Initial AITuner study orchestrator

This commit is contained in:
gahow
2026-04-04 21:26:37 +08:00
commit cdcca1d9d7
24 changed files with 3357 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
.aituner/
__pycache__/
*.pyc
infra/gpu_fleet/config/fleet.toml
infra/gpu_fleet/config/jobs.toml

View File

@@ -0,0 +1,14 @@
{
"prefill_service_by_bucket": {
"4k": {
"tp4_ms": 320,
"tp8_ms": 240
}
},
"queueing_knee_by_bucket": {
"4k": {
"tp4_tok_s_per_gpu": 1000,
"tp8_tok_s_per_gpu": 1100
}
}
}

View File

@@ -0,0 +1,96 @@
{
"study_id": "example-chat-window",
"hardware": {
"gpu_count": 8,
"gpu_model": "H20",
"host_candidates": ["dash0", "dash1"]
},
"model": {
"model_id": "qwen3-30b",
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
},
"engine": {
"engine_name": "vllm",
"engine_version": "0.x",
"exec_path": "/usr/local/bin/vllm",
"cwd": ".",
"host": "127.0.0.1",
"port": 8000,
"healthcheck_path": "/v1/models",
"ready_timeout_s": 600,
"request_timeout_s": 600,
"launch_args": [
"serve",
"/path/to/model"
],
"base_envs": {},
"base_flags": {
"host": "127.0.0.1",
"port": 8000,
"served-model-name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
},
"tunable_envs": [
"VLLM_ATTENTION_BACKEND",
"CUDA_GRAPH_MAX_BATCH_SIZE"
],
"tunable_flags": [
"tensor-parallel-size",
"data-parallel-size",
"pipeline-parallel-size",
"max-num-seqs",
"max-num-batched-tokens",
"gpu-memory-utilization",
"enable-prefix-caching",
"block-size"
],
"python_executable": "python3"
},
"trace": {
"windows_path": "trace_windows/windows.json",
"window_id": "chat_w_example_peak_0001",
"u_field": "sampling_u",
"timestamp_field": "timestamp",
"max_concurrency": 64
},
"slo": {
"target_pass_rate": 0.95,
"ttft_rule": {
"kind": "step_ms",
"buckets": [
{
"max_input_tokens": 4096,
"threshold_ms": 2000
},
{
"max_input_tokens": 16384,
"threshold_ms": 4000
},
{
"threshold_ms": 8000
}
]
},
"tpot_rule": {
"kind": "fixed_ms",
"threshold_ms": 120
}
},
"search": {
"low": 0.0,
"high": 1.0,
"tolerance": 0.01,
"max_probes": 8,
"sample_seed": 20260325
},
"llm": {
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
"max_history_trials": 8,
"endpoint": {
"base_url": "https://example-openai-compatible-endpoint",
"model": "gpt-4.1-mini",
"api_key_env": "OPENAI_API_KEY",
"timeout_s": 120
}
},
"capability_profile_path": "capability.example.json"
}

View File

@@ -0,0 +1,3 @@
{"request_id":"example-1","timestamp":0.0,"sampling_u":0.10,"messages":[{"role":"user","content":"hello"}],"input_length":512,"output_length":16}
{"request_id":"example-2","timestamp":1.0,"sampling_u":0.50,"messages":[{"role":"user","content":"summarize this file"}],"input_length":2048,"output_length":64}
{"request_id":"example-3","timestamp":2.5,"sampling_u":0.90,"messages":[{"role":"user","content":"write a longer answer"}],"input_length":8192,"output_length":128}

View File

@@ -0,0 +1,15 @@
{
"sample_seed": 20260325,
"u_field": "sampling_u",
"window_duration_seconds": 10.0,
"windows": [
{
"window_id": "chat_w_example_peak_0001",
"trace_type": "chat",
"trace_file": "traces/chat_w_example_peak_0001.jsonl",
"window_start": 0.0,
"window_end": 10.0,
"num_requests": 3
}
]
}

View File

@@ -0,0 +1,59 @@
version = 1
[paths]
state_dir = ".aituner/gpu_fleet/state"
artifacts_dir = ".aituner/gpu_fleet/artifacts"
[ssh]
connect_timeout_sec = 10
[scheduler]
gpu_free_memory_mb = 1024
gpu_free_utilization_pct = 10
prefer_pack = true
[sync]
mode = "rsync"
local_path = "."
exclude = [
".git/",
".venv/",
".aituner/",
"__pycache__/",
"*.pyc",
]
[[hosts]]
name = "dash0"
ssh_alias = "dash0"
enabled = true
sync_remote_path = "~/workspace/aituner"
fleet_root = "~/.aituner_gpu_fleet"
[[hosts]]
name = "dash1"
ssh_alias = "dash1"
enabled = true
sync_remote_path = "~/workspace/aituner"
fleet_root = "~/.aituner_gpu_fleet"
[[hosts]]
name = "dash2"
ssh_alias = "dash2"
enabled = true
sync_remote_path = "~/workspace/aituner"
fleet_root = "~/.aituner_gpu_fleet"
[[hosts]]
name = "dash3"
ssh_alias = "dash3"
enabled = true
sync_remote_path = "~/aituner"
fleet_root = "~/.aituner_gpu_fleet"
[[hosts]]
name = "dash5"
ssh_alias = "dash5"
enabled = true
sync_remote_path = "~/workspace/aituner"
fleet_root = "~/.aituner_gpu_fleet"

View File

@@ -0,0 +1,27 @@
# This file is an append-only queue source for the monitor.
# Each job name must stay unique and immutable once appended.
version = 1
[[jobs]]
name = "smoke-train-h20-1gpu"
gpus = 1
gpu_model = "H20"
hosts = ["dash0", "dash1", "dash2"]
command = "python train.py --config configs/smoke.toml"
artifacts = ["outputs/smoke-train-h20-1gpu"]
env = { WANDB_MODE = "offline" }
[[jobs]]
name = "eval-5090-4gpu"
gpus = 4
gpu_model = "5090"
hosts = ["dash5"]
command = "python eval.py --config configs/eval.toml"
artifacts = ["outputs/eval-5090-4gpu", "logs/eval-5090-4gpu.log"]
[[jobs]]
name = "special-dash3-run"
gpus = 2
hosts = ["dash3"]
command = "python benchmark.py --suite long-context"
artifacts = ["outputs/special-dash3-run"]

View File

@@ -0,0 +1,8 @@
# One SSH alias per line.
# Lines starting with "#" are ignored.
dash0
dash1
dash2
dash3
dash5

1132
infra/gpu_fleet/gpu_fleet.py Executable file

File diff suppressed because it is too large Load Diff

19
pyproject.toml Normal file
View File

@@ -0,0 +1,19 @@
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[project]
name = "aituner"
version = "0.1.0"
description = "AITuner study orchestrator for OpenAI-compatible serving engines"
requires-python = ">=3.11"
dependencies = []
[project.scripts]
aituner = "aituner.cli:main"
[tool.setuptools]
package-dir = {"" = "src"}
[tool.setuptools.packages.find]
where = ["src"]

5
src/aituner/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""AITuner package."""
__all__ = [
"cli",
]

177
src/aituner/cli.py Normal file
View File

@@ -0,0 +1,177 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from .job import append_job, build_trial_job
from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text
from .spec import Proposal, SpecError, load_study_spec
from .store import StudyStore
from .trace import load_trace_requests, summarize_window
from .worker import run_trial
def _study_source_path(study_root: Path) -> Path:
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
def cmd_study_init(args: argparse.Namespace) -> int:
spec_path = Path(args.spec).resolve()
study = load_study_spec(spec_path)
store = StudyStore(Path(args.store_root) if args.store_root else None)
root = store.init_study(spec_path=spec_path, study=study)
print(root)
return 0
def cmd_study_prompt(args: argparse.Namespace) -> int:
store = StudyStore(Path(args.store_root) if args.store_root else None)
study_root = Path(args.study_root).resolve()
source_path = _study_source_path(study_root)
study = load_study_spec(source_path)
state = store.load_state(study.study_id)
capability_profile = load_capability_profile(study, study_spec_path=source_path)
window, requests = load_trace_requests(study, study_spec_path=source_path)
prompt = build_prompt(
study=study,
window_summary=summarize_window(requests, window),
state=state,
capability_profile=capability_profile,
)
prompt_name = args.prompt_name or f"prompt-{state.next_trial_index:04d}"
path = store.write_prompt(study.study_id, prompt_name, prompt)
print(path)
return 0
def cmd_study_llm_propose(args: argparse.Namespace) -> int:
store = StudyStore(Path(args.store_root) if args.store_root else None)
study_root = Path(args.study_root).resolve()
source_path = _study_source_path(study_root)
study = load_study_spec(source_path)
state = store.load_state(study.study_id)
capability_profile = load_capability_profile(study, study_spec_path=source_path)
window, requests = load_trace_requests(study, study_spec_path=source_path)
prompt = build_prompt(
study=study,
window_summary=summarize_window(requests, window),
state=state,
capability_profile=capability_profile,
)
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
proposal = parse_proposal_text(proposal_text, study)
name = args.proposal_name or f"proposal-{state.next_trial_index:04d}"
path = store.write_proposal(study.study_id, name, proposal)
print(path)
return 0
def cmd_study_register_proposal(args: argparse.Namespace) -> int:
store = StudyStore(Path(args.store_root) if args.store_root else None)
study_root = Path(args.study_root).resolve()
source_path = _study_source_path(study_root)
study = load_study_spec(source_path)
proposal = parse_proposal_text(Path(args.proposal_file).read_text(encoding="utf-8"), study)
name = args.proposal_name or Path(args.proposal_file).stem
path = store.write_proposal(study.study_id, name, proposal)
print(path)
return 0
def cmd_study_emit_job(args: argparse.Namespace) -> int:
store = StudyStore(Path(args.store_root) if args.store_root else None)
study_root = Path(args.study_root).resolve()
source_path = _study_source_path(study_root)
study = load_study_spec(source_path)
state = store.load_state(study.study_id)
proposal_text = Path(args.proposal_file).read_text(encoding="utf-8")
proposal = parse_proposal_text(proposal_text, study)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
repo_root = Path(__file__).resolve().parents[2]
job = build_trial_job(study=study, trial=trial, repo_root=repo_root)
append_job(Path(args.jobs_file).resolve(), job)
print(trial.trial_id)
return 0
def cmd_study_ingest(args: argparse.Namespace) -> int:
store = StudyStore(Path(args.store_root) if args.store_root else None)
study_root = Path(args.study_root).resolve()
study_id = study_root.name
state = store.ingest_trial_results(study_id)
print(json.dumps({"best_trial_id": state.best_trial_id, "best_request_rate": state.best_request_rate}))
return 0
def cmd_worker_run_trial(args: argparse.Namespace) -> int:
result = run_trial(Path(args.trial_spec).resolve())
print(json.dumps(result))
return 0
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="AITuner CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
study = subparsers.add_parser("study")
study_sub = study.add_subparsers(dest="study_command", required=True)
init = study_sub.add_parser("init")
init.add_argument("--spec", required=True)
init.add_argument("--store-root")
init.set_defaults(func=cmd_study_init)
prompt = study_sub.add_parser("prompt")
prompt.add_argument("--study-root", required=True)
prompt.add_argument("--store-root")
prompt.add_argument("--prompt-name")
prompt.set_defaults(func=cmd_study_prompt)
llm_propose = study_sub.add_parser("llm-propose")
llm_propose.add_argument("--study-root", required=True)
llm_propose.add_argument("--store-root")
llm_propose.add_argument("--proposal-name")
llm_propose.set_defaults(func=cmd_study_llm_propose)
register = study_sub.add_parser("register-proposal")
register.add_argument("--study-root", required=True)
register.add_argument("--store-root")
register.add_argument("--proposal-file", required=True)
register.add_argument("--proposal-name")
register.set_defaults(func=cmd_study_register_proposal)
emit = study_sub.add_parser("emit-job")
emit.add_argument("--study-root", required=True)
emit.add_argument("--store-root")
emit.add_argument("--proposal-file", required=True)
emit.add_argument("--jobs-file", required=True)
emit.set_defaults(func=cmd_study_emit_job)
ingest = study_sub.add_parser("ingest")
ingest.add_argument("--study-root", required=True)
ingest.add_argument("--store-root")
ingest.set_defaults(func=cmd_study_ingest)
worker = subparsers.add_parser("worker")
worker_sub = worker.add_subparsers(dest="worker_command", required=True)
run = worker_sub.add_parser("run-trial")
run.add_argument("--trial-spec", required=True)
run.set_defaults(func=cmd_worker_run_trial)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
try:
return int(args.func(args))
except SpecError as exc:
print(str(exc), file=sys.stderr)
return 2
if __name__ == "__main__":
raise SystemExit(main())

65
src/aituner/engine.py Normal file
View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import os
import shlex
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .spec import ConfigPatch, EngineLaunchSpec
@dataclass(frozen=True)
class LaunchRecipe:
argv: list[str]
env: dict[str, str]
cwd: str | None
base_url: str
healthcheck_path: str
ready_timeout_s: float
request_timeout_s: float
def _normalize_flag_name(name: str) -> str:
return str(name).strip().replace("_", "-")
def _serialize_flag_parts(name: str, value: Any) -> list[str]:
flag = f"--{_normalize_flag_name(name)}"
if value is None:
return []
if isinstance(value, bool):
return [flag] if value else [f"--no-{_normalize_flag_name(name)}"]
if isinstance(value, list):
parts: list[str] = []
for item in value:
parts.extend([flag, str(item)])
return parts
return [flag, str(value)]
def build_launch_recipe(spec: EngineLaunchSpec, patch: ConfigPatch) -> LaunchRecipe:
env = dict(os.environ)
env.update(spec.base_envs)
env.update(patch.env_patch)
flags = dict(spec.base_flags)
flags.update(patch.flag_patch)
argv = [spec.exec_path, *spec.launch_args]
for key, value in flags.items():
argv.extend(_serialize_flag_parts(key, value))
cwd = None
if spec.cwd:
cwd = str(Path(spec.cwd).expanduser())
return LaunchRecipe(
argv=argv,
env={str(key): str(value) for key, value in env.items()},
cwd=cwd,
base_url=spec.base_url,
healthcheck_path=spec.healthcheck_path,
ready_timeout_s=spec.ready_timeout_s,
request_timeout_s=spec.request_timeout_s,
)
def shell_join(argv: list[str]) -> str:
return " ".join(shlex.quote(item) for item in argv)

147
src/aituner/http_client.py Normal file
View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import json
import os
import time
import urllib.error
import urllib.request
from dataclasses import dataclass
from typing import Any, Iterable
class HttpClientError(RuntimeError):
"""Raised for HTTP client failures."""
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
deadline = time.monotonic() + timeout_s
url = f"{base_url.rstrip('/')}{path}"
last_error = "server_not_ready"
while time.monotonic() < deadline:
try:
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
with urllib.request.urlopen(request, timeout=5) as response:
if 200 <= response.status < 500:
return
except Exception as exc: # noqa: BLE001
last_error = str(exc)
time.sleep(1.0)
raise HttpClientError(f"Timed out waiting for {url}: {last_error}")
def chat_completion(
*,
base_url: str,
api_key_env: str | None,
model: str,
messages: list[dict[str, Any]],
timeout_s: float,
system_prompt: str = "",
) -> dict[str, Any]:
payload: dict[str, Any] = {"model": model, "messages": messages}
if system_prompt:
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
data = json.dumps(payload).encode("utf-8")
request = urllib.request.Request(
url=f"{base_url.rstrip('/')}/v1/chat/completions",
headers=_auth_headers(api_key_env),
data=data,
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=timeout_s) as response:
return json.loads(response.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc
@dataclass(frozen=True)
class StreamMetrics:
ttft_ms: float | None
tpot_ms: float | None
completion_tokens: int | None
def stream_chat_completion(
*,
base_url: str,
body: dict[str, Any],
timeout_s: float,
) -> StreamMetrics:
data = json.dumps(body).encode("utf-8")
request = urllib.request.Request(
url=f"{base_url.rstrip('/')}/v1/chat/completions",
headers=_auth_headers(None),
data=data,
method="POST",
)
start = time.monotonic()
first_token_at: float | None = None
last_token_at: float | None = None
chunk_token_count = 0
completion_tokens: int | None = None
try:
with urllib.request.urlopen(request, timeout=timeout_s) as response:
for raw in _iter_sse_lines(response):
if raw == "[DONE]":
break
payload = json.loads(raw)
if not isinstance(payload, dict):
continue
usage = payload.get("usage")
if isinstance(usage, dict):
comp = usage.get("completion_tokens")
if isinstance(comp, int) and comp >= 0:
completion_tokens = comp
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
continue
delta = choices[0].get("delta", {})
if not isinstance(delta, dict):
continue
content = delta.get("content")
if isinstance(content, str) and content:
now = time.monotonic()
if first_token_at is None:
first_token_at = now
last_token_at = now
chunk_token_count += 1
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc
ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0
used_tokens = completion_tokens if completion_tokens is not None else chunk_token_count
if (
first_token_at is None
or last_token_at is None
or used_tokens is None
or used_tokens <= 1
):
tpot_ms = None
else:
tpot_ms = ((last_token_at - first_token_at) / max(used_tokens - 1, 1)) * 1000.0
return StreamMetrics(
ttft_ms=ttft_ms,
tpot_ms=tpot_ms,
completion_tokens=used_tokens if used_tokens > 0 else None,
)
def _iter_sse_lines(response: Any) -> Iterable[str]:
for raw in response:
line = raw.decode("utf-8", errors="replace").strip()
if not line.startswith("data:"):
continue
payload = line[len("data:") :].strip()
if payload:
yield payload

75
src/aituner/job.py Normal file
View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .spec import StudySpec, TrialSpec
@dataclass(frozen=True)
class InfraJob:
name: str
gpus: int
gpu_model: str | None
hosts: list[str]
command: str
artifacts: list[str]
env: dict[str, str]
def _toml_scalar(value: Any) -> str:
if isinstance(value, bool):
return "true" if value else "false"
if isinstance(value, int):
return str(value)
text = str(value).replace("\\", "\\\\").replace('"', '\\"')
return f'"{text}"'
def _toml_list(values: list[Any]) -> str:
return "[" + ", ".join(_toml_scalar(item) for item in values) + "]"
def _toml_inline_table(mapping: dict[str, str]) -> str:
parts = [f"{key} = {_toml_scalar(value)}" for key, value in sorted(mapping.items())]
return "{ " + ", ".join(parts) + " }"
def build_trial_job(*, study: StudySpec, trial: TrialSpec, repo_root: Path) -> InfraJob:
trial_path = Path(trial.artifact_dir) / "trial_spec.json"
rel_trial_path = trial_path.resolve().relative_to(repo_root.resolve())
rel_trial_dir = Path(trial.artifact_dir).resolve().relative_to(repo_root.resolve())
command = (
f"{study.engine.python_executable} -m aituner.cli worker run-trial "
f"--trial-spec {rel_trial_path}"
)
env = {"PYTHONPATH": "src"}
return InfraJob(
name=f"{study.study_id}-{trial.trial_id}",
gpus=study.hardware.gpu_count,
gpu_model=study.hardware.gpu_model,
hosts=list(study.hardware.host_candidates),
command=command,
artifacts=[str(rel_trial_dir)],
env=env,
)
def append_job(jobs_path: Path, job: InfraJob) -> None:
jobs_path.parent.mkdir(parents=True, exist_ok=True)
with jobs_path.open("a", encoding="utf-8") as handle:
if jobs_path.stat().st_size == 0:
handle.write("version = 1\n")
handle.write("\n[[jobs]]\n")
handle.write(f"name = {_toml_scalar(job.name)}\n")
handle.write(f"gpus = {job.gpus}\n")
if job.gpu_model:
handle.write(f"gpu_model = {_toml_scalar(job.gpu_model)}\n")
if job.hosts:
handle.write(f"hosts = {_toml_list(job.hosts)}\n")
handle.write(f"command = {_toml_scalar(job.command)}\n")
if job.artifacts:
handle.write(f"artifacts = {_toml_list(job.artifacts)}\n")
if job.env:
handle.write(f"env = {_toml_inline_table(job.env)}\n")

144
src/aituner/llm.py Normal file
View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from .http_client import chat_completion
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
def build_prompt(
*,
study: StudySpec,
window_summary: dict[str, Any],
state: StudyState,
capability_profile: dict[str, Any] | None,
) -> str:
history = []
for trial in state.trials[-study.llm.max_history_trials :]:
history.append(
{
"trial_id": trial.trial_id,
"status": trial.status,
"best_sampling_u": trial.best_sampling_u,
"best_request_rate": trial.best_request_rate,
"best_pass_rate": trial.best_pass_rate,
"diagnosis": trial.diagnosis,
}
)
sections = [
"You are tuning an OpenAI-compatible serving engine.",
"Return exactly one JSON object with keys: observation, diagnosis, config_patch, expected_effects, why_not_previous_failures.",
"config_patch must contain env_patch and flag_patch.",
"Only use allowed tunable env keys and allowed tunable flag keys.",
"",
"Study stack:",
json.dumps(
{
"study_id": study.study_id,
"hardware": {
"gpu_count": study.hardware.gpu_count,
"gpu_model": study.hardware.gpu_model,
},
"model": {
"model_id": study.model.model_id,
"served_model_name": study.model.served_model_name,
},
"engine": {
"engine_name": study.engine.engine_name,
"engine_version": study.engine.engine_version,
"base_flags": study.engine.base_flags,
"base_envs": study.engine.base_envs,
"allowed_flag_keys": study.engine.tunable_flags,
"allowed_env_keys": study.engine.tunable_envs,
},
},
ensure_ascii=False,
indent=2,
),
"",
"Window summary:",
json.dumps(window_summary, ensure_ascii=False, indent=2),
"",
"SLO:",
json.dumps(
{
"target_pass_rate": study.slo.target_pass_rate,
"ttft_rule": study.slo.ttft_rule,
"tpot_rule": study.slo.tpot_rule,
},
default=lambda value: value.__dict__,
ensure_ascii=False,
indent=2,
),
"",
"Capability profile:",
json.dumps(capability_profile or {}, ensure_ascii=False, indent=2),
"",
"Trial history:",
json.dumps(history, ensure_ascii=False, indent=2),
"",
"The proposal should improve the maximum feasible sampling_u under the 95%+ SLO target.",
]
return "\n".join(sections)
def load_capability_profile(study: StudySpec, *, study_spec_path: Path) -> dict[str, Any] | None:
if not study.capability_profile_path:
return None
path = Path(study.capability_profile_path)
if not path.is_absolute():
path = (study_spec_path.parent / path).resolve()
return json.loads(path.read_text(encoding="utf-8"))
def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
unknown_envs = sorted(set(proposal.config_patch.env_patch) - set(study.engine.tunable_envs))
unknown_flags = sorted(
set(proposal.config_patch.flag_patch) - set(study.engine.tunable_flags)
)
if unknown_envs:
raise SpecError(f"Proposal uses unsupported env keys: {', '.join(unknown_envs)}")
if unknown_flags:
raise SpecError(f"Proposal uses unsupported flag keys: {', '.join(unknown_flags)}")
return proposal
def parse_proposal_text(text: str, study: StudySpec) -> Proposal:
payload = json.loads(text)
proposal = Proposal.from_dict(payload)
return validate_proposal(proposal, study)
def call_llm_for_proposal(
*,
policy: LLMPolicySpec,
prompt: str,
) -> str:
if policy.endpoint is None:
raise RuntimeError("study.llm.endpoint is not configured")
response = chat_completion(
base_url=policy.endpoint.base_url,
api_key_env=policy.endpoint.api_key_env,
model=policy.endpoint.model,
messages=[{"role": "user", "content": prompt}],
timeout_s=policy.endpoint.timeout_s,
system_prompt=policy.system_prompt,
)
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
raise RuntimeError("LLM response does not contain choices")
message = choices[0].get("message", {})
if not isinstance(message, dict):
raise RuntimeError("LLM response does not contain a valid message")
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and isinstance(item.get("text"), str)
)
raise RuntimeError("LLM response content is empty")

58
src/aituner/search.py Normal file
View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Generic, TypeVar
T = TypeVar("T")
@dataclass(frozen=True)
class ThresholdProbe(Generic[T]):
threshold: float
feasible: bool
payload: T
@dataclass(frozen=True)
class ThresholdSearchResult(Generic[T]):
best_threshold: float
best_feasible_payload: T | None
probes: list[ThresholdProbe[T]]
def binary_search_max_feasible(
*,
low: float,
high: float,
tolerance: float,
max_probes: int,
evaluator: Callable[[float], ThresholdProbe[T]],
) -> ThresholdSearchResult[T]:
probes: list[ThresholdProbe[T]] = []
cache: dict[float, ThresholdProbe[T]] = {}
best_threshold = low
best_payload: T | None = None
cur_low = low
cur_high = high
for _ in range(max_probes):
if cur_high - cur_low <= tolerance:
break
threshold = round((cur_low + cur_high) / 2.0, 12)
probe = cache.get(threshold)
if probe is None:
probe = evaluator(threshold)
cache[threshold] = probe
probes.append(probe)
if probe.feasible:
if threshold >= best_threshold:
best_threshold = threshold
best_payload = probe.payload
cur_low = threshold
else:
cur_high = threshold
return ThresholdSearchResult(
best_threshold=best_threshold,
best_feasible_payload=best_payload,
probes=probes,
)

80
src/aituner/slo.py Normal file
View File

@@ -0,0 +1,80 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from .spec import SloSpec, ThresholdRule
@dataclass(frozen=True)
class RequestOutcome:
request_id: str
success: bool
ttft_ms: float | None
tpot_ms: float | None
prompt_tokens: int | None
completion_tokens: int | None
error: str = ""
@dataclass(frozen=True)
class RequestEvaluation:
request_id: str
passed: bool
reasons: list[str]
def _rule_threshold_ms(rule: ThresholdRule, prompt_tokens: int | None) -> float:
if rule.kind == "fixed_ms":
assert rule.threshold_ms is not None
return rule.threshold_ms
if rule.kind != "step_ms":
raise ValueError(f"Unsupported threshold rule: {rule.kind}")
prompt = float(prompt_tokens or 0)
for bucket in rule.buckets:
ceiling = bucket.get("max_input_tokens")
if ceiling is None or prompt <= ceiling:
return float(bucket["threshold_ms"])
return float(rule.buckets[-1]["threshold_ms"])
def evaluate_request(outcome: RequestOutcome, slo: SloSpec) -> RequestEvaluation:
reasons: list[str] = []
if not outcome.success:
reasons.append(outcome.error or "request_failed")
return RequestEvaluation(request_id=outcome.request_id, passed=False, reasons=reasons)
if slo.ttft_rule is not None:
if outcome.ttft_ms is None:
reasons.append("ttft_missing")
else:
threshold = _rule_threshold_ms(slo.ttft_rule, outcome.prompt_tokens)
if outcome.ttft_ms > threshold:
reasons.append(f"ttft_ms>{threshold}")
if slo.tpot_rule is not None:
if outcome.tpot_ms is None:
reasons.append("tpot_missing")
else:
threshold = _rule_threshold_ms(slo.tpot_rule, outcome.prompt_tokens)
if outcome.tpot_ms > threshold:
reasons.append(f"tpot_ms>{threshold}")
return RequestEvaluation(
request_id=outcome.request_id,
passed=not reasons,
reasons=reasons,
)
def summarize_evaluations(
outcomes: list[RequestOutcome], slo: SloSpec
) -> tuple[list[RequestEvaluation], dict[str, Any]]:
evaluations = [evaluate_request(item, slo) for item in outcomes]
total = len(evaluations)
passed = sum(1 for item in evaluations if item.passed)
pass_rate = (passed / total) if total else 0.0
return evaluations, {
"request_count": total,
"slo_pass_count": passed,
"slo_pass_rate": pass_rate,
"target_pass_rate": slo.target_pass_rate,
"feasible": pass_rate >= slo.target_pass_rate,
}

440
src/aituner/spec.py Normal file
View File

@@ -0,0 +1,440 @@
from __future__ import annotations
import json
import tomllib
from dataclasses import asdict, dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Mapping
class SpecError(ValueError):
"""Raised when a structured spec is invalid."""
def _require_mapping(value: Any, *, context: str) -> Mapping[str, Any]:
if not isinstance(value, Mapping):
raise SpecError(f"{context} must be an object.")
return value
def _require_str(value: Any, *, context: str) -> str:
if not isinstance(value, str) or not value.strip():
raise SpecError(f"{context} must be a non-empty string.")
return value.strip()
def _require_float(value: Any, *, context: str) -> float:
if isinstance(value, bool) or not isinstance(value, (int, float)):
raise SpecError(f"{context} must be numeric.")
return float(value)
def _require_int(value: Any, *, context: str) -> int:
if isinstance(value, bool) or not isinstance(value, int):
raise SpecError(f"{context} must be an integer.")
return value
def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]:
mapping = _require_mapping(value or {}, context=context)
return {str(key): str(item) for key, item in mapping.items()}
def _coerce_any_map(value: Any, *, context: str) -> dict[str, Any]:
mapping = _require_mapping(value or {}, context=context)
return {str(key): item for key, item in mapping.items()}
def _coerce_str_list(value: Any, *, context: str) -> list[str]:
if value is None:
return []
if not isinstance(value, list):
raise SpecError(f"{context} must be a list.")
result: list[str] = []
for item in value:
result.append(_require_str(item, context=context))
return result
@dataclass(frozen=True)
class HardwareSpec:
gpu_count: int
gpu_model: str | None = None
host_candidates: list[str] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "HardwareSpec":
return cls(
gpu_count=_require_int(data.get("gpu_count"), context="hardware.gpu_count"),
gpu_model=str(data["gpu_model"]).strip() if data.get("gpu_model") else None,
host_candidates=_coerce_str_list(
data.get("host_candidates"), context="hardware.host_candidates"
),
)
@dataclass(frozen=True)
class ModelSpec:
model_id: str
served_model_name: str
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "ModelSpec":
return cls(
model_id=_require_str(data.get("model_id"), context="model.model_id"),
served_model_name=_require_str(
data.get("served_model_name"), context="model.served_model_name"
),
)
@dataclass(frozen=True)
class EngineLaunchSpec:
engine_name: str
engine_version: str | None
exec_path: str
cwd: str | None
host: str
port: int
ready_timeout_s: float
request_timeout_s: float
healthcheck_path: str
launch_args: list[str]
base_envs: dict[str, str]
base_flags: dict[str, Any]
tunable_envs: list[str]
tunable_flags: list[str]
python_executable: str = "python3"
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "EngineLaunchSpec":
return cls(
engine_name=_require_str(data.get("engine_name"), context="engine.engine_name"),
engine_version=str(data["engine_version"]).strip()
if data.get("engine_version")
else None,
exec_path=_require_str(data.get("exec_path"), context="engine.exec_path"),
cwd=str(data["cwd"]).strip() if data.get("cwd") else None,
host=str(data.get("host") or "127.0.0.1").strip(),
port=_require_int(data.get("port", 8000), context="engine.port"),
ready_timeout_s=_require_float(
data.get("ready_timeout_s", 600.0), context="engine.ready_timeout_s"
),
request_timeout_s=_require_float(
data.get("request_timeout_s", 600.0),
context="engine.request_timeout_s",
),
healthcheck_path=str(data.get("healthcheck_path") or "/v1/models").strip(),
launch_args=_coerce_str_list(data.get("launch_args"), context="engine.launch_args"),
base_envs=_coerce_str_map(data.get("base_envs"), context="engine.base_envs"),
base_flags=_coerce_any_map(data.get("base_flags"), context="engine.base_flags"),
tunable_envs=_coerce_str_list(
data.get("tunable_envs"), context="engine.tunable_envs"
),
tunable_flags=_coerce_str_list(
data.get("tunable_flags"), context="engine.tunable_flags"
),
python_executable=str(data.get("python_executable") or "python3").strip(),
)
@dataclass(frozen=True)
class TraceSpec:
windows_path: str
window_id: str
trace_file_override: str | None
u_field: str
timestamp_field: str
max_concurrency: int
max_requests_per_probe: int | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
max_requests = data.get("max_requests_per_probe")
return cls(
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
trace_file_override=str(data["trace_file_override"]).strip()
if data.get("trace_file_override")
else None,
u_field=str(data.get("u_field") or "sampling_u").strip(),
timestamp_field=str(data.get("timestamp_field") or "timestamp").strip(),
max_concurrency=_require_int(
data.get("max_concurrency", 64), context="trace.max_concurrency"
),
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
)
@dataclass(frozen=True)
class ThresholdRule:
kind: str
threshold_ms: float | None = None
buckets: list[dict[str, float]] = field(default_factory=list)
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "ThresholdRule":
kind = _require_str(data.get("kind"), context=f"{context}.kind")
if kind == "fixed_ms":
return cls(
kind=kind,
threshold_ms=_require_float(
data.get("threshold_ms"), context=f"{context}.threshold_ms"
),
)
if kind == "step_ms":
raw = data.get("buckets")
if not isinstance(raw, list) or not raw:
raise SpecError(f"{context}.buckets must be a non-empty list.")
buckets: list[dict[str, float]] = []
for idx, item in enumerate(raw):
mapping = _require_mapping(item, context=f"{context}.buckets[{idx}]")
bucket: dict[str, float] = {
"threshold_ms": _require_float(
mapping.get("threshold_ms"),
context=f"{context}.buckets[{idx}].threshold_ms",
)
}
if "max_input_tokens" in mapping and mapping["max_input_tokens"] is not None:
bucket["max_input_tokens"] = _require_float(
mapping["max_input_tokens"],
context=f"{context}.buckets[{idx}].max_input_tokens",
)
buckets.append(bucket)
return cls(kind=kind, buckets=buckets)
raise SpecError(f"Unsupported threshold rule kind: {kind}")
@dataclass(frozen=True)
class SloSpec:
target_pass_rate: float
ttft_rule: ThresholdRule | None
tpot_rule: ThresholdRule | None
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "SloSpec":
ttft_rule = (
ThresholdRule.from_dict(
_require_mapping(data["ttft_rule"], context="slo.ttft_rule"),
context="slo.ttft_rule",
)
if data.get("ttft_rule")
else None
)
tpot_rule = (
ThresholdRule.from_dict(
_require_mapping(data["tpot_rule"], context="slo.tpot_rule"),
context="slo.tpot_rule",
)
if data.get("tpot_rule")
else None
)
return cls(
target_pass_rate=_require_float(
data.get("target_pass_rate", 0.95), context="slo.target_pass_rate"
),
ttft_rule=ttft_rule,
tpot_rule=tpot_rule,
)
@dataclass(frozen=True)
class SamplingSearchSpec:
low: float
high: float
tolerance: float
max_probes: int
sample_seed: int
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "SamplingSearchSpec":
return cls(
low=_require_float(data.get("low", 0.0), context="search.low"),
high=_require_float(data.get("high", 1.0), context="search.high"),
tolerance=_require_float(
data.get("tolerance", 0.01), context="search.tolerance"
),
max_probes=_require_int(data.get("max_probes", 8), context="search.max_probes"),
sample_seed=_require_int(
data.get("sample_seed", 20260325), context="search.sample_seed"
),
)
@dataclass(frozen=True)
class LLMEndpointSpec:
base_url: str
model: str
api_key_env: str = "OPENAI_API_KEY"
timeout_s: float = 120.0
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
return cls(
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"),
model=_require_str(data.get("model"), context="llm.endpoint.model"),
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(),
timeout_s=_require_float(
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
),
)
@dataclass(frozen=True)
class LLMPolicySpec:
endpoint: LLMEndpointSpec | None
system_prompt: str
max_history_trials: int
@classmethod
def from_dict(cls, data: Mapping[str, Any] | None) -> "LLMPolicySpec":
payload = _require_mapping(data or {}, context="llm")
endpoint = (
LLMEndpointSpec.from_dict(
_require_mapping(payload["endpoint"], context="llm.endpoint")
)
if payload.get("endpoint")
else None
)
return cls(
endpoint=endpoint,
system_prompt=str(payload.get("system_prompt") or "").strip(),
max_history_trials=_require_int(
payload.get("max_history_trials", 8), context="llm.max_history_trials"
),
)
@dataclass(frozen=True)
class StudySpec:
study_id: str
hardware: HardwareSpec
model: ModelSpec
engine: EngineLaunchSpec
trace: TraceSpec
slo: SloSpec
search: SamplingSearchSpec
llm: LLMPolicySpec
capability_profile_path: str | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
return cls(
study_id=_require_str(data.get("study_id"), context="study_id"),
hardware=HardwareSpec.from_dict(
_require_mapping(data.get("hardware"), context="hardware")
),
model=ModelSpec.from_dict(_require_mapping(data.get("model"), context="model")),
engine=EngineLaunchSpec.from_dict(
_require_mapping(data.get("engine"), context="engine")
),
trace=TraceSpec.from_dict(_require_mapping(data.get("trace"), context="trace")),
slo=SloSpec.from_dict(_require_mapping(data.get("slo"), context="slo")),
search=SamplingSearchSpec.from_dict(
_require_mapping(data.get("search"), context="search")
),
llm=LLMPolicySpec.from_dict(data.get("llm")),
capability_profile_path=str(data["capability_profile_path"]).strip()
if data.get("capability_profile_path")
else None,
)
@dataclass(frozen=True)
class ConfigPatch:
env_patch: dict[str, str] = field(default_factory=dict)
flag_patch: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "ConfigPatch":
return cls(
env_patch=_coerce_str_map(data.get("env_patch"), context="config_patch.env_patch"),
flag_patch=_coerce_any_map(
data.get("flag_patch"), context="config_patch.flag_patch"
),
)
@dataclass(frozen=True)
class Proposal:
observation: str
diagnosis: str
config_patch: ConfigPatch
expected_effects: list[str]
why_not_previous_failures: str = ""
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
return cls(
observation=_require_str(data.get("observation"), context="proposal.observation"),
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
config_patch=ConfigPatch.from_dict(
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
),
expected_effects=_coerce_str_list(
data.get("expected_effects"), context="proposal.expected_effects"
),
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
)
@dataclass(frozen=True)
class TrialSpec:
study_id: str
trial_id: str
config_patch: ConfigPatch
search: SamplingSearchSpec
study_spec_path: str
artifact_dir: str
probe_log_path: str
engine_log_path: str
result_path: str
@dataclass
class TrialSummary:
trial_id: str
status: str
best_sampling_u: float | None = None
best_request_rate: float | None = None
best_pass_rate: float | None = None
result_path: str | None = None
diagnosis: str = ""
@dataclass
class StudyState:
study_id: str
best_trial_id: str | None = None
best_request_rate: float | None = None
next_trial_index: int = 1
trials: list[TrialSummary] = field(default_factory=list)
def to_jsonable(value: Any) -> Any:
if is_dataclass(value):
return {key: to_jsonable(item) for key, item in asdict(value).items()}
if isinstance(value, dict):
return {str(key): to_jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [to_jsonable(item) for item in value]
return value
def load_structured_file(path: Path) -> Mapping[str, Any]:
suffix = path.suffix.lower()
if suffix == ".json":
payload = json.loads(path.read_text(encoding="utf-8"))
elif suffix in {".toml", ".tml"}:
payload = tomllib.loads(path.read_text(encoding="utf-8"))
else:
raise SpecError(f"Unsupported spec file type: {path}")
return _require_mapping(payload, context=str(path))
def load_study_spec(path: Path) -> StudySpec:
return StudySpec.from_dict(load_structured_file(path))

118
src/aituner/store.py Normal file
View File

@@ -0,0 +1,118 @@
from __future__ import annotations
import json
from dataclasses import replace
from pathlib import Path
from typing import Any
from .spec import Proposal, StudySpec, StudyState, TrialSpec, TrialSummary, to_jsonable
class StudyStore:
def __init__(self, root: Path | None = None):
base = root or Path(".aituner") / "studies"
self.root = base.resolve()
def study_root(self, study_id: str) -> Path:
return self.root / study_id
def init_study(self, *, spec_path: Path, study: StudySpec) -> Path:
root = self.study_root(study.study_id)
for rel in ("prompts", "proposals", "trials", "results"):
(root / rel).mkdir(parents=True, exist_ok=True)
(root / "study_spec.source").write_text(str(spec_path.resolve()) + "\n", encoding="utf-8")
self.write_json(root / "study_spec.snapshot.json", to_jsonable(study))
if not (root / "state.json").exists():
self.write_json(root / "state.json", to_jsonable(StudyState(study_id=study.study_id)))
return root
def load_state(self, study_id: str) -> StudyState:
payload = json.loads((self.study_root(study_id) / "state.json").read_text(encoding="utf-8"))
trials = [TrialSummary(**item) for item in payload.get("trials", [])]
return StudyState(
study_id=str(payload["study_id"]),
best_trial_id=payload.get("best_trial_id"),
best_request_rate=payload.get("best_request_rate"),
next_trial_index=int(payload.get("next_trial_index", 1)),
trials=trials,
)
def save_state(self, state: StudyState) -> None:
self.write_json(self.study_root(state.study_id) / "state.json", to_jsonable(state))
def write_prompt(self, study_id: str, prompt_name: str, prompt_text: str) -> Path:
path = self.study_root(study_id) / "prompts" / f"{prompt_name}.txt"
path.write_text(prompt_text, encoding="utf-8")
return path
def write_proposal(self, study_id: str, proposal_name: str, proposal: Proposal) -> Path:
path = self.study_root(study_id) / "proposals" / f"{proposal_name}.json"
self.write_json(path, to_jsonable(proposal))
return path
def materialize_trial(
self,
*,
study: StudySpec,
state: StudyState,
proposal: Proposal,
) -> tuple[TrialSpec, StudyState]:
trial_id = f"trial-{state.next_trial_index:04d}"
trial_root = self.study_root(study.study_id) / "trials" / trial_id
trial_root.mkdir(parents=True, exist_ok=True)
spec = TrialSpec(
study_id=study.study_id,
trial_id=trial_id,
config_patch=proposal.config_patch,
search=study.search,
study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()),
artifact_dir=str(trial_root),
probe_log_path=str(trial_root / "probe_history.json"),
engine_log_path=str(trial_root / "engine.log"),
result_path=str(trial_root / "result.json"),
)
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
next_state = replace(state, next_trial_index=state.next_trial_index + 1)
next_state.trials.append(
TrialSummary(trial_id=trial_id, status="queued", diagnosis=proposal.diagnosis)
)
self.save_state(next_state)
return spec, next_state
def ingest_trial_results(self, study_id: str) -> StudyState:
state = self.load_state(study_id)
by_id = {item.trial_id: item for item in state.trials}
trials_dir = self.study_root(study_id) / "trials"
best_trial_id = state.best_trial_id
best_rate = state.best_request_rate
for trial_dir in sorted(trials_dir.glob("trial-*")):
result_path = trial_dir / "result.json"
if not result_path.exists():
continue
payload = json.loads(result_path.read_text(encoding="utf-8"))
trial_id = str(payload["trial_id"])
summary = by_id.get(trial_id)
if summary is None:
summary = TrialSummary(trial_id=trial_id, status="unknown")
state.trials.append(summary)
by_id[trial_id] = summary
summary.status = str(payload.get("status") or "completed")
summary.best_sampling_u = payload.get("best_sampling_u")
summary.best_request_rate = payload.get("best_request_rate")
summary.best_pass_rate = payload.get("best_pass_rate")
summary.result_path = str(result_path)
if (
isinstance(summary.best_request_rate, (int, float))
and (best_rate is None or summary.best_request_rate > best_rate)
):
best_rate = float(summary.best_request_rate)
best_trial_id = trial_id
state.best_request_rate = best_rate
state.best_trial_id = best_trial_id
self.save_state(state)
return state
@staticmethod
def write_json(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")

178
src/aituner/trace.py Normal file
View File

@@ -0,0 +1,178 @@
from __future__ import annotations
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Mapping
from .spec import StudySpec
class TraceError(ValueError):
"""Raised when trace assets are invalid."""
def _percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
ordered = sorted(values)
idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1))
return ordered[idx]
@dataclass(frozen=True)
class WindowRecord:
window_id: str
trace_path: Path
trace_type: str
window_start: float
window_end: float
source_payload: dict[str, Any]
@dataclass(frozen=True)
class TraceRequest:
row_id: str
arrival_s: float
sampling_u: float
body: dict[str, Any]
prompt_tokens_hint: int | None
completion_tokens_hint: int | None
def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord:
windows_path = Path(study.trace.windows_path)
if not windows_path.is_absolute():
windows_path = (study_spec_path.parent / windows_path).resolve()
payload = json.loads(windows_path.read_text(encoding="utf-8"))
windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload
if not isinstance(windows, list):
raise TraceError(f"windows payload must contain a list: {windows_path}")
for item in windows:
if not isinstance(item, Mapping):
continue
if str(item.get("window_id") or "").strip() != study.trace.window_id:
continue
trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip()
if not trace_file:
raise TraceError(f"window {study.trace.window_id} does not define trace_file")
trace_path = Path(trace_file)
if not trace_path.is_absolute():
trace_path = (windows_path.parent / trace_path).resolve()
return WindowRecord(
window_id=study.trace.window_id,
trace_path=trace_path,
trace_type=str(item.get("trace_type") or "chat").strip(),
window_start=float(item.get("window_start") or 0.0),
window_end=float(item.get("window_end") or 0.0),
source_payload={str(key): value for key, value in item.items()},
)
raise TraceError(f"window_id not found: {study.trace.window_id}")
def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
messages = row.get("messages")
if isinstance(messages, list) and messages:
return [dict(item) for item in messages if isinstance(item, Mapping)]
prompt = row.get("prompt") or row.get("input") or row.get("text")
if isinstance(prompt, str) and prompt.strip():
return [{"role": "user", "content": prompt}]
raise TraceError("trace row is missing chat messages/prompt text")
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
value = row.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, float) and value >= 0:
return int(value)
return None
def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None:
for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"):
value = row.get(key)
if isinstance(value, bool):
continue
if isinstance(value, int) and value >= 0:
return value
if isinstance(value, float) and value >= 0:
return int(value)
return None
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
window = resolve_window_record(study, study_spec_path=study_spec_path)
requests: list[TraceRequest] = []
with window.trace_path.open("r", encoding="utf-8") as handle:
for idx, raw in enumerate(handle):
if not raw.strip():
continue
row = json.loads(raw)
if not isinstance(row, Mapping):
continue
timestamp = row.get(study.trace.timestamp_field)
if timestamp is None:
timestamp = row.get("arrival_time", row.get("timestamp"))
if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)):
raise TraceError(f"trace row {idx} is missing numeric timestamp")
sampling_u = row.get(study.trace.u_field, 1.0)
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
body: dict[str, Any] = {
"model": study.model.served_model_name,
"messages": _coerce_messages(row),
"stream": True,
"stream_options": {"include_usage": True},
}
completion_tokens = _coerce_completion_tokens(row)
if completion_tokens is not None:
body["max_tokens"] = completion_tokens
temperature = row.get("temperature")
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):
body["temperature"] = temperature
requests.append(
TraceRequest(
row_id=str(row.get("request_id") or row.get("id") or idx),
arrival_s=float(timestamp),
sampling_u=float(sampling_u),
body=body,
prompt_tokens_hint=_coerce_prompt_tokens(row),
completion_tokens_hint=completion_tokens,
)
)
requests.sort(key=lambda item: item.arrival_s)
if study.trace.max_requests_per_probe is not None:
requests = requests[: study.trace.max_requests_per_probe]
return window, requests
def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]:
prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests]
completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests]
duration = max(window.window_end - window.window_start, 0.0) or (
requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0
)
qps = (len(requests) / duration) if duration > 0 else 0.0
return {
"window_id": window.window_id,
"trace_path": str(window.trace_path),
"trace_type": window.trace_type,
"request_count": len(requests),
"duration_s": duration,
"request_rate": qps,
"prompt_tokens_p50": _percentile(prompt_tokens, 50.0),
"prompt_tokens_p95": _percentile(prompt_tokens, 95.0),
"completion_tokens_p50": _percentile(completion_tokens, 50.0),
"completion_tokens_p95": _percentile(completion_tokens, 95.0),
}
def select_requests_for_threshold(
requests: list[TraceRequest], *, threshold: float
) -> list[TraceRequest]:
return [item for item in requests if item.sampling_u <= threshold]

215
src/aituner/worker.py Normal file
View File

@@ -0,0 +1,215 @@
from __future__ import annotations
import json
import subprocess
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from .engine import build_launch_recipe
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
from .search import ThresholdProbe, binary_search_max_feasible
from .slo import RequestOutcome, summarize_evaluations
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
@dataclass(frozen=True)
class ProbePayload:
threshold: float
request_count: int
pass_rate: float
request_rate: float
feasible: bool
outcomes: list[dict[str, Any]]
def _trial_spec_from_json(path: Path) -> TrialSpec:
payload = json.loads(path.read_text(encoding="utf-8"))
return TrialSpec(
study_id=str(payload["study_id"]),
trial_id=str(payload["trial_id"]),
config_patch=ConfigPatch.from_dict(payload["config_patch"]),
search=SamplingSearchSpec.from_dict(payload["search"]),
study_spec_path=str(payload["study_spec_path"]),
artifact_dir=str(payload["artifact_dir"]),
probe_log_path=str(payload["probe_log_path"]),
engine_log_path=str(payload["engine_log_path"]),
result_path=str(payload["result_path"]),
)
def _run_one_request(
request: TraceRequest,
*,
base_url: str,
timeout_s: float,
) -> RequestOutcome:
try:
metrics = stream_chat_completion(base_url=base_url, body=request.body, timeout_s=timeout_s)
return RequestOutcome(
request_id=request.row_id,
success=True,
ttft_ms=metrics.ttft_ms,
tpot_ms=metrics.tpot_ms,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=metrics.completion_tokens or request.completion_tokens_hint,
)
except HttpClientError as exc:
return RequestOutcome(
request_id=request.row_id,
success=False,
ttft_ms=None,
tpot_ms=None,
prompt_tokens=request.prompt_tokens_hint,
completion_tokens=request.completion_tokens_hint,
error=str(exc),
)
def _replay_requests(
requests: list[TraceRequest],
*,
base_url: str,
timeout_s: float,
max_concurrency: int,
) -> list[RequestOutcome]:
outcomes_by_id: dict[str, RequestOutcome] = {}
lock = threading.Lock()
start = time.monotonic()
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
futures = []
for request in requests:
delay = max(0.0, request.arrival_s)
now = time.monotonic()
sleep_for = (start + delay) - now
if sleep_for > 0:
time.sleep(sleep_for)
futures.append(
pool.submit(
_run_one_request,
request,
base_url=base_url,
timeout_s=timeout_s,
)
)
for future in as_completed(futures):
outcome = future.result()
with lock:
outcomes_by_id[outcome.request_id] = outcome
return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
from .store import StudyStore
trial = _trial_spec_from_json(trial_spec_path)
study_spec_path = Path(Path(trial.study_spec_path).read_text(encoding="utf-8").strip())
study = load_study_spec(study_spec_path)
window, requests = load_trace_requests(study, study_spec_path=study_spec_path)
recipe = build_launch_recipe(study.engine, trial.config_patch)
artifact_dir = Path(trial.artifact_dir)
artifact_dir.mkdir(parents=True, exist_ok=True)
engine_log_path = Path(trial.engine_log_path)
with engine_log_path.open("w", encoding="utf-8") as engine_log:
process = subprocess.Popen( # noqa: S603
recipe.argv,
cwd=recipe.cwd,
env=recipe.env,
stdout=engine_log,
stderr=subprocess.STDOUT,
text=True,
)
try:
wait_for_server(recipe.base_url, recipe.healthcheck_path, recipe.ready_timeout_s)
probe_history: list[dict[str, Any]] = []
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
selected = select_requests_for_threshold(requests, threshold=threshold)
outcomes = _replay_requests(
selected,
base_url=recipe.base_url,
timeout_s=recipe.request_timeout_s,
max_concurrency=study.trace.max_concurrency,
)
evaluations, summary = summarize_evaluations(outcomes, study.slo)
request_rate = (
len(selected) / max(window.window_end - window.window_start, 1e-9)
if selected
else 0.0
)
payload = ProbePayload(
threshold=threshold,
request_count=len(selected),
pass_rate=float(summary["slo_pass_rate"]),
request_rate=request_rate,
feasible=bool(summary["feasible"]),
outcomes=[
{
"request_id": outcome.request_id,
"success": outcome.success,
"ttft_ms": outcome.ttft_ms,
"tpot_ms": outcome.tpot_ms,
"prompt_tokens": outcome.prompt_tokens,
"completion_tokens": outcome.completion_tokens,
"evaluation": evaluation.passed,
"reasons": evaluation.reasons,
}
for outcome, evaluation in zip(outcomes, evaluations)
],
)
probe_record = {
"threshold": threshold,
"request_count": payload.request_count,
"pass_rate": payload.pass_rate,
"request_rate": payload.request_rate,
"feasible": payload.feasible,
}
probe_history.append(probe_record)
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
return ThresholdProbe(
threshold=threshold,
feasible=payload.feasible,
payload=payload,
)
search = binary_search_max_feasible(
low=trial.search.low,
high=trial.search.high,
tolerance=trial.search.tolerance,
max_probes=trial.search.max_probes,
evaluator=evaluator,
)
best = search.best_feasible_payload
result = {
"study_id": trial.study_id,
"trial_id": trial.trial_id,
"status": "completed",
"best_sampling_u": search.best_threshold if best is not None else None,
"best_request_rate": best.request_rate if best is not None else None,
"best_pass_rate": best.pass_rate if best is not None else None,
"best_request_count": best.request_count if best is not None else None,
"probes": [
{
"threshold": probe.threshold,
"feasible": probe.feasible,
"payload": {
"request_count": probe.payload.request_count,
"pass_rate": probe.payload.pass_rate,
"request_rate": probe.payload.request_rate,
},
}
for probe in search.probes
],
}
StudyStore.write_json(Path(trial.result_path), result)
return result
finally:
process.terminate()
try:
process.wait(timeout=30)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=30)

10
tests/conftest.py Normal file
View File

@@ -0,0 +1,10 @@
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))

267
tests/test_core_flow.py Normal file
View File

@@ -0,0 +1,267 @@
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from aituner.job import append_job, build_trial_job
from aituner.llm import build_prompt, parse_proposal_text
from aituner.search import ThresholdProbe, binary_search_max_feasible
from aituner.slo import RequestOutcome, summarize_evaluations
from aituner.spec import Proposal, load_study_spec
from aituner.store import StudyStore
from aituner.trace import load_trace_requests, summarize_window
def _write_study_assets(tmp_path: Path) -> Path:
trace_dir = tmp_path / "trace_windows" / "traces"
trace_dir.mkdir(parents=True)
trace_path = trace_dir / "chat_w1.jsonl"
rows = [
{
"request_id": "r1",
"timestamp": 0.0,
"sampling_u": 0.10,
"messages": [{"role": "user", "content": "hello"}],
"input_length": 1000,
"output_length": 16
},
{
"request_id": "r2",
"timestamp": 1.0,
"sampling_u": 0.50,
"messages": [{"role": "user", "content": "world"}],
"input_length": 5000,
"output_length": 32
},
{
"request_id": "r3",
"timestamp": 2.0,
"sampling_u": 0.90,
"messages": [{"role": "user", "content": "!"}],
"input_length": 20000,
"output_length": 64
}
]
with trace_path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row) + "\n")
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_payload = {
"u_field": "sampling_u",
"windows": [
{
"window_id": "chat_w1",
"trace_type": "chat",
"trace_file": "traces/chat_w1.jsonl",
"window_start": 0.0,
"window_end": 10.0
}
]
}
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
capability_path = tmp_path / "capability.json"
capability_path.write_text(
json.dumps({"prefill_service_by_bucket": {"4k": {"tp4_ms": 320, "tp8_ms": 240}}}),
encoding="utf-8",
)
study_path = tmp_path / "study.json"
study_payload = {
"study_id": "study-1",
"hardware": {"gpu_count": 8, "gpu_model": "H20", "host_candidates": ["dash0"]},
"model": {
"model_id": "qwen",
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
},
"engine": {
"engine_name": "vllm",
"engine_version": "0.1",
"exec_path": "/usr/local/bin/vllm",
"cwd": str(tmp_path),
"host": "127.0.0.1",
"port": 8000,
"healthcheck_path": "/v1/models",
"ready_timeout_s": 30,
"request_timeout_s": 30,
"launch_args": ["serve", "/models/qwen"],
"base_envs": {"BASE_ENV": "1"},
"base_flags": {"host": "127.0.0.1", "port": 8000},
"tunable_envs": ["VLLM_ATTENTION_BACKEND"],
"tunable_flags": ["tensor-parallel-size", "max-num-seqs"],
"python_executable": "python3"
},
"trace": {
"windows_path": str(windows_path),
"window_id": "chat_w1",
"u_field": "sampling_u",
"timestamp_field": "timestamp",
"max_concurrency": 4
},
"slo": {
"target_pass_rate": 0.95,
"ttft_rule": {
"kind": "step_ms",
"buckets": [
{"max_input_tokens": 4096, "threshold_ms": 2000},
{"max_input_tokens": 16384, "threshold_ms": 5000},
{"threshold_ms": 9000}
]
},
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 120}
},
"search": {
"low": 0.0,
"high": 1.0,
"tolerance": 0.01,
"max_probes": 8,
"sample_seed": 20260325
},
"llm": {"system_prompt": "Tune it.", "max_history_trials": 8},
"capability_profile_path": str(capability_path)
}
study_path.write_text(json.dumps(study_payload), encoding="utf-8")
return study_path
class CoreFlowTests(unittest.TestCase):
def test_trace_and_prompt_flow(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
study_root = store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
window, requests = load_trace_requests(study, study_spec_path=study_path)
summary = summarize_window(requests, window)
self.assertEqual(summary["request_count"], 3)
self.assertEqual(summary["request_rate"], 0.3)
prompt = build_prompt(
study=study,
window_summary=summary,
state=state,
capability_profile={"queueing_knee_by_bucket": {"4k": 1000}},
)
self.assertIn("allowed_flag_keys", prompt)
self.assertIn("study-1", prompt)
self.assertIn("queueing_knee_by_bucket", prompt)
self.assertTrue(study_root.exists())
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
study = load_study_spec(_write_study_assets(Path(tmp)))
outcomes = [
RequestOutcome(
request_id="r1",
success=True,
ttft_ms=1000,
tpot_ms=100,
prompt_tokens=1000,
completion_tokens=16,
),
RequestOutcome(
request_id="r2",
success=True,
ttft_ms=6000,
tpot_ms=100,
prompt_tokens=5000,
completion_tokens=16,
),
]
evaluations, summary = summarize_evaluations(outcomes, study.slo)
self.assertTrue(evaluations[0].passed)
self.assertFalse(evaluations[1].passed)
self.assertEqual(summary["slo_pass_rate"], 0.5)
self.assertFalse(summary["feasible"])
def test_binary_search_max_feasible(self) -> None:
result = binary_search_max_feasible(
low=0.0,
high=1.0,
tolerance=0.01,
max_probes=8,
evaluator=lambda threshold: ThresholdProbe(
threshold=threshold,
feasible=threshold <= 0.625,
payload={"threshold": threshold},
),
)
self.assertLessEqual(result.best_threshold, 0.625)
self.assertGreaterEqual(result.best_threshold, 0.5)
self.assertIsNotNone(result.best_feasible_payload)
def test_proposal_validation_and_job_emission(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal_text = json.dumps(
{
"observation": "Current TTFT fails before TPOT.",
"diagnosis": "Prefill pressure dominates.",
"config_patch": {
"env_patch": {"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
"flag_patch": {"tensor-parallel-size": 4, "max-num-seqs": 64}
},
"expected_effects": ["lower TTFT", "raise feasible sampling_u"],
"why_not_previous_failures": "Avoids changing unsupported envs."
}
)
proposal = parse_proposal_text(proposal_text, study)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
job = build_trial_job(study=study, trial=trial, repo_root=tmp_path)
jobs_path = tmp_path / "jobs.toml"
append_job(jobs_path, job)
rendered = jobs_path.read_text(encoding="utf-8")
self.assertIn('name = "study-1-trial-0001"', rendered)
self.assertIn('command = "python3 -m aituner.cli worker run-trial', rendered)
self.assertIn('PYTHONPATH = "src"', rendered)
def test_ingest_trial_results_updates_best(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
study = load_study_spec(study_path)
store = StudyStore(tmp_path / ".aituner" / "studies")
store.init_study(spec_path=study_path, study=study)
state = store.load_state(study.study_id)
proposal = Proposal.from_dict(
{
"observation": "Obs",
"diagnosis": "Diag",
"config_patch": {"env_patch": {}, "flag_patch": {"tensor-parallel-size": 4}},
"expected_effects": ["raise rate"]
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
Path(trial.result_path).write_text(
json.dumps(
{
"study_id": study.study_id,
"trial_id": trial.trial_id,
"status": "completed",
"best_sampling_u": 0.75,
"best_request_rate": 12.5,
"best_pass_rate": 0.97
}
),
encoding="utf-8",
)
next_state = store.ingest_trial_results(study.study_id)
self.assertEqual(next_state.best_trial_id, trial.trial_id)
self.assertEqual(next_state.best_request_rate, 12.5)
if __name__ == "__main__":
unittest.main()