Normalize OpenAI base URLs
This commit is contained in:
@@ -22,6 +22,14 @@ def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||
return headers
|
||||
|
||||
|
||||
def _openai_url(base_url: str, path: str) -> str:
|
||||
root = base_url.rstrip("/")
|
||||
normalized_path = "/" + path.lstrip("/")
|
||||
if root.endswith("/v1") and normalized_path.startswith("/v1/"):
|
||||
normalized_path = normalized_path[len("/v1") :]
|
||||
return root + normalized_path
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
||||
deadline = time.monotonic() + timeout_s
|
||||
url = f"{base_url.rstrip('/')}{path}"
|
||||
@@ -52,7 +60,7 @@ def chat_completion(
|
||||
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||
headers=_auth_headers(api_key_env),
|
||||
data=data,
|
||||
method="POST",
|
||||
@@ -80,7 +88,7 @@ def stream_chat_completion(
|
||||
) -> StreamMetrics:
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||
headers=_auth_headers(None),
|
||||
data=data,
|
||||
method="POST",
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from aituner.cli import main as cli_main
|
||||
from aituner.http_client import _openai_url
|
||||
from aituner.job import append_job, build_trial_job
|
||||
from aituner.llm import build_prompt, parse_proposal_text
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
@@ -717,6 +718,16 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(len(replayed), 3)
|
||||
self.assertEqual(replayed[1].error, "slo_pass_rate_unrecoverable")
|
||||
|
||||
def test_openai_url_avoids_double_v1(self) -> None:
|
||||
self.assertEqual(
|
||||
_openai_url("http://example.com", "/v1/chat/completions"),
|
||||
"http://example.com/v1/chat/completions",
|
||||
)
|
||||
self.assertEqual(
|
||||
_openai_url("http://example.com/v1", "/v1/chat/completions"),
|
||||
"http://example.com/v1/chat/completions",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user