Add codex and bailian LLM provider presets
This commit is contained in:
9
.env.example
Normal file
9
.env.example
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# Copy to .env when you want to override runtime secrets locally.
|
||||||
|
# codex provider:
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
|
||||||
|
# bailian / DashScope OpenAI-compatible provider:
|
||||||
|
DASHSCOPE_API_KEY=
|
||||||
|
|
||||||
|
# Optional override if ~/.codex/config.toml does not expose a codex base_url:
|
||||||
|
# AITUNER_CODEX_BASE_URL=
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
.aituner/
|
.aituner/
|
||||||
.aituner-smoke/
|
.aituner-smoke/
|
||||||
.aituner-tight/
|
.aituner-tight/
|
||||||
|
.env
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
logs/
|
logs/
|
||||||
|
|||||||
@@ -142,7 +142,7 @@
|
|||||||
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.",
|
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target. Favor launch-safe changes grounded in the incumbent result and only propose knobs that plausibly improve throughput above the incumbent request rate.",
|
||||||
"max_history_trials": 8,
|
"max_history_trials": 8,
|
||||||
"endpoint": {
|
"endpoint": {
|
||||||
"base_url": "http://tianx.ipads-lab.se.sjtu.edu.cn:8317/v1",
|
"provider": "codex",
|
||||||
"model": "gpt-5.4",
|
"model": "gpt-5.4",
|
||||||
"api_key_env": "OPENAI_API_KEY",
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
"timeout_s": 180
|
"timeout_s": 180
|
||||||
|
|||||||
@@ -90,7 +90,8 @@
|
|||||||
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
|
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
|
||||||
"max_history_trials": 8,
|
"max_history_trials": 8,
|
||||||
"endpoint": {
|
"endpoint": {
|
||||||
"base_url": "https://example-openai-compatible-endpoint",
|
"provider": "custom",
|
||||||
|
"base_url": "https://example-openai-compatible-endpoint/v1",
|
||||||
"model": "gpt-4.1-mini",
|
"model": "gpt-4.1-mini",
|
||||||
"api_key_env": "OPENAI_API_KEY",
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
"timeout_s": 120
|
"timeout_s": 120
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import tomllib
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from ipaddress import ip_address
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from ipaddress import ip_address
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@@ -34,12 +36,85 @@ def _urlopen(request: urllib.request.Request, *, timeout: float):
|
|||||||
return urllib.request.urlopen(request, timeout=timeout)
|
return urllib.request.urlopen(request, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
def _find_dotenv(start: Path | None = None) -> Path | None:
|
||||||
headers = {"Content-Type": "application/json"}
|
current = (start or Path.cwd()).resolve()
|
||||||
|
for candidate_dir in (current, *current.parents):
|
||||||
|
candidate = candidate_dir / ".env"
|
||||||
|
if candidate.is_file():
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dotenv() -> None:
|
||||||
|
path = _find_dotenv()
|
||||||
|
if path is None:
|
||||||
|
return
|
||||||
|
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = raw_line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if line.startswith("export "):
|
||||||
|
line = line[len("export ") :].strip()
|
||||||
|
key, separator, value = line.partition("=")
|
||||||
|
if not separator:
|
||||||
|
continue
|
||||||
|
key = key.strip()
|
||||||
|
if not key:
|
||||||
|
continue
|
||||||
|
value = value.strip()
|
||||||
|
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
|
||||||
|
value = value[1:-1]
|
||||||
|
os.environ.setdefault(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_codex_network_env() -> None:
|
||||||
|
config_path = Path.home() / ".codex" / "config.toml"
|
||||||
|
if not config_path.is_file():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
except tomllib.TOMLDecodeError:
|
||||||
|
return
|
||||||
|
network = payload.get("network")
|
||||||
|
if not isinstance(network, dict):
|
||||||
|
return
|
||||||
|
for key, value in network.items():
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
continue
|
||||||
|
normalized = value.strip()
|
||||||
|
os.environ.setdefault(str(key), normalized)
|
||||||
|
if str(key).endswith("_proxy"):
|
||||||
|
os.environ.setdefault(str(key).upper(), normalized)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key(api_key_env: str | None, *, provider: str) -> str | None:
|
||||||
|
_load_dotenv()
|
||||||
|
if provider == "codex":
|
||||||
|
_load_codex_network_env()
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
return api_key
|
||||||
|
if provider != "codex" and api_key_env != "OPENAI_API_KEY":
|
||||||
|
return None
|
||||||
|
auth_path = Path.home() / ".codex" / "auth.json"
|
||||||
|
if not auth_path.is_file():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload = json.loads(auth_path.read_text(encoding="utf-8"))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
api_key = payload.get("OPENAI_API_KEY")
|
||||||
|
if isinstance(api_key, str) and api_key.strip():
|
||||||
|
return api_key.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers(api_key_env: str | None, provider: str = "custom") -> dict[str, str]:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
api_key = _resolve_api_key(api_key_env, provider=provider)
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +146,7 @@ def chat_completion(
|
|||||||
*,
|
*,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key_env: str | None,
|
api_key_env: str | None,
|
||||||
|
provider: str = "custom",
|
||||||
model: str,
|
model: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
timeout_s: float,
|
timeout_s: float,
|
||||||
@@ -82,7 +158,7 @@ def chat_completion(
|
|||||||
data = json.dumps(payload).encode("utf-8")
|
data = json.dumps(payload).encode("utf-8")
|
||||||
request = urllib.request.Request(
|
request = urllib.request.Request(
|
||||||
url=_openai_url(base_url, "/v1/chat/completions"),
|
url=_openai_url(base_url, "/v1/chat/completions"),
|
||||||
headers=_auth_headers(api_key_env),
|
headers=_auth_headers(api_key_env, provider),
|
||||||
data=data,
|
data=data,
|
||||||
method="POST",
|
method="POST",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ def call_llm_for_proposal(
|
|||||||
response = chat_completion(
|
response = chat_completion(
|
||||||
base_url=policy.endpoint.base_url,
|
base_url=policy.endpoint.base_url,
|
||||||
api_key_env=policy.endpoint.api_key_env,
|
api_key_env=policy.endpoint.api_key_env,
|
||||||
|
provider=policy.endpoint.provider,
|
||||||
model=policy.endpoint.model,
|
model=policy.endpoint.model,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
timeout_s=policy.endpoint.timeout_s,
|
timeout_s=policy.endpoint.timeout_s,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import tomllib
|
import tomllib
|
||||||
from dataclasses import asdict, dataclass, field, is_dataclass
|
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -56,6 +57,47 @@ def _coerce_str_list(value: Any, *, context: str) -> list[str]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_codex_base_url() -> str:
|
||||||
|
override = os.environ.get("AITUNER_CODEX_BASE_URL", "").strip()
|
||||||
|
if override:
|
||||||
|
return override
|
||||||
|
config_path = Path.home() / ".codex" / "config.toml"
|
||||||
|
try:
|
||||||
|
payload = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
raise SpecError(
|
||||||
|
"Unable to resolve llm.endpoint.base_url for provider=codex. "
|
||||||
|
"Set llm.endpoint.base_url explicitly, set AITUNER_CODEX_BASE_URL, "
|
||||||
|
f"or provide {config_path}."
|
||||||
|
) from exc
|
||||||
|
except tomllib.TOMLDecodeError as exc:
|
||||||
|
raise SpecError(f"Invalid TOML in {config_path}: {exc}") from exc
|
||||||
|
|
||||||
|
raw_providers = payload.get("model_providers")
|
||||||
|
providers = raw_providers if isinstance(raw_providers, Mapping) else {}
|
||||||
|
provider_name = str(payload.get("model_provider") or "").strip()
|
||||||
|
if provider_name:
|
||||||
|
selected = providers.get(provider_name)
|
||||||
|
if isinstance(selected, Mapping):
|
||||||
|
base_url = str(selected.get("base_url") or "").strip()
|
||||||
|
if base_url:
|
||||||
|
return base_url
|
||||||
|
if len(providers) == 1:
|
||||||
|
only_provider = next(iter(providers.values()))
|
||||||
|
if isinstance(only_provider, Mapping):
|
||||||
|
base_url = str(only_provider.get("base_url") or "").strip()
|
||||||
|
if base_url:
|
||||||
|
return base_url
|
||||||
|
root_base_url = str(payload.get("base_url") or "").strip()
|
||||||
|
if root_base_url:
|
||||||
|
return root_base_url
|
||||||
|
raise SpecError(
|
||||||
|
"Unable to resolve llm.endpoint.base_url for provider=codex from "
|
||||||
|
f"{config_path}. Set llm.endpoint.base_url explicitly or set "
|
||||||
|
"AITUNER_CODEX_BASE_URL."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class HardwareSpec:
|
class HardwareSpec:
|
||||||
gpu_count: int
|
gpu_count: int
|
||||||
@@ -343,15 +385,37 @@ class SamplingSearchSpec:
|
|||||||
class LLMEndpointSpec:
|
class LLMEndpointSpec:
|
||||||
base_url: str
|
base_url: str
|
||||||
model: str
|
model: str
|
||||||
|
provider: str = "custom"
|
||||||
api_key_env: str = "OPENAI_API_KEY"
|
api_key_env: str = "OPENAI_API_KEY"
|
||||||
timeout_s: float = 120.0
|
timeout_s: float = 120.0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
||||||
|
provider = str(data.get("provider") or "custom").strip().lower()
|
||||||
|
base_url = str(data.get("base_url") or "").strip()
|
||||||
|
api_key_env = str(data.get("api_key_env") or "").strip()
|
||||||
|
if provider == "codex":
|
||||||
|
if not base_url:
|
||||||
|
base_url = _resolve_codex_base_url()
|
||||||
|
if not api_key_env:
|
||||||
|
api_key_env = "OPENAI_API_KEY"
|
||||||
|
elif provider == "bailian":
|
||||||
|
if not base_url:
|
||||||
|
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
if not api_key_env:
|
||||||
|
api_key_env = "DASHSCOPE_API_KEY"
|
||||||
|
elif provider == "custom":
|
||||||
|
if not base_url:
|
||||||
|
raise SpecError("llm.endpoint.base_url must be a non-empty string.")
|
||||||
|
if not api_key_env:
|
||||||
|
api_key_env = "OPENAI_API_KEY"
|
||||||
|
else:
|
||||||
|
raise SpecError(f"Unsupported llm.endpoint.provider: {provider}")
|
||||||
return cls(
|
return cls(
|
||||||
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"),
|
base_url=_require_str(base_url, context="llm.endpoint.base_url"),
|
||||||
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||||
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(),
|
provider=provider,
|
||||||
|
api_key_env=_require_str(api_key_env, context="llm.endpoint.api_key_env"),
|
||||||
timeout_s=_require_float(
|
timeout_s=_require_float(
|
||||||
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -8,12 +9,19 @@ from pathlib import Path
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from aituner.cli import main as cli_main
|
from aituner.cli import main as cli_main
|
||||||
from aituner.http_client import _openai_url, _should_bypass_proxy
|
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
|
||||||
from aituner.job import append_job, build_trial_job
|
from aituner.job import append_job, build_trial_job
|
||||||
from aituner.llm import build_prompt, parse_proposal_text
|
from aituner.llm import build_prompt, parse_proposal_text
|
||||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||||
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
from aituner.slo import RequestOutcome, evaluate_request, summarize_evaluations
|
||||||
from aituner.spec import Proposal, SpecError, StudyState, TrialSummary, load_study_spec
|
from aituner.spec import (
|
||||||
|
LLMEndpointSpec,
|
||||||
|
Proposal,
|
||||||
|
SpecError,
|
||||||
|
StudyState,
|
||||||
|
TrialSummary,
|
||||||
|
load_study_spec,
|
||||||
|
)
|
||||||
from aituner.store import StudyStore
|
from aituner.store import StudyStore
|
||||||
from aituner.trace import load_trace_requests, summarize_window
|
from aituner.trace import load_trace_requests, summarize_window
|
||||||
from aituner.worker import (
|
from aituner.worker import (
|
||||||
@@ -214,6 +222,71 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
|
with self.assertRaisesRegex(SpecError, "min_input_tokens must be <="):
|
||||||
load_study_spec(study_path)
|
load_study_spec(study_path)
|
||||||
|
|
||||||
|
def test_bailian_endpoint_defaults(self) -> None:
|
||||||
|
endpoint = LLMEndpointSpec.from_dict({"provider": "bailian", "model": "qwen-plus"})
|
||||||
|
self.assertEqual(endpoint.provider, "bailian")
|
||||||
|
self.assertEqual(
|
||||||
|
endpoint.base_url, "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
)
|
||||||
|
self.assertEqual(endpoint.api_key_env, "DASHSCOPE_API_KEY")
|
||||||
|
|
||||||
|
def test_codex_endpoint_resolves_base_url_from_codex_config(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
codex_dir = tmp_path / ".codex"
|
||||||
|
codex_dir.mkdir(parents=True)
|
||||||
|
(codex_dir / "config.toml").write_text(
|
||||||
|
'\n'.join(
|
||||||
|
[
|
||||||
|
'model_provider = "ipads"',
|
||||||
|
"",
|
||||||
|
"[model_providers.ipads]",
|
||||||
|
'base_url = "http://codex.example/v1"',
|
||||||
|
]
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
|
||||||
|
endpoint = LLMEndpointSpec.from_dict({"provider": "codex", "model": "gpt-5.4"})
|
||||||
|
self.assertEqual(endpoint.provider, "codex")
|
||||||
|
self.assertEqual(endpoint.base_url, "http://codex.example/v1")
|
||||||
|
self.assertEqual(endpoint.api_key_env, "OPENAI_API_KEY")
|
||||||
|
|
||||||
|
def test_auth_headers_load_bailian_key_from_dotenv(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
(tmp_path / ".env").write_text('DASHSCOPE_API_KEY="dash-key"\n', encoding="utf-8")
|
||||||
|
with mock.patch.dict(os.environ, {}, clear=True):
|
||||||
|
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
|
||||||
|
headers = _auth_headers("DASHSCOPE_API_KEY", "bailian")
|
||||||
|
self.assertEqual(headers["Authorization"], "Bearer dash-key")
|
||||||
|
|
||||||
|
def test_auth_headers_load_codex_auth_and_proxy(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
codex_dir = tmp_path / ".codex"
|
||||||
|
codex_dir.mkdir(parents=True)
|
||||||
|
(codex_dir / "config.toml").write_text(
|
||||||
|
'\n'.join(
|
||||||
|
[
|
||||||
|
"[network]",
|
||||||
|
'http_proxy = "http://proxy.example:3128"',
|
||||||
|
'https_proxy = "http://proxy.example:3128"',
|
||||||
|
]
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(codex_dir / "auth.json").write_text(
|
||||||
|
json.dumps({"OPENAI_API_KEY": "sk-codex-test"}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
with mock.patch.dict(os.environ, {"HOME": str(tmp_path)}, clear=True):
|
||||||
|
with mock.patch("pathlib.Path.cwd", return_value=tmp_path):
|
||||||
|
headers = _auth_headers("OPENAI_API_KEY", "codex")
|
||||||
|
self.assertEqual(os.environ["http_proxy"], "http://proxy.example:3128")
|
||||||
|
self.assertEqual(os.environ["HTTP_PROXY"], "http://proxy.example:3128")
|
||||||
|
self.assertEqual(headers["Authorization"], "Bearer sk-codex-test")
|
||||||
|
|
||||||
def test_prompt_includes_failed_trial_context(self) -> None:
|
def test_prompt_includes_failed_trial_context(self) -> None:
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
tmp_path = Path(tmp)
|
tmp_path = Path(tmp)
|
||||||
|
|||||||
Reference in New Issue
Block a user