Normalize OpenAI base URLs
This commit is contained in:
@@ -22,6 +22,14 @@ def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_url(base_url: str, path: str) -> str:
|
||||||
|
root = base_url.rstrip("/")
|
||||||
|
normalized_path = "/" + path.lstrip("/")
|
||||||
|
if root.endswith("/v1") and normalized_path.startswith("/v1/"):
|
||||||
|
normalized_path = normalized_path[len("/v1") :]
|
||||||
|
return root + normalized_path
|
||||||
|
|
||||||
|
|
||||||
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
||||||
deadline = time.monotonic() + timeout_s
|
deadline = time.monotonic() + timeout_s
|
||||||
url = f"{base_url.rstrip('/')}{path}"
|
url = f"{base_url.rstrip('/')}{path}"
|
||||||
@@ -52,7 +60,7 @@ def chat_completion(
|
|||||||
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||||
data = json.dumps(payload).encode("utf-8")
|
data = json.dumps(payload).encode("utf-8")
|
||||||
request = urllib.request.Request(
|
request = urllib.request.Request(
|
||||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||||
headers=_auth_headers(api_key_env),
|
headers=_auth_headers(api_key_env),
|
||||||
data=data,
|
data=data,
|
||||||
method="POST",
|
method="POST",
|
||||||
@@ -80,7 +88,7 @@ def stream_chat_completion(
|
|||||||
) -> StreamMetrics:
|
) -> StreamMetrics:
|
||||||
data = json.dumps(body).encode("utf-8")
|
data = json.dumps(body).encode("utf-8")
|
||||||
request = urllib.request.Request(
|
request = urllib.request.Request(
|
||||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||||
headers=_auth_headers(None),
|
headers=_auth_headers(None),
|
||||||
data=data,
|
data=data,
|
||||||
method="POST",
|
method="POST",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from aituner.cli import main as cli_main
|
from aituner.cli import main as cli_main
|
||||||
|
from aituner.http_client import _openai_url
|
||||||
from aituner.job import append_job, build_trial_job
|
from aituner.job import append_job, build_trial_job
|
||||||
from aituner.llm import build_prompt, parse_proposal_text
|
from aituner.llm import build_prompt, parse_proposal_text
|
||||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||||
@@ -717,6 +718,16 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(len(replayed), 3)
|
self.assertEqual(len(replayed), 3)
|
||||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||||
|
|
||||||
|
def test_openai_url_avoids_double_v1(self) -> None:
|
||||||
|
self.assertEqual(
|
||||||
|
_openai_url("http://example.com", "/v1/chat/completions"),
|
||||||
|
"http://example.com/v1/chat/completions",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
_openai_url("http://example.com/v1", "/v1/chat/completions"),
|
||||||
|
"http://example.com/v1/chat/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user