diff --git a/src/aituner/http_client.py b/src/aituner/http_client.py index 495dfd3..d94e433 100644 --- a/src/aituner/http_client.py +++ b/src/aituner/http_client.py @@ -147,17 +147,25 @@ def chat_completion( base_url: str, api_key_env: str | None, provider: str = "custom", + wire_api: str = "chat.completions", model: str, messages: list[dict[str, Any]], timeout_s: float, system_prompt: str = "", ) -> dict[str, Any]: - payload: dict[str, Any] = {"model": model, "messages": messages} - if system_prompt: - payload["messages"] = [{"role": "system", "content": system_prompt}, *messages] + if wire_api == "responses": + payload = {"model": model, "input": messages} + if system_prompt: + payload["instructions"] = system_prompt + path = "/v1/responses" + else: + payload = {"model": model, "messages": messages} + if system_prompt: + payload["messages"] = [{"role": "system", "content": system_prompt}, *messages] + path = "/v1/chat/completions" data = json.dumps(payload).encode("utf-8") request = urllib.request.Request( - url=_openai_url(base_url, "/v1/chat/completions"), + url=_openai_url(base_url, path), headers=_auth_headers(api_key_env, provider), data=data, method="POST", @@ -167,7 +175,7 @@ def chat_completion( return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as exc: detail = exc.read().decode("utf-8", errors="replace") - raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc + raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc @dataclass(frozen=True) diff --git a/src/aituner/llm.py b/src/aituner/llm.py index f458729..23ebdec 100644 --- a/src/aituner/llm.py +++ b/src/aituner/llm.py @@ -183,6 +183,45 @@ def parse_proposal_text(text: str, study: StudySpec) -> Proposal: return validate_proposal(proposal, study) +def _extract_response_text(response: dict[str, Any]) -> str: + output_text = response.get("output_text") + if isinstance(output_text, str) and output_text: + return output_text + choices = response.get("choices") + if isinstance(choices, list) and choices: + message = choices[0].get("message", {}) + if isinstance(message, dict): + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + text = "".join( + item.get("text", "") + for item in content + if isinstance(item, dict) and isinstance(item.get("text"), str) + ) + if text: + return text + output = response.get("output") + if isinstance(output, list): + parts: list[str] = [] + for item in output: + if not isinstance(item, dict): + continue + content = item.get("content") + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + text = block.get("text") + if isinstance(text, str) and text: + parts.append(text) + if parts: + return "".join(parts) + raise RuntimeError("LLM response content is empty") + + def call_llm_for_proposal( *, policy: LLMPolicySpec, @@ -194,24 +233,10 @@ def call_llm_for_proposal( base_url=policy.endpoint.base_url, api_key_env=policy.endpoint.api_key_env, provider=policy.endpoint.provider, + wire_api=policy.endpoint.wire_api, model=policy.endpoint.model, messages=[{"role": "user", "content": prompt}], timeout_s=policy.endpoint.timeout_s, system_prompt=policy.system_prompt, ) - choices = response.get("choices") - if not isinstance(choices, list) or not choices: - raise RuntimeError("LLM response does not contain choices") - message = choices[0].get("message", {}) - if not isinstance(message, dict): - raise RuntimeError("LLM response does not contain a valid message") - content = message.get("content") - if isinstance(content, str): - return content - if isinstance(content, list): - return "".join( - item.get("text", "") - for item in content - if isinstance(item, dict) and isinstance(item.get("text"), str) - ) - raise RuntimeError("LLM response content is empty") + return _extract_response_text(response) diff --git a/src/aituner/spec.py b/src/aituner/spec.py index c348ba1..0aacfa7 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -57,10 +57,10 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]: return result -def _resolve_codex_base_url() -> str: +def _resolve_codex_endpoint() -> tuple[str, str]: override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip() if override: - return override + return override, "chat.completions" config_path = Path.home() / ".codex" / "config.toml" try: payload = tomllib.loads(config_path.read_text(encoding="utf-8")) @@ -81,16 +81,19 @@ def _resolve_codex_base_url() -> str: if isinstance(selected, Mapping): base_url = str(selected.get("base_url") or "").strip() if base_url: - return base_url + wire_api = str(selected.get("wire_api") or "chat.completions").strip() + return base_url, wire_api if len(providers) == 1: only_provider = next(iter(providers.values())) if isinstance(only_provider, Mapping): base_url = str(only_provider.get("base_url") or "").strip() if base_url: - return base_url + wire_api = str(only_provider.get("wire_api") or "chat.completions").strip() + return base_url, wire_api root_base_url = str(payload.get("base_url") or "").strip() if root_base_url: - return root_base_url + wire_api = str(payload.get("wire_api") or "chat.completions").strip() + return root_base_url, wire_api raise SpecError( "Unable to resolve llm.endpoint.base_url for provider=codex from " f"{config_path}. Set llm.endpoint.base_url explicitly or set " @@ -386,6 +389,7 @@ class LLMEndpointSpec: base_url: str model: str provider: str = "custom" + wire_api: str = "chat.completions" api_key_env: str = "OPENAI_API_KEY" timeout_s: float = 120.0 @@ -393,20 +397,29 @@ class LLMEndpointSpec: def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec": provider = str(data.get("provider") or "custom").strip().lower() base_url = str(data.get("base_url") or "").strip() + wire_api = str(data.get("wire_api") or "").strip() api_key_env = str(data.get("api_key_env") or "").strip() if provider == "codex": - if not base_url: - base_url = _resolve_codex_base_url() + if not base_url or not wire_api: + resolved_base_url, resolved_wire_api = _resolve_codex_endpoint() + if not base_url: + base_url = resolved_base_url + if not wire_api: + wire_api = resolved_wire_api if not api_key_env: api_key_env = "OPENAI_API_KEY" elif provider == "bailian": if not base_url: base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + if not wire_api: + wire_api = "chat.completions" if not api_key_env: api_key_env = "DASHSCOPE_API_KEY" elif provider == "custom": if not base_url: raise SpecError("llm.endpoint.base_url must be a non-empty string.") + if not wire_api: + wire_api = "chat.completions" if not api_key_env: api_key_env = "OPENAI_API_KEY" else: @@ -415,6 +428,7 @@ class LLMEndpointSpec: base_url=_require_str(base_url, context="llm.endpoint.base_url"), model=_require_str(data.get("model"), context="llm.endpoint.model"), provider=provider, + wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"), api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"), timeout_s=_require_float( data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s" diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 8c6a392..1b96269 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -11,7 +11,7 @@ from unittest import mock from aituner.cli import main as cli_main from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy from aituner.job import append_job, build_trial_job -from aituner.llm import build_prompt, parse_proposal_text +from aituner.llm import _extract_response_text, build_prompt, parse_proposal_text from aituner.search import ThresholdProbe, binary_search_max_feasible from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations from aituner.spec import ( @@ -242,6 +242,7 @@ class CoreFlowTests(unittest.TestCase): "", "[model_providers.ipads]", 'base_url = "http://codex.example/v1"', + 'wire_api = "responses"', ] ), encoding="utf-8", @@ -250,8 +251,24 @@ class CoreFlowTests(unittest.TestCase): endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"}) self.assertEqual(endpoint.provider, "codex") self.assertEqual(endpoint.base_url, "http://codex.example/v1") + self.assertEqual(endpoint.wire_api, "responses") self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY") + def test_extract_response_text_supports_responses_api_output(self) -> None: + text = _extract_response_text( + { + "output": [ + { + "type": "message", + "content": [ + {"type": "output_text", "text": '{"diagnosis":"ok"}'} + ], + } + ] + } + ) + self.assertEqual(text, '{"diagnosis":"ok"}') + def test_auth_headers_load_bailian_key_from_dotenv(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp)