From 94c89e110345d1c177a707e2a70742f116321c9e Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 7 Apr 2026 11:31:26 +0800 Subject: [PATCH] Add codex and bailian LLM provider presets --- .env.example | 9 ++ .gitignore | 1 + .../dash0_qwen27b_tight_slo_run4_0_8k.json | 2 +- configs/examples/study.example.json | 3 +- src/aituner/http_client.py | 86 +++++++++++++++++-- src/aituner/llm.py | 1 + src/aituner/spec.py | 68 ++++++++++++++- tests/test_core_flow.py | 77 ++++++++++++++++- 8 files changed, 236 insertions(+), 11 deletions(-) create mode 100644 .env.example diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..fbf57db --- /dev/null +++ b/.env.example @@ -0,0 +1,9 @@ +# Copy to .env when you want to override runtime secrets locally. +# codex provider: +OPENAI_API_KEY= + +# bailian / DashScope OpenAI-compatible provider: +DASHSCOPE_API_KEY= + +# Optional override if ~/.codex/config.toml does not expose a codex base_url: +# AITUNER_CODEX_BASE_URL= diff --git a/.gitignore b/.gitignore index 14892c0..952ec45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .aituner/ .aituner-smoke/ .aituner-tight/ +.env __pycache__/ *.pyc logs/ diff --git a/configs/examples/dash0_qwen27b_tight_slo_run4_0_8k.json b/configs/examples/dash0_qwen27b_tight_slo_run4_0_8k.json index a59decb..34d15c2 100644 --- a/configs/examples/dash0_qwen27b_tight_slo_run4_0_8k.json +++ b/configs/examples/dash0_qwen27b_tight_slo_run4_0_8k.json @@ -142,7 +142,7 @@ "system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.", "max_history_trials": 8, "endpoint": { - "base_url": "http://tianx.ipads-lab.se.sjtu.edu.cn:8317/v1", + "provider": "codex", "model": "gpt-5.4", "api_key_env": "OPENAI_API_KEY", "timeout_s": 180 diff --git a/configs/examples/study.example.json b/configs/examples/study.example.json index 57f694d..a052dee 100644 --- a/configs/examples/study.example.json +++ b/configs/examples/study.example.json @@ -90,7 +90,8 @@ "system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.", "max_history_trials": 8, "endpoint": { - "base_url": "https://example-openai-compatible-endpoint", + "provider": "custom", + "base_url": "https://example-openai-compatible-endpoint/v1", "model": "gpt-4.1-mini", "api_key_env": "OPENAI_API_KEY", "timeout_s": 120 diff --git a/src/aituner/http_client.py b/src/aituner/http_client.py index 33550ae..495dfd3 100644 --- a/src/aituner/http_client.py +++ b/src/aituner/http_client.py @@ -3,10 +3,12 @@ from __future__ import annotations import json import os import time +import tomllib import urllib.error import urllib.request -from ipaddress import ip_address from dataclasses import dataclass +from ipaddress import ip_address +from pathlib import Path from typing import Any, Iterable from urllib.parse import urlparse @@ -34,12 +36,85 @@ def _urlopen(request: urllib.request.Request, *, timeout: float): return urllib.request.urlopen(request, timeout=timeout) -def _auth_headers(api_key_env: str | None) -> dict[str, str]: - headers = {"Content-Type": "application/json"} +def _find_dotenv(start: Path | None = None) -> Path | None: + current = (start or Path.cwd()).resolve() + for candidate_dir in (current, *current.parents): + candidate = candidate_dir / ".env" + if candidate.is_file(): + return candidate + return None + + +def _load_dotenv() -> None: + path = _find_dotenv() + if path is None: + return + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[len("export ") :].strip() + key, separator, value = line.partition("=") + if not separator: + continue + key = key.strip() + if not key: + continue + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}: + value = value[1:-1] + os.environ.setdefault(key, value) + + +def _load_codex_network_env() -> None: + config_path = Path.home() / ".codex" / "config.toml" + if not config_path.is_file(): + return + try: + payload = tomllib.loads(config_path.read_text(encoding="utf-8")) + except tomllib.TOMLDecodeError: + return + network = payload.get("network") + if not isinstance(network, dict): + return + for key, value in network.items(): + if not isinstance(value, str) or not value.strip(): + continue + normalized = value.strip() + os.environ.setdefault(str(key), normalized) + if str(key).endswith("_proxy"): + os.environ.setdefault(str(key).upper(), normalized) + + +def _resolve_api_key(api_key_env: str | None, *, provider: str) -> str | None: + _load_dotenv() + if provider == "codex": + _load_codex_network_env() if api_key_env: api_key = os.environ.get(api_key_env) if api_key: - headers["Authorization"] = f"Bearer {api_key}" + return api_key + if provider != "codex" and api_key_env != "OPENAI_API_KEY": + return None + auth_path = Path.home() / ".codex" / "auth.json" + if not auth_path.is_file(): + return None + try: + payload = json.loads(auth_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + return None + api_key = payload.get("OPENAI_API_KEY") + if isinstance(api_key, str) and api_key.strip(): + return api_key.strip() + return None + + +def _auth_headers(api_key_env: str | None, provider: str = "custom") -> dict[str, str]: + headers = {"Content-Type": "application/json"} + api_key = _resolve_api_key(api_key_env, provider=provider) + if api_key: + headers["Authorization"] = f"Bearer {api_key}" return headers @@ -71,6 +146,7 @@ def chat_completion( *, base_url: str, api_key_env: str | None, + provider: str = "custom", model: str, messages: list[dict[str, Any]], timeout_s: float, @@ -82,7 +158,7 @@ def chat_completion( data = json.dumps(payload).encode("utf-8") request = urllib.request.Request( url=_openai_url(base_url, "/v1/chat/completions"), - headers=_auth_headers(api_key_env), + headers=_auth_headers(api_key_env, provider), data=data, method="POST", ) diff --git a/src/aituner/llm.py b/src/aituner/llm.py index 6f76dd6..fa27755 100644 --- a/src/aituner/llm.py +++ b/src/aituner/llm.py @@ -158,6 +158,7 @@ def call_llm_for_proposal( response = chat_completion( base_url=policy.endpoint.base_url, api_key_env=policy.endpoint.api_key_env, + provider=policy.endpoint.provider, model=policy.endpoint.model, messages=[{"role": "user", "content": prompt}], timeout_s=policy.endpoint.timeout_s, diff --git a/src/aituner/spec.py b/src/aituner/spec.py index aa37115..c348ba1 100644 --- a/src/aituner/spec.py +++ b/src/aituner/spec.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import os import tomllib from dataclasses import asdict, dataclass, field, is_dataclass from pathlib import Path @@ -56,6 +57,47 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]: return result +def _resolve_codex_base_url() -> str: + override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip() + if override: + return override + config_path = Path.home() / ".codex" / "config.toml" + try: + payload = tomllib.loads(config_path.read_text(encoding="utf-8")) + except FileNotFoundError as exc: + raise SpecError( + "Unable to resolve llm.endpoint.base_url for provider=codex. " + "Set llm.endpoint.base_url explicitly, set AITUNER_CODEX_BASE_URL, " + f"or provide {config_path}." + ) from exc + except tomllib.TOMLDecodeError as exc: + raise SpecError(f"Invalid TOML in {config_path}: {exc}") from exc + + raw_providers = payload.get("model_providers") + providers = raw_providers if isinstance(raw_providers, Mapping) else {} + provider_name = str(payload.get("model_provider") or "").strip() + if provider_name: + selected = providers.get(provider_name) + if isinstance(selected, Mapping): + base_url = str(selected.get("base_url") or "").strip() + if base_url: + return base_url + if len(providers) == 1: + only_provider = next(iter(providers.values())) + if isinstance(only_provider, Mapping): + base_url = str(only_provider.get("base_url") or "").strip() + if base_url: + return base_url + root_base_url = str(payload.get("base_url") or "").strip() + if root_base_url: + return root_base_url + raise SpecError( + "Unable to resolve llm.endpoint.base_url for provider=codex from " + f"{config_path}. Set llm.endpoint.base_url explicitly or set " + "AITUNER_CODEX_BASE_URL." + ) + + @dataclass(frozen=True) class HardwareSpec: gpu_count: int @@ -343,15 +385,37 @@ class SamplingSearchSpec: class LLMEndpointSpec: base_url: str model: str + provider: str = "custom" api_key_env: str = "OPENAI_API_KEY" timeout_s: float = 120.0 @classmethod def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec": + provider = str(data.get("provider") or "custom").strip().lower() + base_url = str(data.get("base_url") or "").strip() + api_key_env = str(data.get("api_key_env") or "").strip() + if provider == "codex": + if not base_url: + base_url = _resolve_codex_base_url() + if not api_key_env: + api_key_env = "OPENAI_API_KEY" + elif provider == "bailian": + if not base_url: + base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + if not api_key_env: + api_key_env = "DASHSCOPE_API_KEY" + elif provider == "custom": + if not base_url: + raise SpecError("llm.endpoint.base_url must be a non-empty string.") + if not api_key_env: + api_key_env = "OPENAI_API_KEY" + else: + raise SpecError(f"Unsupported llm.endpoint.provider: {provider}") return cls( - base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"), + base_url=_require_str(base_url, context="llm.endpoint.base_url"), model=_require_str(data.get("model"), context="llm.endpoint.model"), - api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(), + provider=provider, + api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"), timeout_s=_require_float( data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s" ), diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index d084318..2ee88fc 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import os import subprocess import tempfile import unittest @@ -8,12 +9,19 @@ from pathlib import Path from unittest import mock from aituner.cli import main as cli_main -from aituner.http_client import _openai_url, _should_bypass_proxy +from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy from aituner.job import append_job, build_trial_job from aituner.llm import build_prompt, parse_proposal_text from aituner.search import ThresholdProbe, binary_search_max_feasible from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations -from aituner.spec import Proposal, SpecError, StudyState, TrialSummary, load_study_spec +from aituner.spec import ( + LLMEndpointSpec, + Proposal, + SpecError, + StudyState, + TrialSummary, + load_study_spec, +) from aituner.store import StudyStore from aituner.trace import load_trace_requests, summarize_window from aituner.worker import ( @@ -214,6 +222,71 @@ class CoreFlowTests(unittest.TestCase): with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="): load_study_spec(study_path) + def test_bailian_endpoint_defaults(self) -> None: + endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"}) + self.assertEqual(endpoint.provider, "bailian") + self.assertEqual( + endpoint.base_url, "https://dashscope.aliyuncs.com/compatible-mode/v1" + ) + self.assertEqual(endpoint.api_key_env, "DASHSCOPE_API_KEY") + + def test_codex_endpoint_resolves_base_url_from_codex_config(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + codex_dir = tmp_path / ".codex" + codex_dir.mkdir(parents=True) + (codex_dir / "config.toml").write_text( + '\n'.join( + [ + 'model_provider = "ipads"', + "", + "[model_providers.ipads]", + 'base_url = "http://codex.example/v1"', + ] + ), + encoding="utf-8", + ) + with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True): + endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"}) + self.assertEqual(endpoint.provider, "codex") + self.assertEqual(endpoint.base_url, "http://codex.example/v1") + self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY") + + def test_auth_headers_load_bailian_key_from_dotenv(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + (tmp_path / ".env").write_text('DASHSCOPE_API_KEY="dash-key"\n', encoding="utf-8") + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("pathlib.Path.cwd", return_value=tmp_path): + headers = _auth_headers("DASHSCOPE_API_KEY", "bailian") + self.assertEqual(headers["Authorization"], "Bearer dash-key") + + def test_auth_headers_load_codex_auth_and_proxy(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + codex_dir = tmp_path / ".codex" + codex_dir.mkdir(parents=True) + (codex_dir / "config.toml").write_text( + '\n'.join( + [ + "[network]", + 'http_proxy = "http://proxy.example:3128"', + 'https_proxy = "http://proxy.example:3128"', + ] + ), + encoding="utf-8", + ) + (codex_dir / "auth.json").write_text( + json.dumps({"OPENAI_API_KEY": "sk-codex-test"}), + encoding="utf-8", + ) + with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True): + with mock.patch("pathlib.Path.cwd", return_value=tmp_path): + headers = _auth_headers("OPENAI_API_KEY", "codex") + self.assertEqual(os.environ["http_proxy"], "http://proxy.example:3128") + self.assertEqual(os.environ["HTTP_PROXY"], "http://proxy.example:3128") + self.assertEqual(headers["Authorization"], "Bearer sk-codex-test") + def test_prompt_includes_failed_trial_context(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp)