Support codex responses API
This commit is contained in:
@@ -147,17 +147,25 @@ def chat_completion(
|
|||||||
base_url: str,
|
base_url: str,
|
||||||
api_key_env: str | None,
|
api_key_env: str | None,
|
||||||
provider: str = "custom",
|
provider: str = "custom",
|
||||||
|
wire_api: str = "chat.completions",
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
timeout_s: float,
|
timeout_s: float,
|
||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
payload: dict[str, Any] = {"model": model, "messages": messages}
|
if wire_api == "responses":
|
||||||
if system_prompt:
|
payload = {"model": model, "input": messages}
|
||||||
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
if system_prompt:
|
||||||
|
payload["instructions"] = system_prompt
|
||||||
|
path = "/v1/responses"
|
||||||
|
else:
|
||||||
|
payload = {"model": model, "messages": messages}
|
||||||
|
if system_prompt:
|
||||||
|
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||||
|
path = "/v1/chat/completions"
|
||||||
data = json.dumps(payload).encode("utf-8")
|
data = json.dumps(payload).encode("utf-8")
|
||||||
request = urllib.request.Request(
|
request = urllib.request.Request(
|
||||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
url=_openai_url(base_url, path),
|
||||||
headers=_auth_headers(api_key_env, provider),
|
headers=_auth_headers(api_key_env, provider),
|
||||||
data=data,
|
data=data,
|
||||||
method="POST",
|
method="POST",
|
||||||
@@ -167,7 +175,7 @@ def chat_completion(
|
|||||||
return json.loads(response.read().decode("utf-8"))
|
return json.loads(response.read().decode("utf-8"))
|
||||||
except urllib.error.HTTPError as exc:
|
except urllib.error.HTTPError as exc:
|
||||||
detail = exc.read().decode("utf-8", errors="replace")
|
detail = exc.read().decode("utf-8", errors="replace")
|
||||||
raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc
|
raise HttpClientError(f"llm_completion failed: {exc.code} {detail}") from exc
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -183,6 +183,45 @@ def parse_proposal_text(text: str, study: StudySpec) -> Proposal:
|
|||||||
return validate_proposal(proposal, study)
|
return validate_proposal(proposal, study)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_response_text(response: dict[str, Any]) -> str:
|
||||||
|
output_text = response.get("output_text")
|
||||||
|
if isinstance(output_text, str) and output_text:
|
||||||
|
return output_text
|
||||||
|
choices = response.get("choices")
|
||||||
|
if isinstance(choices, list) and choices:
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
if isinstance(message, dict):
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
text = "".join(
|
||||||
|
item.get("text", "")
|
||||||
|
for item in content
|
||||||
|
if isinstance(item, dict) and isinstance(item.get("text"), str)
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
output = response.get("output")
|
||||||
|
if isinstance(output, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in output:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
content = item.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
continue
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str) and text:
|
||||||
|
parts.append(text)
|
||||||
|
if parts:
|
||||||
|
return "".join(parts)
|
||||||
|
raise RuntimeError("LLM response content is empty")
|
||||||
|
|
||||||
|
|
||||||
def call_llm_for_proposal(
|
def call_llm_for_proposal(
|
||||||
*,
|
*,
|
||||||
policy: LLMPolicySpec,
|
policy: LLMPolicySpec,
|
||||||
@@ -194,24 +233,10 @@ def call_llm_for_proposal(
|
|||||||
base_url=policy.endpoint.base_url,
|
base_url=policy.endpoint.base_url,
|
||||||
api_key_env=policy.endpoint.api_key_env,
|
api_key_env=policy.endpoint.api_key_env,
|
||||||
provider=policy.endpoint.provider,
|
provider=policy.endpoint.provider,
|
||||||
|
wire_api=policy.endpoint.wire_api,
|
||||||
model=policy.endpoint.model,
|
model=policy.endpoint.model,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
timeout_s=policy.endpoint.timeout_s,
|
timeout_s=policy.endpoint.timeout_s,
|
||||||
system_prompt=policy.system_prompt,
|
system_prompt=policy.system_prompt,
|
||||||
)
|
)
|
||||||
choices = response.get("choices")
|
return _extract_response_text(response)
|
||||||
if not isinstance(choices, list) or not choices:
|
|
||||||
raise RuntimeError("LLM response does not contain choices")
|
|
||||||
message = choices[0].get("message", {})
|
|
||||||
if not isinstance(message, dict):
|
|
||||||
raise RuntimeError("LLM response does not contain a valid message")
|
|
||||||
content = message.get("content")
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
return "".join(
|
|
||||||
item.get("text", "")
|
|
||||||
for item in content
|
|
||||||
if isinstance(item, dict) and isinstance(item.get("text"), str)
|
|
||||||
)
|
|
||||||
raise RuntimeError("LLM response content is empty")
|
|
||||||
|
|||||||
@@ -57,10 +57,10 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _resolve_codex_base_url() -> str:
|
def _resolve_codex_endpoint() -> tuple[str, str]:
|
||||||
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
|
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
|
||||||
if override:
|
if override:
|
||||||
return override
|
return override, "chat.completions"
|
||||||
config_path = Path.home() / ".codex" / "config.toml"
|
config_path = Path.home() / ".codex" / "config.toml"
|
||||||
try:
|
try:
|
||||||
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||||
@@ -81,16 +81,19 @@ def _resolve_codex_base_url() -> str:
|
|||||||
if isinstance(selected, Mapping):
|
if isinstance(selected, Mapping):
|
||||||
base_url = str(selected.get("base_url") or "").strip()
|
base_url = str(selected.get("base_url") or "").strip()
|
||||||
if base_url:
|
if base_url:
|
||||||
return base_url
|
wire_api = str(selected.get("wire_api") or "chat.completions").strip()
|
||||||
|
return base_url, wire_api
|
||||||
if len(providers) == 1:
|
if len(providers) == 1:
|
||||||
only_provider = next(iter(providers.values()))
|
only_provider = next(iter(providers.values()))
|
||||||
if isinstance(only_provider, Mapping):
|
if isinstance(only_provider, Mapping):
|
||||||
base_url = str(only_provider.get("base_url") or "").strip()
|
base_url = str(only_provider.get("base_url") or "").strip()
|
||||||
if base_url:
|
if base_url:
|
||||||
return base_url
|
wire_api = str(only_provider.get("wire_api") or "chat.completions").strip()
|
||||||
|
return base_url, wire_api
|
||||||
root_base_url = str(payload.get("base_url") or "").strip()
|
root_base_url = str(payload.get("base_url") or "").strip()
|
||||||
if root_base_url:
|
if root_base_url:
|
||||||
return root_base_url
|
wire_api = str(payload.get("wire_api") or "chat.completions").strip()
|
||||||
|
return root_base_url, wire_api
|
||||||
raise SpecError(
|
raise SpecError(
|
||||||
"Unable to resolve llm.endpoint.base_url for provider=codex from "
|
"Unable to resolve llm.endpoint.base_url for provider=codex from "
|
||||||
f"{config_path}. Set llm.endpoint.base_url explicitly or set "
|
f"{config_path}. Set llm.endpoint.base_url explicitly or set "
|
||||||
@@ -386,6 +389,7 @@ class LLMEndpointSpec:
|
|||||||
base_url: str
|
base_url: str
|
||||||
model: str
|
model: str
|
||||||
provider: str = "custom"
|
provider: str = "custom"
|
||||||
|
wire_api: str = "chat.completions"
|
||||||
api_key_env: str = "OPENAI_API_KEY"
|
api_key_env: str = "OPENAI_API_KEY"
|
||||||
timeout_s: float = 120.0
|
timeout_s: float = 120.0
|
||||||
|
|
||||||
@@ -393,20 +397,29 @@ class LLMEndpointSpec:
|
|||||||
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
||||||
provider = str(data.get("provider") or "custom").strip().lower()
|
provider = str(data.get("provider") or "custom").strip().lower()
|
||||||
base_url = str(data.get("base_url") or "").strip()
|
base_url = str(data.get("base_url") or "").strip()
|
||||||
|
wire_api = str(data.get("wire_api") or "").strip()
|
||||||
api_key_env = str(data.get("api_key_env") or "").strip()
|
api_key_env = str(data.get("api_key_env") or "").strip()
|
||||||
if provider == "codex":
|
if provider == "codex":
|
||||||
if not base_url:
|
if not base_url or not wire_api:
|
||||||
base_url = _resolve_codex_base_url()
|
resolved_base_url, resolved_wire_api = _resolve_codex_endpoint()
|
||||||
|
if not base_url:
|
||||||
|
base_url = resolved_base_url
|
||||||
|
if not wire_api:
|
||||||
|
wire_api = resolved_wire_api
|
||||||
if not api_key_env:
|
if not api_key_env:
|
||||||
api_key_env = "OPENAI_API_KEY"
|
api_key_env = "OPENAI_API_KEY"
|
||||||
elif provider == "bailian":
|
elif provider == "bailian":
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
if not wire_api:
|
||||||
|
wire_api = "chat.completions"
|
||||||
if not api_key_env:
|
if not api_key_env:
|
||||||
api_key_env = "DASHSCOPE_API_KEY"
|
api_key_env = "DASHSCOPE_API_KEY"
|
||||||
elif provider == "custom":
|
elif provider == "custom":
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise SpecError("llm.endpoint.base_url must be a non-empty string.")
|
raise SpecError("llm.endpoint.base_url must be a non-empty string.")
|
||||||
|
if not wire_api:
|
||||||
|
wire_api = "chat.completions"
|
||||||
if not api_key_env:
|
if not api_key_env:
|
||||||
api_key_env = "OPENAI_API_KEY"
|
api_key_env = "OPENAI_API_KEY"
|
||||||
else:
|
else:
|
||||||
@@ -415,6 +428,7 @@ class LLMEndpointSpec:
|
|||||||
base_url=_require_str(base_url, context="llm.endpoint.base_url"),
|
base_url=_require_str(base_url, context="llm.endpoint.base_url"),
|
||||||
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"),
|
||||||
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
|
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
|
||||||
timeout_s=_require_float(
|
timeout_s=_require_float(
|
||||||
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from unittest import mock
|
|||||||
from aituner.cli import main as cli_main
|
from aituner.cli import main as cli_main
|
||||||
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
|
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
|
||||||
from aituner.job import append_job, build_trial_job
|
from aituner.job import append_job, build_trial_job
|
||||||
from aituner.llm import build_prompt, parse_proposal_text
|
from aituner.llm import _extract_response_text, build_prompt, parse_proposal_text
|
||||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||||
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||||
from aituner.spec import (
|
from aituner.spec import (
|
||||||
@@ -242,6 +242,7 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
"",
|
"",
|
||||||
"[model_providers.ipads]",
|
"[model_providers.ipads]",
|
||||||
'base_url = "http://codex.example/v1"',
|
'base_url = "http://codex.example/v1"',
|
||||||
|
'wire_api = "responses"',
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
@@ -250,8 +251,24 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"})
|
endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"})
|
||||||
self.assertEqual(endpoint.provider, "codex")
|
self.assertEqual(endpoint.provider, "codex")
|
||||||
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
|
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
|
||||||
|
self.assertEqual(endpoint.wire_api, "responses")
|
||||||
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
|
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
|
||||||
|
|
||||||
|
def test_extract_response_text_supports_responses_api_output(self) -> None:
|
||||||
|
text = _extract_response_text(
|
||||||
|
{
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"content": [
|
||||||
|
{"type": "output_text", "text": '{"diagnosis":"ok"}'}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(text, '{"diagnosis":"ok"}')
|
||||||
|
|
||||||
def test_auth_headers_load_bailian_key_from_dotenv(self) -> None:
|
def test_auth_headers_load_bailian_key_from_dotenv(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user