Add codex and bailian LLM provider presets

This commit is contained in:
2026-04-07 11:31:26 +08:00
parent f73a8a5767
commit 94c89e1103
8 changed files with 236 additions and 11 deletions

9
.env.example Normal file
View File

@@ -0,0 +1,9 @@
# Copy to .env when you want to override runtime secrets locally.
# codex provider:
OPENAI_API_KEY=
# bailian / DashScope OpenAI-compatible provider:
DASHSCOPE_API_KEY=
# Optional override if ~/.codex/config.toml does not expose a codex base_url:
# AITUNER_CODEX_BASE_URL=

1
.gitignore vendored
View File

@@ -1,6 +1,7 @@
.aituner/ .aituner/
.aituner-smoke/ .aituner-smoke/
.aituner-tight/ .aituner-tight/
.env
__pycache__/ __pycache__/
*.pyc *.pyc
logs/ logs/

View File

@@ -142,7 +142,7 @@
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.", "system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.",
"max_history_trials": 8, "max_history_trials": 8,
"endpoint": { "endpoint": {
"base_url": "http://tianx.ipads-lab.se.sjtu.edu.cn:8317/v1", "provider": "codex",
"model": "gpt-5.4", "model": "gpt-5.4",
"api_key_env": "OPENAI_API_KEY", "api_key_env": "OPENAI_API_KEY",
"timeout_s": 180 "timeout_s": 180

View File

@@ -90,7 +90,8 @@
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.", "system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
"max_history_trials": 8, "max_history_trials": 8,
"endpoint": { "endpoint": {
"base_url": "https://example-openai-compatible-endpoint", "provider": "custom",
"base_url": "https://example-openai-compatible-endpoint/v1",
"model": "gpt-4.1-mini", "model": "gpt-4.1-mini",
"api_key_env": "OPENAI_API_KEY", "api_key_env": "OPENAI_API_KEY",
"timeout_s": 120 "timeout_s": 120

View File

@@ -3,10 +3,12 @@ from __future__ import annotations
import json import json
import os import os
import time import time
import tomllib
import urllib.error import urllib.error
import urllib.request import urllib.request
from ipaddress import ip_address
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import ip_address
from pathlib import Path
from typing import Any, Iterable from typing import Any, Iterable
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -34,12 +36,85 @@ def _urlopen(request: urllib.request.Request, *, timeout: float):
return urllib.request.urlopen(request, timeout=timeout) return urllib.request.urlopen(request, timeout=timeout)
def _auth_headers(api_key_env: str | None) -> dict[str, str]: def _find_dotenv(start: Path | None = None) -> Path | None:
headers = {"Content-Type": "application/json"} current = (start or Path.cwd()).resolve()
for candidate_dir in (current, *current.parents):
candidate = candidate_dir / ".env"
if candidate.is_file():
return candidate
return None
def _load_dotenv() -> None:
path = _find_dotenv()
if path is None:
return
for raw_line in path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#"):
continue
if line.startswith("export "):
line = line[len("export ") :].strip()
key, separator, value = line.partition("=")
if not separator:
continue
key = key.strip()
if not key:
continue
value = value.strip()
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
value = value[1:-1]
os.environ.setdefault(key, value)
def _load_codex_network_env() -> None:
config_path = Path.home() / ".codex" / "config.toml"
if not config_path.is_file():
return
try:
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
except tomllib.TOMLDecodeError:
return
network = payload.get("network")
if not isinstance(network, dict):
return
for key, value in network.items():
if not isinstance(value, str) or not value.strip():
continue
normalized = value.strip()
os.environ.setdefault(str(key), normalized)
if str(key).endswith("_proxy"):
os.environ.setdefault(str(key).upper(), normalized)
def _resolve_api_key(api_key_env: str | None, *, provider: str) -> str | None:
_load_dotenv()
if provider == "codex":
_load_codex_network_env()
if api_key_env: if api_key_env:
api_key = os.environ.get(api_key_env) api_key = os.environ.get(api_key_env)
if api_key: if api_key:
headers["Authorization"] = f"Bearer {api_key}" return api_key
if provider != "codex" and api_key_env != "OPENAI_API_KEY":
return None
auth_path = Path.home() / ".codex" / "auth.json"
if not auth_path.is_file():
return None
try:
payload = json.loads(auth_path.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return None
api_key = payload.get("OPENAI_API_KEY")
if isinstance(api_key, str) and api_key.strip():
return api_key.strip()
return None
def _auth_headers(api_key_env: str | None, provider: str = "custom") -> dict[str, str]:
headers = {"Content-Type": "application/json"}
api_key = _resolve_api_key(api_key_env, provider=provider)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers return headers
@@ -71,6 +146,7 @@ def chat_completion(
*, *,
base_url: str, base_url: str,
api_key_env: str | None, api_key_env: str | None,
provider: str = "custom",
model: str, model: str,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
timeout_s: float, timeout_s: float,
@@ -82,7 +158,7 @@ def chat_completion(
data = json.dumps(payload).encode("utf-8") data = json.dumps(payload).encode("utf-8")
request = urllib.request.Request( request = urllib.request.Request(
url=_openai_url(base_url, "/v1/chat/completions"), url=_openai_url(base_url, "/v1/chat/completions"),
headers=_auth_headers(api_key_env), headers=_auth_headers(api_key_env, provider),
data=data, data=data,
method="POST", method="POST",
) )

View File

@@ -158,6 +158,7 @@ def call_llm_for_proposal(
response = chat_completion( response = chat_completion(
base_url=policy.endpoint.base_url, base_url=policy.endpoint.base_url,
api_key_env=policy.endpoint.api_key_env, api_key_env=policy.endpoint.api_key_env,
provider=policy.endpoint.provider,
model=policy.endpoint.model, model=policy.endpoint.model,
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
timeout_s=policy.endpoint.timeout_s, timeout_s=policy.endpoint.timeout_s,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import tomllib import tomllib
from dataclasses import asdict, dataclass, field, is_dataclass from dataclasses import asdict, dataclass, field, is_dataclass
from pathlib import Path from pathlib import Path
@@ -56,6 +57,47 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]:
return result return result
def _resolve_codex_base_url() -> str:
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
if override:
return override
config_path = Path.home() / ".codex" / "config.toml"
try:
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
except FileNotFoundError as exc:
raise SpecError(
"Unable to resolve llm.endpoint.base_url for provider=codex. "
"Set llm.endpoint.base_url explicitly, set AITUNER_CODEX_BASE_URL, "
f"or provide {config_path}."
) from exc
except tomllib.TOMLDecodeError as exc:
raise SpecError(f"Invalid TOML in {config_path}: {exc}") from exc
raw_providers = payload.get("model_providers")
providers = raw_providers if isinstance(raw_providers, Mapping) else {}
provider_name = str(payload.get("model_provider") or "").strip()
if provider_name:
selected = providers.get(provider_name)
if isinstance(selected, Mapping):
base_url = str(selected.get("base_url") or "").strip()
if base_url:
return base_url
if len(providers) == 1:
only_provider = next(iter(providers.values()))
if isinstance(only_provider, Mapping):
base_url = str(only_provider.get("base_url") or "").strip()
if base_url:
return base_url
root_base_url = str(payload.get("base_url") or "").strip()
if root_base_url:
return root_base_url
raise SpecError(
"Unable to resolve llm.endpoint.base_url for provider=codex from "
f"{config_path}. Set llm.endpoint.base_url explicitly or set "
"AITUNER_CODEX_BASE_URL."
)
@dataclass(frozen=True) @dataclass(frozen=True)
class HardwareSpec: class HardwareSpec:
gpu_count: int gpu_count: int
@@ -343,15 +385,37 @@ class SamplingSearchSpec:
class LLMEndpointSpec: class LLMEndpointSpec:
base_url: str base_url: str
model: str model: str
provider: str = "custom"
api_key_env: str = "OPENAI_API_KEY" api_key_env: str = "OPENAI_API_KEY"
timeout_s: float = 120.0 timeout_s: float = 120.0
@classmethod @classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec": def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
provider = str(data.get("provider") or "custom").strip().lower()
base_url = str(data.get("base_url") or "").strip()
api_key_env = str(data.get("api_key_env") or "").strip()
if provider == "codex":
if not base_url:
base_url = _resolve_codex_base_url()
if not api_key_env:
api_key_env = "OPENAI_API_KEY"
elif provider == "bailian":
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
if not api_key_env:
api_key_env = "DASHSCOPE_API_KEY"
elif provider == "custom":
if not base_url:
raise SpecError("llm.endpoint.base_url must be a non-empty string.")
if not api_key_env:
api_key_env = "OPENAI_API_KEY"
else:
raise SpecError(f"Unsupported llm.endpoint.provider: {provider}")
return cls( return cls(
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"), base_url=_require_str(base_url, context="llm.endpoint.base_url"),
model=_require_str(data.get("model"), context="llm.endpoint.model"), model=_require_str(data.get("model"), context="llm.endpoint.model"),
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(), provider=provider,
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
timeout_s=_require_float( timeout_s=_require_float(
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s" data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
), ),

View File

@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import subprocess import subprocess
import tempfile import tempfile
import unittest import unittest
@@ -8,12 +9,19 @@ from pathlib import Path
from unittest import mock from unittest import mock
from aituner.cli import main as cli_main from aituner.cli import main as cli_main
from aituner.http_client import _openai_url, _should_bypass_proxy from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
from aituner.job import append_job, build_trial_job from aituner.job import append_job, build_trial_job
from aituner.llm import build_prompt, parse_proposal_text from aituner.llm import build_prompt, parse_proposal_text
from aituner.search import ThresholdProbe, binary_search_max_feasible from aituner.search import ThresholdProbe, binary_search_max_feasible
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
from aituner.spec import Proposal, SpecError, StudyState, TrialSummary, load_study_spec from aituner.spec import (
LLMEndpointSpec,
Proposal,
SpecError,
StudyState,
TrialSummary,
load_study_spec,
)
from aituner.store import StudyStore from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window from aituner.trace import load_trace_requests, summarize_window
from aituner.worker import ( from aituner.worker import (
@@ -214,6 +222,71 @@ class CoreFlowTests(unittest.TestCase):
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="): with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
load_study_spec(study_path) load_study_spec(study_path)
def test_bailian_endpoint_defaults(self) -> None:
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
self.assertEqual(endpoint.provider, "bailian")
self.assertEqual(
endpoint.base_url, "https://dashscope.aliyuncs.com/compatible-mode/v1"
)
self.assertEqual(endpoint.api_key_env, "DASHSCOPE_API_KEY")
def test_codex_endpoint_resolves_base_url_from_codex_config(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
codex_dir = tmp_path / ".codex"
codex_dir.mkdir(parents=True)
(codex_dir / "config.toml").write_text(
'\n'.join(
[
'model_provider = "ipads"',
"",
"[model_providers.ipads]",
'base_url = "http://codex.example/v1"',
]
),
encoding="utf-8",
)
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"})
self.assertEqual(endpoint.provider, "codex")
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
def test_auth_headers_load_bailian_key_from_dotenv(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
(tmp_path / ".env").write_text('DASHSCOPE_API_KEY="dash-key"\n', encoding="utf-8")
with mock.patch.dict(os.environ, {}, clear=True):
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
headers = _auth_headers("DASHSCOPE_API_KEY", "bailian")
self.assertEqual(headers["Authorization"], "Bearer dash-key")
def test_auth_headers_load_codex_auth_and_proxy(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
codex_dir = tmp_path / ".codex"
codex_dir.mkdir(parents=True)
(codex_dir / "config.toml").write_text(
'\n'.join(
[
"[network]",
'http_proxy = "http://proxy.example:3128"',
'https_proxy = "http://proxy.example:3128"',
]
),
encoding="utf-8",
)
(codex_dir / "auth.json").write_text(
json.dumps({"OPENAI_API_KEY": "sk-codex-test"}),
encoding="utf-8",
)
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
headers = _auth_headers("OPENAI_API_KEY", "codex")
self.assertEqual(os.environ["http_proxy"], "http://proxy.example:3128")
self.assertEqual(os.environ["HTTP_PROXY"], "http://proxy.example:3128")
self.assertEqual(headers["Authorization"], "Bearer sk-codex-test")
def test_prompt_includes_failed_trial_context(self) -> None: def test_prompt_includes_failed_trial_context(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp) tmp_path = Path(tmp)