Add codex and bailian LLM provider presets
This commit is contained in:
9
.env.example
Normal file
9
.env.example
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copy to .env when you want to override runtime secrets locally.
|
||||
# codex provider:
|
||||
OPENAI_API_KEY=
|
||||
|
||||
# bailian / DashScope OpenAI-compatible provider:
|
||||
DASHSCOPE_API_KEY=
|
||||
|
||||
# Optional override if ~/.codex/config.toml does not expose a codex base_url:
|
||||
# AITUNER_CODEX_BASE_URL=
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
||||
.aituner/
|
||||
.aituner-smoke/
|
||||
.aituner-tight/
|
||||
.env
|
||||
__pycache__/
|
||||
*.pyc
|
||||
logs/
|
||||
|
||||
@@ -142,7 +142,7 @@
|
||||
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.",
|
||||
"max_history_trials": 8,
|
||||
"endpoint": {
|
||||
"base_url": "http://tianx.ipads-lab.se.sjtu.edu.cn:8317/v1",
|
||||
"provider": "codex",
|
||||
"model": "gpt-5.4",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
"timeout_s": 180
|
||||
|
||||
@@ -90,7 +90,8 @@
|
||||
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
|
||||
"max_history_trials": 8,
|
||||
"endpoint": {
|
||||
"base_url": "https://example-openai-compatible-endpoint",
|
||||
"provider": "custom",
|
||||
"base_url": "https://example-openai-compatible-endpoint/v1",
|
||||
"model": "gpt-4.1-mini",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
"timeout_s": 120
|
||||
|
||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import tomllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from ipaddress import ip_address
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import ip_address
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -34,10 +36,83 @@ def _urlopen(request: urllib.request.Request, *, timeout: float):
|
||||
return urllib.request.urlopen(request, timeout=timeout)
|
||||
|
||||
|
||||
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
def _find_dotenv(start: Path | None = None) -> Path | None:
|
||||
current = (start or Path.cwd()).resolve()
|
||||
for candidate_dir in (current, *current.parents):
|
||||
candidate = candidate_dir / ".env"
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _load_dotenv() -> None:
|
||||
path = _find_dotenv()
|
||||
if path is None:
|
||||
return
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.startswith("export "):
|
||||
line = line[len("export ") :].strip()
|
||||
key, separator, value = line.partition("=")
|
||||
if not separator:
|
||||
continue
|
||||
key = key.strip()
|
||||
if not key:
|
||||
continue
|
||||
value = value.strip()
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
|
||||
value = value[1:-1]
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
|
||||
def _load_codex_network_env() -> None:
|
||||
config_path = Path.home() / ".codex" / "config.toml"
|
||||
if not config_path.is_file():
|
||||
return
|
||||
try:
|
||||
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||
except tomllib.TOMLDecodeError:
|
||||
return
|
||||
network = payload.get("network")
|
||||
if not isinstance(network, dict):
|
||||
return
|
||||
for key, value in network.items():
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
continue
|
||||
normalized = value.strip()
|
||||
os.environ.setdefault(str(key), normalized)
|
||||
if str(key).endswith("_proxy"):
|
||||
os.environ.setdefault(str(key).upper(), normalized)
|
||||
|
||||
|
||||
def _resolve_api_key(api_key_env: str | None, *, provider: str) -> str | None:
|
||||
_load_dotenv()
|
||||
if provider == "codex":
|
||||
_load_codex_network_env()
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
return api_key
|
||||
if provider != "codex" and api_key_env != "OPENAI_API_KEY":
|
||||
return None
|
||||
auth_path = Path.home() / ".codex" / "auth.json"
|
||||
if not auth_path.is_file():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(auth_path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
api_key = payload.get("OPENAI_API_KEY")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _auth_headers(api_key_env: str | None, provider: str = "custom") -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
api_key = _resolve_api_key(api_key_env, provider=provider)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
@@ -71,6 +146,7 @@ def chat_completion(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key_env: str | None,
|
||||
provider: str = "custom",
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
timeout_s: float,
|
||||
@@ -82,7 +158,7 @@ def chat_completion(
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||
headers=_auth_headers(api_key_env),
|
||||
headers=_auth_headers(api_key_env, provider),
|
||||
data=data,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
@@ -158,6 +158,7 @@ def call_llm_for_proposal(
|
||||
response = chat_completion(
|
||||
base_url=policy.endpoint.base_url,
|
||||
api_key_env=policy.endpoint.api_key_env,
|
||||
provider=policy.endpoint.provider,
|
||||
model=policy.endpoint.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
timeout_s=policy.endpoint.timeout_s,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tomllib
|
||||
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||
from pathlib import Path
|
||||
@@ -56,6 +57,47 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]:
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_codex_base_url() -> str:
|
||||
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
|
||||
if override:
|
||||
return override
|
||||
config_path = Path.home() / ".codex" / "config.toml"
|
||||
try:
|
||||
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||
except FileNotFoundError as exc:
|
||||
raise SpecError(
|
||||
"Unable to resolve llm.endpoint.base_url for provider=codex. "
|
||||
"Set llm.endpoint.base_url explicitly, set AITUNER_CODEX_BASE_URL, "
|
||||
f"or provide {config_path}."
|
||||
) from exc
|
||||
except tomllib.TOMLDecodeError as exc:
|
||||
raise SpecError(f"Invalid TOML in {config_path}: {exc}") from exc
|
||||
|
||||
raw_providers = payload.get("model_providers")
|
||||
providers = raw_providers if isinstance(raw_providers, Mapping) else {}
|
||||
provider_name = str(payload.get("model_provider") or "").strip()
|
||||
if provider_name:
|
||||
selected = providers.get(provider_name)
|
||||
if isinstance(selected, Mapping):
|
||||
base_url = str(selected.get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
if len(providers) == 1:
|
||||
only_provider = next(iter(providers.values()))
|
||||
if isinstance(only_provider, Mapping):
|
||||
base_url = str(only_provider.get("base_url") or "").strip()
|
||||
if base_url:
|
||||
return base_url
|
||||
root_base_url = str(payload.get("base_url") or "").strip()
|
||||
if root_base_url:
|
||||
return root_base_url
|
||||
raise SpecError(
|
||||
"Unable to resolve llm.endpoint.base_url for provider=codex from "
|
||||
f"{config_path}. Set llm.endpoint.base_url explicitly or set "
|
||||
"AITUNER_CODEX_BASE_URL."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HardwareSpec:
|
||||
gpu_count: int
|
||||
@@ -343,15 +385,37 @@ class SamplingSearchSpec:
|
||||
class LLMEndpointSpec:
|
||||
base_url: str
|
||||
model: str
|
||||
provider: str = "custom"
|
||||
api_key_env: str = "OPENAI_API_KEY"
|
||||
timeout_s: float = 120.0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
||||
provider = str(data.get("provider") or "custom").strip().lower()
|
||||
base_url = str(data.get("base_url") or "").strip()
|
||||
api_key_env = str(data.get("api_key_env") or "").strip()
|
||||
if provider == "codex":
|
||||
if not base_url:
|
||||
base_url = _resolve_codex_base_url()
|
||||
if not api_key_env:
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
elif provider == "bailian":
|
||||
if not base_url:
|
||||
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
if not api_key_env:
|
||||
api_key_env = "DASHSCOPE_API_KEY"
|
||||
elif provider == "custom":
|
||||
if not base_url:
|
||||
raise SpecError("llm.endpoint.base_url must be a non-empty string.")
|
||||
if not api_key_env:
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
else:
|
||||
raise SpecError(f"Unsupported llm.endpoint.provider: {provider}")
|
||||
return cls(
|
||||
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"),
|
||||
base_url=_require_str(base_url, context="llm.endpoint.base_url"),
|
||||
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(),
|
||||
provider=provider,
|
||||
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
|
||||
timeout_s=_require_float(
|
||||
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
||||
),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -8,12 +9,19 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from aituner.cli import main as cli_main
|
||||
from aituner.http_client import _openai_url, _should_bypass_proxy
|
||||
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
|
||||
from aituner.job import append_job, build_trial_job
|
||||
from aituner.llm import build_prompt, parse_proposal_text
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||
from aituner.spec import Proposal, SpecError, StudyState, TrialSummary, load_study_spec
|
||||
from aituner.spec import (
|
||||
LLMEndpointSpec,
|
||||
Proposal,
|
||||
SpecError,
|
||||
StudyState,
|
||||
TrialSummary,
|
||||
load_study_spec,
|
||||
)
|
||||
from aituner.store import StudyStore
|
||||
from aituner.trace import load_trace_requests, summarize_window
|
||||
from aituner.worker import (
|
||||
@@ -214,6 +222,71 @@ class CoreFlowTests(unittest.TestCase):
|
||||
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
|
||||
load_study_spec(study_path)
|
||||
|
||||
def test_bailian_endpoint_defaults(self) -> None:
|
||||
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
||||
self.assertEqual(endpoint.provider, "bailian")
|
||||
self.assertEqual(
|
||||
endpoint.base_url, "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
self.assertEqual(endpoint.api_key_env, "DASHSCOPE_API_KEY")
|
||||
|
||||
def test_codex_endpoint_resolves_base_url_from_codex_config(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
codex_dir = tmp_path / ".codex"
|
||||
codex_dir.mkdir(parents=True)
|
||||
(codex_dir / "config.toml").write_text(
|
||||
'\n'.join(
|
||||
[
|
||||
'model_provider = "ipads"',
|
||||
"",
|
||||
"[model_providers.ipads]",
|
||||
'base_url = "http://codex.example/v1"',
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
|
||||
endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"})
|
||||
self.assertEqual(endpoint.provider, "codex")
|
||||
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
|
||||
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
|
||||
|
||||
def test_auth_headers_load_bailian_key_from_dotenv(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
(tmp_path / ".env").write_text('DASHSCOPE_API_KEY="dash-key"\n', encoding="utf-8")
|
||||
with mock.patch.dict(os.environ, {}, clear=True):
|
||||
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
|
||||
headers = _auth_headers("DASHSCOPE_API_KEY", "bailian")
|
||||
self.assertEqual(headers["Authorization"], "Bearer dash-key")
|
||||
|
||||
def test_auth_headers_load_codex_auth_and_proxy(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
codex_dir = tmp_path / ".codex"
|
||||
codex_dir.mkdir(parents=True)
|
||||
(codex_dir / "config.toml").write_text(
|
||||
'\n'.join(
|
||||
[
|
||||
"[network]",
|
||||
'http_proxy = "http://proxy.example:3128"',
|
||||
'https_proxy = "http://proxy.example:3128"',
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(codex_dir / "auth.json").write_text(
|
||||
json.dumps({"OPENAI_API_KEY": "sk-codex-test"}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
|
||||
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
|
||||
headers = _auth_headers("OPENAI_API_KEY", "codex")
|
||||
self.assertEqual(os.environ["http_proxy"], "http://proxy.example:3128")
|
||||
self.assertEqual(os.environ["HTTP_PROXY"], "http://proxy.example:3128")
|
||||
self.assertEqual(headers["Authorization"], "Bearer sk-codex-test")
|
||||
|
||||
def test_prompt_includes_failed_trial_context(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user