Support codex reasoning effort override

This commit is contained in:
2026-04-09 00:57:33 +08:00
parent 0990a3771e
commit 46151512cd
4 changed files with 23 additions and 7 deletions

View File

@@ -152,6 +152,7 @@ def chat_completion(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
timeout_s: float, timeout_s: float,
system_prompt: str = "", system_prompt: str = "",
reasoning_effort: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
if wire_api == "responses": if wire_api == "responses":
payload = {"model": model, "input": messages} payload = {"model": model, "input": messages}
@@ -162,6 +163,8 @@ def chat_completion(
payload = {"model": model, "messages": messages} payload = {"model": model, "messages": messages}
if system_prompt: if system_prompt:
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages] payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
path = "/v1/chat/completions" path = "/v1/chat/completions"
data = json.dumps(payload).encode("utf-8") data = json.dumps(payload).encode("utf-8")
request = urllib.request.Request( request = urllib.request.Request(

View File

@@ -238,5 +238,6 @@ def call_llm_for_proposal(
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
timeout_s=policy.endpoint.timeout_s, timeout_s=policy.endpoint.timeout_s,
system_prompt=policy.system_prompt, system_prompt=policy.system_prompt,
reasoning_effort=policy.endpoint.reasoning_effort,
) )
return _extract_response_text(response) return _extract_response_text(response)

View File

@@ -57,10 +57,10 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]:
return result return result
def _resolve_codex_endpoint() -> tuple[str, str]: def _resolve_codex_endpoint() -> tuple[str, str, str | None]:
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip() override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
if override: if override:
return override, "chat.completions" return override, "chat.completions", None
config_path = Path.home() / ".codex" / "config.toml" config_path = Path.home() / ".codex" / "config.toml"
try: try:
payload = tomllib.loads(config_path.read_text(encoding="utf-8")) payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
@@ -82,18 +82,21 @@ def _resolve_codex_endpoint() -> tuple[str, str]:
base_url = str(selected.get("base_url") or "").strip() base_url = str(selected.get("base_url") or "").strip()
if base_url: if base_url:
wire_api = str(selected.get("wire_api") or "chat.completions").strip() wire_api = str(selected.get("wire_api") or "chat.completions").strip()
return base_url, wire_api reasoning_effort = str(payload.get("model_reasoning_effort") or "").strip()
return base_url, wire_api, reasoning_effort or None
if len(providers) == 1: if len(providers) == 1:
only_provider = next(iter(providers.values())) only_provider = next(iter(providers.values()))
if isinstance(only_provider, Mapping): if isinstance(only_provider, Mapping):
base_url = str(only_provider.get("base_url") or "").strip() base_url = str(only_provider.get("base_url") or "").strip()
if base_url: if base_url:
wire_api = str(only_provider.get("wire_api") or "chat.completions").strip() wire_api = str(only_provider.get("wire_api") or "chat.completions").strip()
return base_url, wire_api reasoning_effort = str(payload.get("model_reasoning_effort") or "").strip()
return base_url, wire_api, reasoning_effort or None
root_base_url = str(payload.get("base_url") or "").strip() root_base_url = str(payload.get("base_url") or "").strip()
if root_base_url: if root_base_url:
wire_api = str(payload.get("wire_api") or "chat.completions").strip() wire_api = str(payload.get("wire_api") or "chat.completions").strip()
return root_base_url, wire_api reasoning_effort = str(payload.get("model_reasoning_effort") or "").strip()
return root_base_url, wire_api, reasoning_effort or None
raise SpecError( raise SpecError(
"Unable to resolve llm.endpoint.base_url for provider=codex from " "Unable to resolve llm.endpoint.base_url for provider=codex from "
f"{config_path}. Set llm.endpoint.base_url explicitly or set " f"{config_path}. Set llm.endpoint.base_url explicitly or set "
@@ -390,6 +393,7 @@ class LLMEndpointSpec:
model: str model: str
provider: str = "custom" provider: str = "custom"
wire_api: str = "chat.completions" wire_api: str = "chat.completions"
reasoning_effort: str | None = None
api_key_env: str = "OPENAI_API_KEY" api_key_env: str = "OPENAI_API_KEY"
timeout_s: float = 120.0 timeout_s: float = 120.0
@@ -398,14 +402,19 @@ class LLMEndpointSpec:
provider = str(data.get("provider") or "custom").strip().lower() provider = str(data.get("provider") or "custom").strip().lower()
base_url = str(data.get("base_url") or "").strip() base_url = str(data.get("base_url") or "").strip()
wire_api = str(data.get("wire_api") or "").strip() wire_api = str(data.get("wire_api") or "").strip()
reasoning_effort = str(data.get("reasoning_effort") or "").strip()
api_key_env = str(data.get("api_key_env") or "").strip() api_key_env = str(data.get("api_key_env") or "").strip()
if provider == "codex": if provider == "codex":
if not base_url or not wire_api: if not base_url or not wire_api or not reasoning_effort:
resolved_base_url, resolved_wire_api = _resolve_codex_endpoint() resolved_base_url, resolved_wire_api, resolved_reasoning_effort = (
_resolve_codex_endpoint()
)
if not base_url: if not base_url:
base_url = resolved_base_url base_url = resolved_base_url
if not wire_api: if not wire_api:
wire_api = resolved_wire_api wire_api = resolved_wire_api
if not reasoning_effort and resolved_reasoning_effort:
reasoning_effort = resolved_reasoning_effort
if not api_key_env: if not api_key_env:
api_key_env = "OPENAI_API_KEY" api_key_env = "OPENAI_API_KEY"
elif provider == "bailian": elif provider == "bailian":
@@ -429,6 +438,7 @@ class LLMEndpointSpec:
model=_require_str(data.get("model"), context="llm.endpoint.model"), model=_require_str(data.get("model"), context="llm.endpoint.model"),
provider=provider, provider=provider,
wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"), wire_api=_require_str(wire_api, context="llm.endpoint.wire_api"),
reasoning_effort=reasoning_effort or None,
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"), api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
timeout_s=_require_float( timeout_s=_require_float(
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s" data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"

View File

@@ -239,6 +239,7 @@ class CoreFlowTests(unittest.TestCase):
'\n'.join( '\n'.join(
[ [
'model_provider = "ipads"', 'model_provider = "ipads"',
'model_reasoning_effort = "high"',
"", "",
"[model_providers.ipads]", "[model_providers.ipads]",
'base_url = "http://codex.example/v1"', 'base_url = "http://codex.example/v1"',
@@ -252,6 +253,7 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(endpoint.provider, "codex") self.assertEqual(endpoint.provider, "codex")
self.assertEqual(endpoint.base_url, "http://codex.example/v1") self.assertEqual(endpoint.base_url, "http://codex.example/v1")
self.assertEqual(endpoint.wire_api, "responses") self.assertEqual(endpoint.wire_api, "responses")
self.assertEqual(endpoint.reasoning_effort, "high")
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY") self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
def test_extract_response_text_supports_responses_api_output(self) -> None: def test_extract_response_text_supports_responses_api_output(self) -> None: