Initial AITuner study orchestrator
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
.aituner/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
infra/gpu_fleet/config/fleet.toml
|
||||||
|
infra/gpu_fleet/config/jobs.toml
|
||||||
14
configs/examples/capability.example.json
Normal file
14
configs/examples/capability.example.json
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"prefill_service_by_bucket": {
|
||||||
|
"4k": {
|
||||||
|
"tp4_ms": 320,
|
||||||
|
"tp8_ms": 240
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"queueing_knee_by_bucket": {
|
||||||
|
"4k": {
|
||||||
|
"tp4_tok_s_per_gpu": 1000,
|
||||||
|
"tp8_tok_s_per_gpu": 1100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
96
configs/examples/study.example.json
Normal file
96
configs/examples/study.example.json
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
{
|
||||||
|
"study_id": "example-chat-window",
|
||||||
|
"hardware": {
|
||||||
|
"gpu_count": 8,
|
||||||
|
"gpu_model": "H20",
|
||||||
|
"host_candidates": ["dash0", "dash1"]
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"model_id": "qwen3-30b",
|
||||||
|
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||||
|
},
|
||||||
|
"engine": {
|
||||||
|
"engine_name": "vllm",
|
||||||
|
"engine_version": "0.x",
|
||||||
|
"exec_path": "/usr/local/bin/vllm",
|
||||||
|
"cwd": ".",
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8000,
|
||||||
|
"healthcheck_path": "/v1/models",
|
||||||
|
"ready_timeout_s": 600,
|
||||||
|
"request_timeout_s": 600,
|
||||||
|
"launch_args": [
|
||||||
|
"serve",
|
||||||
|
"/path/to/model"
|
||||||
|
],
|
||||||
|
"base_envs": {},
|
||||||
|
"base_flags": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8000,
|
||||||
|
"served-model-name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||||
|
},
|
||||||
|
"tunable_envs": [
|
||||||
|
"VLLM_ATTENTION_BACKEND",
|
||||||
|
"CUDA_GRAPH_MAX_BATCH_SIZE"
|
||||||
|
],
|
||||||
|
"tunable_flags": [
|
||||||
|
"tensor-parallel-size",
|
||||||
|
"data-parallel-size",
|
||||||
|
"pipeline-parallel-size",
|
||||||
|
"max-num-seqs",
|
||||||
|
"max-num-batched-tokens",
|
||||||
|
"gpu-memory-utilization",
|
||||||
|
"enable-prefix-caching",
|
||||||
|
"block-size"
|
||||||
|
],
|
||||||
|
"python_executable": "python3"
|
||||||
|
},
|
||||||
|
"trace": {
|
||||||
|
"windows_path": "trace_windows/windows.json",
|
||||||
|
"window_id": "chat_w_example_peak_0001",
|
||||||
|
"u_field": "sampling_u",
|
||||||
|
"timestamp_field": "timestamp",
|
||||||
|
"max_concurrency": 64
|
||||||
|
},
|
||||||
|
"slo": {
|
||||||
|
"target_pass_rate": 0.95,
|
||||||
|
"ttft_rule": {
|
||||||
|
"kind": "step_ms",
|
||||||
|
"buckets": [
|
||||||
|
{
|
||||||
|
"max_input_tokens": 4096,
|
||||||
|
"threshold_ms": 2000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"max_input_tokens": 16384,
|
||||||
|
"threshold_ms": 4000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"threshold_ms": 8000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"tpot_rule": {
|
||||||
|
"kind": "fixed_ms",
|
||||||
|
"threshold_ms": 120
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"search": {
|
||||||
|
"low": 0.0,
|
||||||
|
"high": 1.0,
|
||||||
|
"tolerance": 0.01,
|
||||||
|
"max_probes": 8,
|
||||||
|
"sample_seed": 20260325
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
|
||||||
|
"max_history_trials": 8,
|
||||||
|
"endpoint": {
|
||||||
|
"base_url": "https://example-openai-compatible-endpoint",
|
||||||
|
"model": "gpt-4.1-mini",
|
||||||
|
"api_key_env": "OPENAI_API_KEY",
|
||||||
|
"timeout_s": 120
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"capability_profile_path": "capability.example.json"
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
{"request_id":"example-1","timestamp":0.0,"sampling_u":0.10,"messages":[{"role":"user","content":"hello"}],"input_length":512,"output_length":16}
|
||||||
|
{"request_id":"example-2","timestamp":1.0,"sampling_u":0.50,"messages":[{"role":"user","content":"summarize this file"}],"input_length":2048,"output_length":64}
|
||||||
|
{"request_id":"example-3","timestamp":2.5,"sampling_u":0.90,"messages":[{"role":"user","content":"write a longer answer"}],"input_length":8192,"output_length":128}
|
||||||
15
configs/examples/trace_windows/windows.json
Normal file
15
configs/examples/trace_windows/windows.json
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"sample_seed": 20260325,
|
||||||
|
"u_field": "sampling_u",
|
||||||
|
"window_duration_seconds": 10.0,
|
||||||
|
"windows": [
|
||||||
|
{
|
||||||
|
"window_id": "chat_w_example_peak_0001",
|
||||||
|
"trace_type": "chat",
|
||||||
|
"trace_file": "traces/chat_w_example_peak_0001.jsonl",
|
||||||
|
"window_start": 0.0,
|
||||||
|
"window_end": 10.0,
|
||||||
|
"num_requests": 3
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
59
infra/gpu_fleet/config/fleet.example.toml
Normal file
59
infra/gpu_fleet/config/fleet.example.toml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
version = 1
|
||||||
|
|
||||||
|
[paths]
|
||||||
|
state_dir = ".aituner/gpu_fleet/state"
|
||||||
|
artifacts_dir = ".aituner/gpu_fleet/artifacts"
|
||||||
|
|
||||||
|
[ssh]
|
||||||
|
connect_timeout_sec = 10
|
||||||
|
|
||||||
|
[scheduler]
|
||||||
|
gpu_free_memory_mb = 1024
|
||||||
|
gpu_free_utilization_pct = 10
|
||||||
|
prefer_pack = true
|
||||||
|
|
||||||
|
[sync]
|
||||||
|
mode = "rsync"
|
||||||
|
local_path = "."
|
||||||
|
exclude = [
|
||||||
|
".git/",
|
||||||
|
".venv/",
|
||||||
|
".aituner/",
|
||||||
|
"__pycache__/",
|
||||||
|
"*.pyc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[hosts]]
|
||||||
|
name = "dash0"
|
||||||
|
ssh_alias = "dash0"
|
||||||
|
enabled = true
|
||||||
|
sync_remote_path = "~/workspace/aituner"
|
||||||
|
fleet_root = "~/.aituner_gpu_fleet"
|
||||||
|
|
||||||
|
[[hosts]]
|
||||||
|
name = "dash1"
|
||||||
|
ssh_alias = "dash1"
|
||||||
|
enabled = true
|
||||||
|
sync_remote_path = "~/workspace/aituner"
|
||||||
|
fleet_root = "~/.aituner_gpu_fleet"
|
||||||
|
|
||||||
|
[[hosts]]
|
||||||
|
name = "dash2"
|
||||||
|
ssh_alias = "dash2"
|
||||||
|
enabled = true
|
||||||
|
sync_remote_path = "~/workspace/aituner"
|
||||||
|
fleet_root = "~/.aituner_gpu_fleet"
|
||||||
|
|
||||||
|
[[hosts]]
|
||||||
|
name = "dash3"
|
||||||
|
ssh_alias = "dash3"
|
||||||
|
enabled = true
|
||||||
|
sync_remote_path = "~/aituner"
|
||||||
|
fleet_root = "~/.aituner_gpu_fleet"
|
||||||
|
|
||||||
|
[[hosts]]
|
||||||
|
name = "dash5"
|
||||||
|
ssh_alias = "dash5"
|
||||||
|
enabled = true
|
||||||
|
sync_remote_path = "~/workspace/aituner"
|
||||||
|
fleet_root = "~/.aituner_gpu_fleet"
|
||||||
27
infra/gpu_fleet/config/jobs.example.toml
Normal file
27
infra/gpu_fleet/config/jobs.example.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# This file is an append-only queue source for the monitor.
|
||||||
|
# Each job name must stay unique and immutable once appended.
|
||||||
|
version = 1
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "smoke-train-h20-1gpu"
|
||||||
|
gpus = 1
|
||||||
|
gpu_model = "H20"
|
||||||
|
hosts = ["dash0", "dash1", "dash2"]
|
||||||
|
command = "python train.py --config configs/smoke.toml"
|
||||||
|
artifacts = ["outputs/smoke-train-h20-1gpu"]
|
||||||
|
env = { WANDB_MODE = "offline" }
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "eval-5090-4gpu"
|
||||||
|
gpus = 4
|
||||||
|
gpu_model = "5090"
|
||||||
|
hosts = ["dash5"]
|
||||||
|
command = "python eval.py --config configs/eval.toml"
|
||||||
|
artifacts = ["outputs/eval-5090-4gpu", "logs/eval-5090-4gpu.log"]
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "special-dash3-run"
|
||||||
|
gpus = 2
|
||||||
|
hosts = ["dash3"]
|
||||||
|
command = "python benchmark.py --suite long-context"
|
||||||
|
artifacts = ["outputs/special-dash3-run"]
|
||||||
8
infra/gpu_fleet/config/ssh_aliases.example.txt
Normal file
8
infra/gpu_fleet/config/ssh_aliases.example.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# One SSH alias per line.
|
||||||
|
# Lines starting with "#" are ignored.
|
||||||
|
dash0
|
||||||
|
dash1
|
||||||
|
dash2
|
||||||
|
dash3
|
||||||
|
dash5
|
||||||
|
|
||||||
1132
infra/gpu_fleet/gpu_fleet.py
Executable file
1132
infra/gpu_fleet/gpu_fleet.py
Executable file
File diff suppressed because it is too large
Load Diff
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "aituner"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "AITuner study orchestrator for OpenAI-compatible serving engines"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
aituner = "aituner.cli:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
package-dir = {"" = "src"}
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
5
src/aituner/__init__.py
Normal file
5
src/aituner/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""AITuner package."""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cli",
|
||||||
|
]
|
||||||
177
src/aituner/cli.py
Normal file
177
src/aituner/cli.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .job import append_job, build_trial_job
|
||||||
|
from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text
|
||||||
|
from .spec import Proposal, SpecError, load_study_spec
|
||||||
|
from .store import StudyStore
|
||||||
|
from .trace import load_trace_requests, summarize_window
|
||||||
|
from .worker import run_trial
|
||||||
|
|
||||||
|
|
||||||
|
def _study_source_path(study_root: Path) -> Path:
|
||||||
|
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_init(args: argparse.Namespace) -> int:
|
||||||
|
spec_path = Path(args.spec).resolve()
|
||||||
|
study = load_study_spec(spec_path)
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
root = store.init_study(spec_path=spec_path, study=study)
|
||||||
|
print(root)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_prompt(args: argparse.Namespace) -> int:
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
study_root = Path(args.study_root).resolve()
|
||||||
|
source_path = _study_source_path(study_root)
|
||||||
|
study = load_study_spec(source_path)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
capability_profile = load_capability_profile(study, study_spec_path=source_path)
|
||||||
|
window, requests = load_trace_requests(study, study_spec_path=source_path)
|
||||||
|
prompt = build_prompt(
|
||||||
|
study=study,
|
||||||
|
window_summary=summarize_window(requests, window),
|
||||||
|
state=state,
|
||||||
|
capability_profile=capability_profile,
|
||||||
|
)
|
||||||
|
prompt_name = args.prompt_name or f"prompt-{state.next_trial_index:04d}"
|
||||||
|
path = store.write_prompt(study.study_id, prompt_name, prompt)
|
||||||
|
print(path)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_llm_propose(args: argparse.Namespace) -> int:
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
study_root = Path(args.study_root).resolve()
|
||||||
|
source_path = _study_source_path(study_root)
|
||||||
|
study = load_study_spec(source_path)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
capability_profile = load_capability_profile(study, study_spec_path=source_path)
|
||||||
|
window, requests = load_trace_requests(study, study_spec_path=source_path)
|
||||||
|
prompt = build_prompt(
|
||||||
|
study=study,
|
||||||
|
window_summary=summarize_window(requests, window),
|
||||||
|
state=state,
|
||||||
|
capability_profile=capability_profile,
|
||||||
|
)
|
||||||
|
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
|
||||||
|
proposal = parse_proposal_text(proposal_text, study)
|
||||||
|
name = args.proposal_name or f"proposal-{state.next_trial_index:04d}"
|
||||||
|
path = store.write_proposal(study.study_id, name, proposal)
|
||||||
|
print(path)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_register_proposal(args: argparse.Namespace) -> int:
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
study_root = Path(args.study_root).resolve()
|
||||||
|
source_path = _study_source_path(study_root)
|
||||||
|
study = load_study_spec(source_path)
|
||||||
|
proposal = parse_proposal_text(Path(args.proposal_file).read_text(encoding="utf-8"), study)
|
||||||
|
name = args.proposal_name or Path(args.proposal_file).stem
|
||||||
|
path = store.write_proposal(study.study_id, name, proposal)
|
||||||
|
print(path)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_emit_job(args: argparse.Namespace) -> int:
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
study_root = Path(args.study_root).resolve()
|
||||||
|
source_path = _study_source_path(study_root)
|
||||||
|
study = load_study_spec(source_path)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
proposal_text = Path(args.proposal_file).read_text(encoding="utf-8")
|
||||||
|
proposal = parse_proposal_text(proposal_text, study)
|
||||||
|
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
job = build_trial_job(study=study, trial=trial, repo_root=repo_root)
|
||||||
|
append_job(Path(args.jobs_file).resolve(), job)
|
||||||
|
print(trial.trial_id)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_study_ingest(args: argparse.Namespace) -> int:
|
||||||
|
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||||
|
study_root = Path(args.study_root).resolve()
|
||||||
|
study_id = study_root.name
|
||||||
|
state = store.ingest_trial_results(study_id)
|
||||||
|
print(json.dumps({"best_trial_id": state.best_trial_id, "best_request_rate": state.best_request_rate}))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_worker_run_trial(args: argparse.Namespace) -> int:
|
||||||
|
result = run_trial(Path(args.trial_spec).resolve())
|
||||||
|
print(json.dumps(result))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description="AITuner CLI")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
study = subparsers.add_parser("study")
|
||||||
|
study_sub = study.add_subparsers(dest="study_command", required=True)
|
||||||
|
|
||||||
|
init = study_sub.add_parser("init")
|
||||||
|
init.add_argument("--spec", required=True)
|
||||||
|
init.add_argument("--store-root")
|
||||||
|
init.set_defaults(func=cmd_study_init)
|
||||||
|
|
||||||
|
prompt = study_sub.add_parser("prompt")
|
||||||
|
prompt.add_argument("--study-root", required=True)
|
||||||
|
prompt.add_argument("--store-root")
|
||||||
|
prompt.add_argument("--prompt-name")
|
||||||
|
prompt.set_defaults(func=cmd_study_prompt)
|
||||||
|
|
||||||
|
llm_propose = study_sub.add_parser("llm-propose")
|
||||||
|
llm_propose.add_argument("--study-root", required=True)
|
||||||
|
llm_propose.add_argument("--store-root")
|
||||||
|
llm_propose.add_argument("--proposal-name")
|
||||||
|
llm_propose.set_defaults(func=cmd_study_llm_propose)
|
||||||
|
|
||||||
|
register = study_sub.add_parser("register-proposal")
|
||||||
|
register.add_argument("--study-root", required=True)
|
||||||
|
register.add_argument("--store-root")
|
||||||
|
register.add_argument("--proposal-file", required=True)
|
||||||
|
register.add_argument("--proposal-name")
|
||||||
|
register.set_defaults(func=cmd_study_register_proposal)
|
||||||
|
|
||||||
|
emit = study_sub.add_parser("emit-job")
|
||||||
|
emit.add_argument("--study-root", required=True)
|
||||||
|
emit.add_argument("--store-root")
|
||||||
|
emit.add_argument("--proposal-file", required=True)
|
||||||
|
emit.add_argument("--jobs-file", required=True)
|
||||||
|
emit.set_defaults(func=cmd_study_emit_job)
|
||||||
|
|
||||||
|
ingest = study_sub.add_parser("ingest")
|
||||||
|
ingest.add_argument("--study-root", required=True)
|
||||||
|
ingest.add_argument("--store-root")
|
||||||
|
ingest.set_defaults(func=cmd_study_ingest)
|
||||||
|
|
||||||
|
worker = subparsers.add_parser("worker")
|
||||||
|
worker_sub = worker.add_subparsers(dest="worker_command", required=True)
|
||||||
|
run = worker_sub.add_parser("run-trial")
|
||||||
|
run.add_argument("--trial-spec", required=True)
|
||||||
|
run.set_defaults(func=cmd_worker_run_trial)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
parser = build_parser()
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
try:
|
||||||
|
return int(args.func(args))
|
||||||
|
except SpecError as exc:
|
||||||
|
print(str(exc), file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
65
src/aituner/engine.py
Normal file
65
src/aituner/engine.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .spec import ConfigPatch, EngineLaunchSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LaunchRecipe:
|
||||||
|
argv: list[str]
|
||||||
|
env: dict[str, str]
|
||||||
|
cwd: str | None
|
||||||
|
base_url: str
|
||||||
|
healthcheck_path: str
|
||||||
|
ready_timeout_s: float
|
||||||
|
request_timeout_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_flag_name(name: str) -> str:
|
||||||
|
return str(name).strip().replace("_", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_flag_parts(name: str, value: Any) -> list[str]:
|
||||||
|
flag = f"--{_normalize_flag_name(name)}"
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return [flag] if value else [f"--no-{_normalize_flag_name(name)}"]
|
||||||
|
if isinstance(value, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
parts.extend([flag, str(item)])
|
||||||
|
return parts
|
||||||
|
return [flag, str(value)]
|
||||||
|
|
||||||
|
|
||||||
|
def build_launch_recipe(spec: EngineLaunchSpec, patch: ConfigPatch) -> LaunchRecipe:
|
||||||
|
env = dict(os.environ)
|
||||||
|
env.update(spec.base_envs)
|
||||||
|
env.update(patch.env_patch)
|
||||||
|
flags = dict(spec.base_flags)
|
||||||
|
flags.update(patch.flag_patch)
|
||||||
|
argv = [spec.exec_path, *spec.launch_args]
|
||||||
|
for key, value in flags.items():
|
||||||
|
argv.extend(_serialize_flag_parts(key, value))
|
||||||
|
cwd = None
|
||||||
|
if spec.cwd:
|
||||||
|
cwd = str(Path(spec.cwd).expanduser())
|
||||||
|
return LaunchRecipe(
|
||||||
|
argv=argv,
|
||||||
|
env={str(key): str(value) for key, value in env.items()},
|
||||||
|
cwd=cwd,
|
||||||
|
base_url=spec.base_url,
|
||||||
|
healthcheck_path=spec.healthcheck_path,
|
||||||
|
ready_timeout_s=spec.ready_timeout_s,
|
||||||
|
request_timeout_s=spec.request_timeout_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def shell_join(argv: list[str]) -> str:
|
||||||
|
return " ".join(shlex.quote(item) for item in argv)
|
||||||
147
src/aituner/http_client.py
Normal file
147
src/aituner/http_client.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClientError(RuntimeError):
|
||||||
|
"""Raised for HTTP client failures."""
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if api_key_env:
|
||||||
|
api_key = os.environ.get(api_key_env)
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
||||||
|
deadline = time.monotonic() + timeout_s
|
||||||
|
url = f"{base_url.rstrip('/')}{path}"
|
||||||
|
last_error = "server_not_ready"
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
|
||||||
|
with urllib.request.urlopen(request, timeout=5) as response:
|
||||||
|
if 200 <= response.status < 500:
|
||||||
|
return
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
last_error = str(exc)
|
||||||
|
time.sleep(1.0)
|
||||||
|
raise HttpClientError(f"Timed out waiting for {url}: {last_error}")
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
api_key_env: str | None,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
timeout_s: float,
|
||||||
|
system_prompt: str = "",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
payload: dict[str, Any] = {"model": model, "messages": messages}
|
||||||
|
if system_prompt:
|
||||||
|
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
request = urllib.request.Request(
|
||||||
|
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||||
|
headers=_auth_headers(api_key_env),
|
||||||
|
data=data,
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
except urllib.error.HTTPError as exc:
|
||||||
|
detail = exc.read().decode("utf-8", errors="replace")
|
||||||
|
raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StreamMetrics:
|
||||||
|
ttft_ms: float | None
|
||||||
|
tpot_ms: float | None
|
||||||
|
completion_tokens: int | None
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chat_completion(
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
body: dict[str, Any],
|
||||||
|
timeout_s: float,
|
||||||
|
) -> StreamMetrics:
|
||||||
|
data = json.dumps(body).encode("utf-8")
|
||||||
|
request = urllib.request.Request(
|
||||||
|
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||||
|
headers=_auth_headers(None),
|
||||||
|
data=data,
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
start = time.monotonic()
|
||||||
|
first_token_at: float | None = None
|
||||||
|
last_token_at: float | None = None
|
||||||
|
chunk_token_count = 0
|
||||||
|
completion_tokens: int | None = None
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||||
|
for raw in _iter_sse_lines(response):
|
||||||
|
if raw == "[DONE]":
|
||||||
|
break
|
||||||
|
payload = json.loads(raw)
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
continue
|
||||||
|
usage = payload.get("usage")
|
||||||
|
if isinstance(usage, dict):
|
||||||
|
comp = usage.get("completion_tokens")
|
||||||
|
if isinstance(comp, int) and comp >= 0:
|
||||||
|
completion_tokens = comp
|
||||||
|
choices = payload.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
continue
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
if not isinstance(delta, dict):
|
||||||
|
continue
|
||||||
|
content = delta.get("content")
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
now = time.monotonic()
|
||||||
|
if first_token_at is None:
|
||||||
|
first_token_at = now
|
||||||
|
last_token_at = now
|
||||||
|
chunk_token_count += 1
|
||||||
|
except urllib.error.HTTPError as exc:
|
||||||
|
detail = exc.read().decode("utf-8", errors="replace")
|
||||||
|
raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc
|
||||||
|
ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0
|
||||||
|
used_tokens = completion_tokens if completion_tokens is not None else chunk_token_count
|
||||||
|
if (
|
||||||
|
first_token_at is None
|
||||||
|
or last_token_at is None
|
||||||
|
or used_tokens is None
|
||||||
|
or used_tokens <= 1
|
||||||
|
):
|
||||||
|
tpot_ms = None
|
||||||
|
else:
|
||||||
|
tpot_ms = ((last_token_at - first_token_at) / max(used_tokens - 1, 1)) * 1000.0
|
||||||
|
return StreamMetrics(
|
||||||
|
ttft_ms=ttft_ms,
|
||||||
|
tpot_ms=tpot_ms,
|
||||||
|
completion_tokens=used_tokens if used_tokens > 0 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_sse_lines(response: Any) -> Iterable[str]:
|
||||||
|
for raw in response:
|
||||||
|
line = raw.decode("utf-8", errors="replace").strip()
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
payload = line[len("data:") :].strip()
|
||||||
|
if payload:
|
||||||
|
yield payload
|
||||||
75
src/aituner/job.py
Normal file
75
src/aituner/job.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .spec import StudySpec, TrialSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InfraJob:
|
||||||
|
name: str
|
||||||
|
gpus: int
|
||||||
|
gpu_model: str | None
|
||||||
|
hosts: list[str]
|
||||||
|
command: str
|
||||||
|
artifacts: list[str]
|
||||||
|
env: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
def _toml_scalar(value: Any) -> str:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "true" if value else "false"
|
||||||
|
if isinstance(value, int):
|
||||||
|
return str(value)
|
||||||
|
text = str(value).replace("\\", "\\\\").replace('"', '\\"')
|
||||||
|
return f'"{text}"'
|
||||||
|
|
||||||
|
|
||||||
|
def _toml_list(values: list[Any]) -> str:
|
||||||
|
return "[" + ", ".join(_toml_scalar(item) for item in values) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
def _toml_inline_table(mapping: dict[str, str]) -> str:
|
||||||
|
parts = [f"{key} = {_toml_scalar(value)}" for key, value in sorted(mapping.items())]
|
||||||
|
return "{ " + ", ".join(parts) + " }"
|
||||||
|
|
||||||
|
|
||||||
|
def build_trial_job(*, study: StudySpec, trial: TrialSpec, repo_root: Path) -> InfraJob:
|
||||||
|
trial_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||||
|
rel_trial_path = trial_path.resolve().relative_to(repo_root.resolve())
|
||||||
|
rel_trial_dir = Path(trial.artifact_dir).resolve().relative_to(repo_root.resolve())
|
||||||
|
command = (
|
||||||
|
f"{study.engine.python_executable} -m aituner.cli worker run-trial "
|
||||||
|
f"--trial-spec {rel_trial_path}"
|
||||||
|
)
|
||||||
|
env = {"PYTHONPATH": "src"}
|
||||||
|
return InfraJob(
|
||||||
|
name=f"{study.study_id}-{trial.trial_id}",
|
||||||
|
gpus=study.hardware.gpu_count,
|
||||||
|
gpu_model=study.hardware.gpu_model,
|
||||||
|
hosts=list(study.hardware.host_candidates),
|
||||||
|
command=command,
|
||||||
|
artifacts=[str(rel_trial_dir)],
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def append_job(jobs_path: Path, job: InfraJob) -> None:
|
||||||
|
jobs_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with jobs_path.open("a", encoding="utf-8") as handle:
|
||||||
|
if jobs_path.stat().st_size == 0:
|
||||||
|
handle.write("version = 1\n")
|
||||||
|
handle.write("\n[[jobs]]\n")
|
||||||
|
handle.write(f"name = {_toml_scalar(job.name)}\n")
|
||||||
|
handle.write(f"gpus = {job.gpus}\n")
|
||||||
|
if job.gpu_model:
|
||||||
|
handle.write(f"gpu_model = {_toml_scalar(job.gpu_model)}\n")
|
||||||
|
if job.hosts:
|
||||||
|
handle.write(f"hosts = {_toml_list(job.hosts)}\n")
|
||||||
|
handle.write(f"command = {_toml_scalar(job.command)}\n")
|
||||||
|
if job.artifacts:
|
||||||
|
handle.write(f"artifacts = {_toml_list(job.artifacts)}\n")
|
||||||
|
if job.env:
|
||||||
|
handle.write(f"env = {_toml_inline_table(job.env)}\n")
|
||||||
144
src/aituner/llm.py
Normal file
144
src/aituner/llm.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .http_client import chat_completion
|
||||||
|
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
*,
|
||||||
|
study: StudySpec,
|
||||||
|
window_summary: dict[str, Any],
|
||||||
|
state: StudyState,
|
||||||
|
capability_profile: dict[str, Any] | None,
|
||||||
|
) -> str:
|
||||||
|
history = []
|
||||||
|
for trial in state.trials[-study.llm.max_history_trials :]:
|
||||||
|
history.append(
|
||||||
|
{
|
||||||
|
"trial_id": trial.trial_id,
|
||||||
|
"status": trial.status,
|
||||||
|
"best_sampling_u": trial.best_sampling_u,
|
||||||
|
"best_request_rate": trial.best_request_rate,
|
||||||
|
"best_pass_rate": trial.best_pass_rate,
|
||||||
|
"diagnosis": trial.diagnosis,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
sections = [
|
||||||
|
"You are tuning an OpenAI-compatible serving engine.",
|
||||||
|
"Return exactly one JSON object with keys: observation, diagnosis, config_patch, expected_effects, why_not_previous_failures.",
|
||||||
|
"config_patch must contain env_patch and flag_patch.",
|
||||||
|
"Only use allowed tunable env keys and allowed tunable flag keys.",
|
||||||
|
"",
|
||||||
|
"Study stack:",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"study_id": study.study_id,
|
||||||
|
"hardware": {
|
||||||
|
"gpu_count": study.hardware.gpu_count,
|
||||||
|
"gpu_model": study.hardware.gpu_model,
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"model_id": study.model.model_id,
|
||||||
|
"served_model_name": study.model.served_model_name,
|
||||||
|
},
|
||||||
|
"engine": {
|
||||||
|
"engine_name": study.engine.engine_name,
|
||||||
|
"engine_version": study.engine.engine_version,
|
||||||
|
"base_flags": study.engine.base_flags,
|
||||||
|
"base_envs": study.engine.base_envs,
|
||||||
|
"allowed_flag_keys": study.engine.tunable_flags,
|
||||||
|
"allowed_env_keys": study.engine.tunable_envs,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
),
|
||||||
|
"",
|
||||||
|
"Window summary:",
|
||||||
|
json.dumps(window_summary, ensure_ascii=False, indent=2),
|
||||||
|
"",
|
||||||
|
"SLO:",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"target_pass_rate": study.slo.target_pass_rate,
|
||||||
|
"ttft_rule": study.slo.ttft_rule,
|
||||||
|
"tpot_rule": study.slo.tpot_rule,
|
||||||
|
},
|
||||||
|
default=lambda value: value.__dict__,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
),
|
||||||
|
"",
|
||||||
|
"Capability profile:",
|
||||||
|
json.dumps(capability_profile or {}, ensure_ascii=False, indent=2),
|
||||||
|
"",
|
||||||
|
"Trial history:",
|
||||||
|
json.dumps(history, ensure_ascii=False, indent=2),
|
||||||
|
"",
|
||||||
|
"The proposal should improve the maximum feasible sampling_u under the 95%+ SLO target.",
|
||||||
|
]
|
||||||
|
return "\n".join(sections)
|
||||||
|
|
||||||
|
|
||||||
|
def load_capability_profile(study: StudySpec, *, study_spec_path: Path) -> dict[str, Any] | None:
|
||||||
|
if not study.capability_profile_path:
|
||||||
|
return None
|
||||||
|
path = Path(study.capability_profile_path)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = (study_spec_path.parent / path).resolve()
|
||||||
|
return json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
|
||||||
|
unknown_envs = sorted(set(proposal.config_patch.env_patch) - set(study.engine.tunable_envs))
|
||||||
|
unknown_flags = sorted(
|
||||||
|
set(proposal.config_patch.flag_patch) - set(study.engine.tunable_flags)
|
||||||
|
)
|
||||||
|
if unknown_envs:
|
||||||
|
raise SpecError(f"Proposal uses unsupported env keys: {', '.join(unknown_envs)}")
|
||||||
|
if unknown_flags:
|
||||||
|
raise SpecError(f"Proposal uses unsupported flag keys: {', '.join(unknown_flags)}")
|
||||||
|
return proposal
|
||||||
|
|
||||||
|
|
||||||
|
def parse_proposal_text(text: str, study: StudySpec) -> Proposal:
|
||||||
|
payload = json.loads(text)
|
||||||
|
proposal = Proposal.from_dict(payload)
|
||||||
|
return validate_proposal(proposal, study)
|
||||||
|
|
||||||
|
|
||||||
|
def call_llm_for_proposal(
|
||||||
|
*,
|
||||||
|
policy: LLMPolicySpec,
|
||||||
|
prompt: str,
|
||||||
|
) -> str:
|
||||||
|
if policy.endpoint is None:
|
||||||
|
raise RuntimeError("study.llm.endpoint is not configured")
|
||||||
|
response = chat_completion(
|
||||||
|
base_url=policy.endpoint.base_url,
|
||||||
|
api_key_env=policy.endpoint.api_key_env,
|
||||||
|
model=policy.endpoint.model,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
timeout_s=policy.endpoint.timeout_s,
|
||||||
|
system_prompt=policy.system_prompt,
|
||||||
|
)
|
||||||
|
choices = response.get("choices")
|
||||||
|
if not isinstance(choices, list) or not choices:
|
||||||
|
raise RuntimeError("LLM response does not contain choices")
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
raise RuntimeError("LLM response does not contain a valid message")
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
return "".join(
|
||||||
|
item.get("text", "")
|
||||||
|
for item in content
|
||||||
|
if isinstance(item, dict) and isinstance(item.get("text"), str)
|
||||||
|
)
|
||||||
|
raise RuntimeError("LLM response content is empty")
|
||||||
58
src/aituner/search.py
Normal file
58
src/aituner/search.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Generic, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ThresholdProbe(Generic[T]):
|
||||||
|
threshold: float
|
||||||
|
feasible: bool
|
||||||
|
payload: T
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ThresholdSearchResult(Generic[T]):
|
||||||
|
best_threshold: float
|
||||||
|
best_feasible_payload: T | None
|
||||||
|
probes: list[ThresholdProbe[T]]
|
||||||
|
|
||||||
|
|
||||||
|
def binary_search_max_feasible(
|
||||||
|
*,
|
||||||
|
low: float,
|
||||||
|
high: float,
|
||||||
|
tolerance: float,
|
||||||
|
max_probes: int,
|
||||||
|
evaluator: Callable[[float], ThresholdProbe[T]],
|
||||||
|
) -> ThresholdSearchResult[T]:
|
||||||
|
probes: list[ThresholdProbe[T]] = []
|
||||||
|
cache: dict[float, ThresholdProbe[T]] = {}
|
||||||
|
best_threshold = low
|
||||||
|
best_payload: T | None = None
|
||||||
|
cur_low = low
|
||||||
|
cur_high = high
|
||||||
|
for _ in range(max_probes):
|
||||||
|
if cur_high - cur_low <= tolerance:
|
||||||
|
break
|
||||||
|
threshold = round((cur_low + cur_high) / 2.0, 12)
|
||||||
|
probe = cache.get(threshold)
|
||||||
|
if probe is None:
|
||||||
|
probe = evaluator(threshold)
|
||||||
|
cache[threshold] = probe
|
||||||
|
probes.append(probe)
|
||||||
|
if probe.feasible:
|
||||||
|
if threshold >= best_threshold:
|
||||||
|
best_threshold = threshold
|
||||||
|
best_payload = probe.payload
|
||||||
|
cur_low = threshold
|
||||||
|
else:
|
||||||
|
cur_high = threshold
|
||||||
|
return ThresholdSearchResult(
|
||||||
|
best_threshold=best_threshold,
|
||||||
|
best_feasible_payload=best_payload,
|
||||||
|
probes=probes,
|
||||||
|
)
|
||||||
80
src/aituner/slo.py
Normal file
80
src/aituner/slo.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .spec import SloSpec, ThresholdRule
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RequestOutcome:
|
||||||
|
request_id: str
|
||||||
|
success: bool
|
||||||
|
ttft_ms: float | None
|
||||||
|
tpot_ms: float | None
|
||||||
|
prompt_tokens: int | None
|
||||||
|
completion_tokens: int | None
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RequestEvaluation:
|
||||||
|
request_id: str
|
||||||
|
passed: bool
|
||||||
|
reasons: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _rule_threshold_ms(rule: ThresholdRule, prompt_tokens: int | None) -> float:
|
||||||
|
if rule.kind == "fixed_ms":
|
||||||
|
assert rule.threshold_ms is not None
|
||||||
|
return rule.threshold_ms
|
||||||
|
if rule.kind != "step_ms":
|
||||||
|
raise ValueError(f"Unsupported threshold rule: {rule.kind}")
|
||||||
|
prompt = float(prompt_tokens or 0)
|
||||||
|
for bucket in rule.buckets:
|
||||||
|
ceiling = bucket.get("max_input_tokens")
|
||||||
|
if ceiling is None or prompt <= ceiling:
|
||||||
|
return float(bucket["threshold_ms"])
|
||||||
|
return float(rule.buckets[-1]["threshold_ms"])
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_request(outcome: RequestOutcome, slo: SloSpec) -> RequestEvaluation:
|
||||||
|
reasons: list[str] = []
|
||||||
|
if not outcome.success:
|
||||||
|
reasons.append(outcome.error or "request_failed")
|
||||||
|
return RequestEvaluation(request_id=outcome.request_id, passed=False, reasons=reasons)
|
||||||
|
if slo.ttft_rule is not None:
|
||||||
|
if outcome.ttft_ms is None:
|
||||||
|
reasons.append("ttft_missing")
|
||||||
|
else:
|
||||||
|
threshold = _rule_threshold_ms(slo.ttft_rule, outcome.prompt_tokens)
|
||||||
|
if outcome.ttft_ms > threshold:
|
||||||
|
reasons.append(f"ttft_ms>{threshold}")
|
||||||
|
if slo.tpot_rule is not None:
|
||||||
|
if outcome.tpot_ms is None:
|
||||||
|
reasons.append("tpot_missing")
|
||||||
|
else:
|
||||||
|
threshold = _rule_threshold_ms(slo.tpot_rule, outcome.prompt_tokens)
|
||||||
|
if outcome.tpot_ms > threshold:
|
||||||
|
reasons.append(f"tpot_ms>{threshold}")
|
||||||
|
return RequestEvaluation(
|
||||||
|
request_id=outcome.request_id,
|
||||||
|
passed=not reasons,
|
||||||
|
reasons=reasons,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_evaluations(
|
||||||
|
outcomes: list[RequestOutcome], slo: SloSpec
|
||||||
|
) -> tuple[list[RequestEvaluation], dict[str, Any]]:
|
||||||
|
evaluations = [evaluate_request(item, slo) for item in outcomes]
|
||||||
|
total = len(evaluations)
|
||||||
|
passed = sum(1 for item in evaluations if item.passed)
|
||||||
|
pass_rate = (passed / total) if total else 0.0
|
||||||
|
return evaluations, {
|
||||||
|
"request_count": total,
|
||||||
|
"slo_pass_count": passed,
|
||||||
|
"slo_pass_rate": pass_rate,
|
||||||
|
"target_pass_rate": slo.target_pass_rate,
|
||||||
|
"feasible": pass_rate >= slo.target_pass_rate,
|
||||||
|
}
|
||||||
440
src/aituner/spec.py
Normal file
440
src/aituner/spec.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tomllib
|
||||||
|
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
|
||||||
|
class SpecError(ValueError):
|
||||||
|
"""Raised when a structured spec is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def _require_mapping(value: Any, *, context: str) -> Mapping[str, Any]:
|
||||||
|
if not isinstance(value, Mapping):
|
||||||
|
raise SpecError(f"{context} must be an object.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _require_str(value: Any, *, context: str) -> str:
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
raise SpecError(f"{context} must be a non-empty string.")
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _require_float(value: Any, *, context: str) -> float:
|
||||||
|
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
||||||
|
raise SpecError(f"{context} must be numeric.")
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_int(value: Any, *, context: str) -> int:
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int):
|
||||||
|
raise SpecError(f"{context} must be an integer.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]:
|
||||||
|
mapping = _require_mapping(value or {}, context=context)
|
||||||
|
return {str(key): str(item) for key, item in mapping.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_any_map(value: Any, *, context: str) -> dict[str, Any]:
|
||||||
|
mapping = _require_mapping(value or {}, context=context)
|
||||||
|
return {str(key): item for key, item in mapping.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_str_list(value: Any, *, context: str) -> list[str]:
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise SpecError(f"{context} must be a list.")
|
||||||
|
result: list[str] = []
|
||||||
|
for item in value:
|
||||||
|
result.append(_require_str(item, context=context))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HardwareSpec:
|
||||||
|
gpu_count: int
|
||||||
|
gpu_model: str | None = None
|
||||||
|
host_candidates: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "HardwareSpec":
|
||||||
|
return cls(
|
||||||
|
gpu_count=_require_int(data.get("gpu_count"), context="hardware.gpu_count"),
|
||||||
|
gpu_model=str(data["gpu_model"]).strip() if data.get("gpu_model") else None,
|
||||||
|
host_candidates=_coerce_str_list(
|
||||||
|
data.get("host_candidates"), context="hardware.host_candidates"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelSpec:
|
||||||
|
model_id: str
|
||||||
|
served_model_name: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "ModelSpec":
|
||||||
|
return cls(
|
||||||
|
model_id=_require_str(data.get("model_id"), context="model.model_id"),
|
||||||
|
served_model_name=_require_str(
|
||||||
|
data.get("served_model_name"), context="model.served_model_name"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EngineLaunchSpec:
|
||||||
|
engine_name: str
|
||||||
|
engine_version: str | None
|
||||||
|
exec_path: str
|
||||||
|
cwd: str | None
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
ready_timeout_s: float
|
||||||
|
request_timeout_s: float
|
||||||
|
healthcheck_path: str
|
||||||
|
launch_args: list[str]
|
||||||
|
base_envs: dict[str, str]
|
||||||
|
base_flags: dict[str, Any]
|
||||||
|
tunable_envs: list[str]
|
||||||
|
tunable_flags: list[str]
|
||||||
|
python_executable: str = "python3"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_url(self) -> str:
|
||||||
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "EngineLaunchSpec":
|
||||||
|
return cls(
|
||||||
|
engine_name=_require_str(data.get("engine_name"), context="engine.engine_name"),
|
||||||
|
engine_version=str(data["engine_version"]).strip()
|
||||||
|
if data.get("engine_version")
|
||||||
|
else None,
|
||||||
|
exec_path=_require_str(data.get("exec_path"), context="engine.exec_path"),
|
||||||
|
cwd=str(data["cwd"]).strip() if data.get("cwd") else None,
|
||||||
|
host=str(data.get("host") or "127.0.0.1").strip(),
|
||||||
|
port=_require_int(data.get("port", 8000), context="engine.port"),
|
||||||
|
ready_timeout_s=_require_float(
|
||||||
|
data.get("ready_timeout_s", 600.0), context="engine.ready_timeout_s"
|
||||||
|
),
|
||||||
|
request_timeout_s=_require_float(
|
||||||
|
data.get("request_timeout_s", 600.0),
|
||||||
|
context="engine.request_timeout_s",
|
||||||
|
),
|
||||||
|
healthcheck_path=str(data.get("healthcheck_path") or "/v1/models").strip(),
|
||||||
|
launch_args=_coerce_str_list(data.get("launch_args"), context="engine.launch_args"),
|
||||||
|
base_envs=_coerce_str_map(data.get("base_envs"), context="engine.base_envs"),
|
||||||
|
base_flags=_coerce_any_map(data.get("base_flags"), context="engine.base_flags"),
|
||||||
|
tunable_envs=_coerce_str_list(
|
||||||
|
data.get("tunable_envs"), context="engine.tunable_envs"
|
||||||
|
),
|
||||||
|
tunable_flags=_coerce_str_list(
|
||||||
|
data.get("tunable_flags"), context="engine.tunable_flags"
|
||||||
|
),
|
||||||
|
python_executable=str(data.get("python_executable") or "python3").strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TraceSpec:
|
||||||
|
windows_path: str
|
||||||
|
window_id: str
|
||||||
|
trace_file_override: str | None
|
||||||
|
u_field: str
|
||||||
|
timestamp_field: str
|
||||||
|
max_concurrency: int
|
||||||
|
max_requests_per_probe: int | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||||
|
max_requests = data.get("max_requests_per_probe")
|
||||||
|
return cls(
|
||||||
|
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
|
||||||
|
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
|
||||||
|
trace_file_override=str(data["trace_file_override"]).strip()
|
||||||
|
if data.get("trace_file_override")
|
||||||
|
else None,
|
||||||
|
u_field=str(data.get("u_field") or "sampling_u").strip(),
|
||||||
|
timestamp_field=str(data.get("timestamp_field") or "timestamp").strip(),
|
||||||
|
max_concurrency=_require_int(
|
||||||
|
data.get("max_concurrency", 64), context="trace.max_concurrency"
|
||||||
|
),
|
||||||
|
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ThresholdRule:
|
||||||
|
kind: str
|
||||||
|
threshold_ms: float | None = None
|
||||||
|
buckets: list[dict[str, float]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "ThresholdRule":
|
||||||
|
kind = _require_str(data.get("kind"), context=f"{context}.kind")
|
||||||
|
if kind == "fixed_ms":
|
||||||
|
return cls(
|
||||||
|
kind=kind,
|
||||||
|
threshold_ms=_require_float(
|
||||||
|
data.get("threshold_ms"), context=f"{context}.threshold_ms"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if kind == "step_ms":
|
||||||
|
raw = data.get("buckets")
|
||||||
|
if not isinstance(raw, list) or not raw:
|
||||||
|
raise SpecError(f"{context}.buckets must be a non-empty list.")
|
||||||
|
buckets: list[dict[str, float]] = []
|
||||||
|
for idx, item in enumerate(raw):
|
||||||
|
mapping = _require_mapping(item, context=f"{context}.buckets[{idx}]")
|
||||||
|
bucket: dict[str, float] = {
|
||||||
|
"threshold_ms": _require_float(
|
||||||
|
mapping.get("threshold_ms"),
|
||||||
|
context=f"{context}.buckets[{idx}].threshold_ms",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if "max_input_tokens" in mapping and mapping["max_input_tokens"] is not None:
|
||||||
|
bucket["max_input_tokens"] = _require_float(
|
||||||
|
mapping["max_input_tokens"],
|
||||||
|
context=f"{context}.buckets[{idx}].max_input_tokens",
|
||||||
|
)
|
||||||
|
buckets.append(bucket)
|
||||||
|
return cls(kind=kind, buckets=buckets)
|
||||||
|
raise SpecError(f"Unsupported threshold rule kind: {kind}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SloSpec:
|
||||||
|
target_pass_rate: float
|
||||||
|
ttft_rule: ThresholdRule | None
|
||||||
|
tpot_rule: ThresholdRule | None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "SloSpec":
|
||||||
|
ttft_rule = (
|
||||||
|
ThresholdRule.from_dict(
|
||||||
|
_require_mapping(data["ttft_rule"], context="slo.ttft_rule"),
|
||||||
|
context="slo.ttft_rule",
|
||||||
|
)
|
||||||
|
if data.get("ttft_rule")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
tpot_rule = (
|
||||||
|
ThresholdRule.from_dict(
|
||||||
|
_require_mapping(data["tpot_rule"], context="slo.tpot_rule"),
|
||||||
|
context="slo.tpot_rule",
|
||||||
|
)
|
||||||
|
if data.get("tpot_rule")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
target_pass_rate=_require_float(
|
||||||
|
data.get("target_pass_rate", 0.95), context="slo.target_pass_rate"
|
||||||
|
),
|
||||||
|
ttft_rule=ttft_rule,
|
||||||
|
tpot_rule=tpot_rule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SamplingSearchSpec:
|
||||||
|
low: float
|
||||||
|
high: float
|
||||||
|
tolerance: float
|
||||||
|
max_probes: int
|
||||||
|
sample_seed: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "SamplingSearchSpec":
|
||||||
|
return cls(
|
||||||
|
low=_require_float(data.get("low", 0.0), context="search.low"),
|
||||||
|
high=_require_float(data.get("high", 1.0), context="search.high"),
|
||||||
|
tolerance=_require_float(
|
||||||
|
data.get("tolerance", 0.01), context="search.tolerance"
|
||||||
|
),
|
||||||
|
max_probes=_require_int(data.get("max_probes", 8), context="search.max_probes"),
|
||||||
|
sample_seed=_require_int(
|
||||||
|
data.get("sample_seed", 20260325), context="search.sample_seed"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMEndpointSpec:
|
||||||
|
base_url: str
|
||||||
|
model: str
|
||||||
|
api_key_env: str = "OPENAI_API_KEY"
|
||||||
|
timeout_s: float = 120.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
||||||
|
return cls(
|
||||||
|
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"),
|
||||||
|
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||||
|
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(),
|
||||||
|
timeout_s=_require_float(
|
||||||
|
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMPolicySpec:
|
||||||
|
endpoint: LLMEndpointSpec | None
|
||||||
|
system_prompt: str
|
||||||
|
max_history_trials: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any] | None) -> "LLMPolicySpec":
|
||||||
|
payload = _require_mapping(data or {}, context="llm")
|
||||||
|
endpoint = (
|
||||||
|
LLMEndpointSpec.from_dict(
|
||||||
|
_require_mapping(payload["endpoint"], context="llm.endpoint")
|
||||||
|
)
|
||||||
|
if payload.get("endpoint")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
endpoint=endpoint,
|
||||||
|
system_prompt=str(payload.get("system_prompt") or "").strip(),
|
||||||
|
max_history_trials=_require_int(
|
||||||
|
payload.get("max_history_trials", 8), context="llm.max_history_trials"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StudySpec:
|
||||||
|
study_id: str
|
||||||
|
hardware: HardwareSpec
|
||||||
|
model: ModelSpec
|
||||||
|
engine: EngineLaunchSpec
|
||||||
|
trace: TraceSpec
|
||||||
|
slo: SloSpec
|
||||||
|
search: SamplingSearchSpec
|
||||||
|
llm: LLMPolicySpec
|
||||||
|
capability_profile_path: str | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
|
||||||
|
return cls(
|
||||||
|
study_id=_require_str(data.get("study_id"), context="study_id"),
|
||||||
|
hardware=HardwareSpec.from_dict(
|
||||||
|
_require_mapping(data.get("hardware"), context="hardware")
|
||||||
|
),
|
||||||
|
model=ModelSpec.from_dict(_require_mapping(data.get("model"), context="model")),
|
||||||
|
engine=EngineLaunchSpec.from_dict(
|
||||||
|
_require_mapping(data.get("engine"), context="engine")
|
||||||
|
),
|
||||||
|
trace=TraceSpec.from_dict(_require_mapping(data.get("trace"), context="trace")),
|
||||||
|
slo=SloSpec.from_dict(_require_mapping(data.get("slo"), context="slo")),
|
||||||
|
search=SamplingSearchSpec.from_dict(
|
||||||
|
_require_mapping(data.get("search"), context="search")
|
||||||
|
),
|
||||||
|
llm=LLMPolicySpec.from_dict(data.get("llm")),
|
||||||
|
capability_profile_path=str(data["capability_profile_path"]).strip()
|
||||||
|
if data.get("capability_profile_path")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConfigPatch:
|
||||||
|
env_patch: dict[str, str] = field(default_factory=dict)
|
||||||
|
flag_patch: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "ConfigPatch":
|
||||||
|
return cls(
|
||||||
|
env_patch=_coerce_str_map(data.get("env_patch"), context="config_patch.env_patch"),
|
||||||
|
flag_patch=_coerce_any_map(
|
||||||
|
data.get("flag_patch"), context="config_patch.flag_patch"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Proposal:
|
||||||
|
observation: str
|
||||||
|
diagnosis: str
|
||||||
|
config_patch: ConfigPatch
|
||||||
|
expected_effects: list[str]
|
||||||
|
why_not_previous_failures: str = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
|
||||||
|
return cls(
|
||||||
|
observation=_require_str(data.get("observation"), context="proposal.observation"),
|
||||||
|
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
|
||||||
|
config_patch=ConfigPatch.from_dict(
|
||||||
|
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
|
||||||
|
),
|
||||||
|
expected_effects=_coerce_str_list(
|
||||||
|
data.get("expected_effects"), context="proposal.expected_effects"
|
||||||
|
),
|
||||||
|
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrialSpec:
|
||||||
|
study_id: str
|
||||||
|
trial_id: str
|
||||||
|
config_patch: ConfigPatch
|
||||||
|
search: SamplingSearchSpec
|
||||||
|
study_spec_path: str
|
||||||
|
artifact_dir: str
|
||||||
|
probe_log_path: str
|
||||||
|
engine_log_path: str
|
||||||
|
result_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrialSummary:
|
||||||
|
trial_id: str
|
||||||
|
status: str
|
||||||
|
best_sampling_u: float | None = None
|
||||||
|
best_request_rate: float | None = None
|
||||||
|
best_pass_rate: float | None = None
|
||||||
|
result_path: str | None = None
|
||||||
|
diagnosis: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StudyState:
|
||||||
|
study_id: str
|
||||||
|
best_trial_id: str | None = None
|
||||||
|
best_request_rate: float | None = None
|
||||||
|
next_trial_index: int = 1
|
||||||
|
trials: list[TrialSummary] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def to_jsonable(value: Any) -> Any:
|
||||||
|
if is_dataclass(value):
|
||||||
|
return {key: to_jsonable(item) for key, item in asdict(value).items()}
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {str(key): to_jsonable(item) for key, item in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [to_jsonable(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def load_structured_file(path: Path) -> Mapping[str, Any]:
|
||||||
|
suffix = path.suffix.lower()
|
||||||
|
if suffix == ".json":
|
||||||
|
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
elif suffix in {".toml", ".tml"}:
|
||||||
|
payload = tomllib.loads(path.read_text(encoding="utf-8"))
|
||||||
|
else:
|
||||||
|
raise SpecError(f"Unsupported spec file type: {path}")
|
||||||
|
return _require_mapping(payload, context=str(path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_study_spec(path: Path) -> StudySpec:
|
||||||
|
return StudySpec.from_dict(load_structured_file(path))
|
||||||
118
src/aituner/store.py
Normal file
118
src/aituner/store.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import replace
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .spec import Proposal, StudySpec, StudyState, TrialSpec, TrialSummary, to_jsonable
|
||||||
|
|
||||||
|
|
||||||
|
class StudyStore:
|
||||||
|
def __init__(self, root: Path | None = None):
|
||||||
|
base = root or Path(".aituner") / "studies"
|
||||||
|
self.root = base.resolve()
|
||||||
|
|
||||||
|
def study_root(self, study_id: str) -> Path:
|
||||||
|
return self.root / study_id
|
||||||
|
|
||||||
|
def init_study(self, *, spec_path: Path, study: StudySpec) -> Path:
|
||||||
|
root = self.study_root(study.study_id)
|
||||||
|
for rel in ("prompts", "proposals", "trials", "results"):
|
||||||
|
(root / rel).mkdir(parents=True, exist_ok=True)
|
||||||
|
(root / "study_spec.source").write_text(str(spec_path.resolve()) + "\n", encoding="utf-8")
|
||||||
|
self.write_json(root / "study_spec.snapshot.json", to_jsonable(study))
|
||||||
|
if not (root / "state.json").exists():
|
||||||
|
self.write_json(root / "state.json", to_jsonable(StudyState(study_id=study.study_id)))
|
||||||
|
return root
|
||||||
|
|
||||||
|
def load_state(self, study_id: str) -> StudyState:
|
||||||
|
payload = json.loads((self.study_root(study_id) / "state.json").read_text(encoding="utf-8"))
|
||||||
|
trials = [TrialSummary(**item) for item in payload.get("trials", [])]
|
||||||
|
return StudyState(
|
||||||
|
study_id=str(payload["study_id"]),
|
||||||
|
best_trial_id=payload.get("best_trial_id"),
|
||||||
|
best_request_rate=payload.get("best_request_rate"),
|
||||||
|
next_trial_index=int(payload.get("next_trial_index", 1)),
|
||||||
|
trials=trials,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_state(self, state: StudyState) -> None:
|
||||||
|
self.write_json(self.study_root(state.study_id) / "state.json", to_jsonable(state))
|
||||||
|
|
||||||
|
def write_prompt(self, study_id: str, prompt_name: str, prompt_text: str) -> Path:
|
||||||
|
path = self.study_root(study_id) / "prompts" / f"{prompt_name}.txt"
|
||||||
|
path.write_text(prompt_text, encoding="utf-8")
|
||||||
|
return path
|
||||||
|
|
||||||
|
def write_proposal(self, study_id: str, proposal_name: str, proposal: Proposal) -> Path:
|
||||||
|
path = self.study_root(study_id) / "proposals" / f"{proposal_name}.json"
|
||||||
|
self.write_json(path, to_jsonable(proposal))
|
||||||
|
return path
|
||||||
|
|
||||||
|
def materialize_trial(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
study: StudySpec,
|
||||||
|
state: StudyState,
|
||||||
|
proposal: Proposal,
|
||||||
|
) -> tuple[TrialSpec, StudyState]:
|
||||||
|
trial_id = f"trial-{state.next_trial_index:04d}"
|
||||||
|
trial_root = self.study_root(study.study_id) / "trials" / trial_id
|
||||||
|
trial_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
spec = TrialSpec(
|
||||||
|
study_id=study.study_id,
|
||||||
|
trial_id=trial_id,
|
||||||
|
config_patch=proposal.config_patch,
|
||||||
|
search=study.search,
|
||||||
|
study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()),
|
||||||
|
artifact_dir=str(trial_root),
|
||||||
|
probe_log_path=str(trial_root / "probe_history.json"),
|
||||||
|
engine_log_path=str(trial_root / "engine.log"),
|
||||||
|
result_path=str(trial_root / "result.json"),
|
||||||
|
)
|
||||||
|
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
|
||||||
|
next_state = replace(state, next_trial_index=state.next_trial_index + 1)
|
||||||
|
next_state.trials.append(
|
||||||
|
TrialSummary(trial_id=trial_id, status="queued", diagnosis=proposal.diagnosis)
|
||||||
|
)
|
||||||
|
self.save_state(next_state)
|
||||||
|
return spec, next_state
|
||||||
|
|
||||||
|
def ingest_trial_results(self, study_id: str) -> StudyState:
|
||||||
|
state = self.load_state(study_id)
|
||||||
|
by_id = {item.trial_id: item for item in state.trials}
|
||||||
|
trials_dir = self.study_root(study_id) / "trials"
|
||||||
|
best_trial_id = state.best_trial_id
|
||||||
|
best_rate = state.best_request_rate
|
||||||
|
for trial_dir in sorted(trials_dir.glob("trial-*")):
|
||||||
|
result_path = trial_dir / "result.json"
|
||||||
|
if not result_path.exists():
|
||||||
|
continue
|
||||||
|
payload = json.loads(result_path.read_text(encoding="utf-8"))
|
||||||
|
trial_id = str(payload["trial_id"])
|
||||||
|
summary = by_id.get(trial_id)
|
||||||
|
if summary is None:
|
||||||
|
summary = TrialSummary(trial_id=trial_id, status="unknown")
|
||||||
|
state.trials.append(summary)
|
||||||
|
by_id[trial_id] = summary
|
||||||
|
summary.status = str(payload.get("status") or "completed")
|
||||||
|
summary.best_sampling_u = payload.get("best_sampling_u")
|
||||||
|
summary.best_request_rate = payload.get("best_request_rate")
|
||||||
|
summary.best_pass_rate = payload.get("best_pass_rate")
|
||||||
|
summary.result_path = str(result_path)
|
||||||
|
if (
|
||||||
|
isinstance(summary.best_request_rate, (int, float))
|
||||||
|
and (best_rate is None or summary.best_request_rate > best_rate)
|
||||||
|
):
|
||||||
|
best_rate = float(summary.best_request_rate)
|
||||||
|
best_trial_id = trial_id
|
||||||
|
state.best_request_rate = best_rate
|
||||||
|
state.best_trial_id = best_trial_id
|
||||||
|
self.save_state(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_json(path: Path, payload: Any) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
178
src/aituner/trace.py
Normal file
178
src/aituner/trace.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
from .spec import StudySpec
|
||||||
|
|
||||||
|
|
||||||
|
class TraceError(ValueError):
|
||||||
|
"""Raised when trace assets are invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def _percentile(values: list[float], p: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return 0.0
|
||||||
|
ordered = sorted(values)
|
||||||
|
idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1))
|
||||||
|
return ordered[idx]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WindowRecord:
|
||||||
|
window_id: str
|
||||||
|
trace_path: Path
|
||||||
|
trace_type: str
|
||||||
|
window_start: float
|
||||||
|
window_end: float
|
||||||
|
source_payload: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TraceRequest:
|
||||||
|
row_id: str
|
||||||
|
arrival_s: float
|
||||||
|
sampling_u: float
|
||||||
|
body: dict[str, Any]
|
||||||
|
prompt_tokens_hint: int | None
|
||||||
|
completion_tokens_hint: int | None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord:
|
||||||
|
windows_path = Path(study.trace.windows_path)
|
||||||
|
if not windows_path.is_absolute():
|
||||||
|
windows_path = (study_spec_path.parent / windows_path).resolve()
|
||||||
|
payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
||||||
|
windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload
|
||||||
|
if not isinstance(windows, list):
|
||||||
|
raise TraceError(f"windows payload must contain a list: {windows_path}")
|
||||||
|
for item in windows:
|
||||||
|
if not isinstance(item, Mapping):
|
||||||
|
continue
|
||||||
|
if str(item.get("window_id") or "").strip() != study.trace.window_id:
|
||||||
|
continue
|
||||||
|
trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip()
|
||||||
|
if not trace_file:
|
||||||
|
raise TraceError(f"window {study.trace.window_id} does not define trace_file")
|
||||||
|
trace_path = Path(trace_file)
|
||||||
|
if not trace_path.is_absolute():
|
||||||
|
trace_path = (windows_path.parent / trace_path).resolve()
|
||||||
|
return WindowRecord(
|
||||||
|
window_id=study.trace.window_id,
|
||||||
|
trace_path=trace_path,
|
||||||
|
trace_type=str(item.get("trace_type") or "chat").strip(),
|
||||||
|
window_start=float(item.get("window_start") or 0.0),
|
||||||
|
window_end=float(item.get("window_end") or 0.0),
|
||||||
|
source_payload={str(key): value for key, value in item.items()},
|
||||||
|
)
|
||||||
|
raise TraceError(f"window_id not found: {study.trace.window_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
messages = row.get("messages")
|
||||||
|
if isinstance(messages, list) and messages:
|
||||||
|
return [dict(item) for item in messages if isinstance(item, Mapping)]
|
||||||
|
prompt = row.get("prompt") or row.get("input") or row.get("text")
|
||||||
|
if isinstance(prompt, str) and prompt.strip():
|
||||||
|
return [{"role": "user", "content": prompt}]
|
||||||
|
raise TraceError("trace row is missing chat messages/prompt text")
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
||||||
|
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
||||||
|
value = row.get(key)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
continue
|
||||||
|
if isinstance(value, int) and value >= 0:
|
||||||
|
return value
|
||||||
|
if isinstance(value, float) and value >= 0:
|
||||||
|
return int(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None:
|
||||||
|
for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"):
|
||||||
|
value = row.get(key)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
continue
|
||||||
|
if isinstance(value, int) and value >= 0:
|
||||||
|
return value
|
||||||
|
if isinstance(value, float) and value >= 0:
|
||||||
|
return int(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
|
||||||
|
window = resolve_window_record(study, study_spec_path=study_spec_path)
|
||||||
|
requests: list[TraceRequest] = []
|
||||||
|
with window.trace_path.open("r", encoding="utf-8") as handle:
|
||||||
|
for idx, raw in enumerate(handle):
|
||||||
|
if not raw.strip():
|
||||||
|
continue
|
||||||
|
row = json.loads(raw)
|
||||||
|
if not isinstance(row, Mapping):
|
||||||
|
continue
|
||||||
|
timestamp = row.get(study.trace.timestamp_field)
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = row.get("arrival_time", row.get("timestamp"))
|
||||||
|
if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)):
|
||||||
|
raise TraceError(f"trace row {idx} is missing numeric timestamp")
|
||||||
|
sampling_u = row.get(study.trace.u_field, 1.0)
|
||||||
|
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
||||||
|
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": study.model.served_model_name,
|
||||||
|
"messages": _coerce_messages(row),
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
}
|
||||||
|
completion_tokens = _coerce_completion_tokens(row)
|
||||||
|
if completion_tokens is not None:
|
||||||
|
body["max_tokens"] = completion_tokens
|
||||||
|
temperature = row.get("temperature")
|
||||||
|
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):
|
||||||
|
body["temperature"] = temperature
|
||||||
|
requests.append(
|
||||||
|
TraceRequest(
|
||||||
|
row_id=str(row.get("request_id") or row.get("id") or idx),
|
||||||
|
arrival_s=float(timestamp),
|
||||||
|
sampling_u=float(sampling_u),
|
||||||
|
body=body,
|
||||||
|
prompt_tokens_hint=_coerce_prompt_tokens(row),
|
||||||
|
completion_tokens_hint=completion_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
requests.sort(key=lambda item: item.arrival_s)
|
||||||
|
if study.trace.max_requests_per_probe is not None:
|
||||||
|
requests = requests[: study.trace.max_requests_per_probe]
|
||||||
|
return window, requests
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]:
|
||||||
|
prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests]
|
||||||
|
completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests]
|
||||||
|
duration = max(window.window_end - window.window_start, 0.0) or (
|
||||||
|
requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0
|
||||||
|
)
|
||||||
|
qps = (len(requests) / duration) if duration > 0 else 0.0
|
||||||
|
return {
|
||||||
|
"window_id": window.window_id,
|
||||||
|
"trace_path": str(window.trace_path),
|
||||||
|
"trace_type": window.trace_type,
|
||||||
|
"request_count": len(requests),
|
||||||
|
"duration_s": duration,
|
||||||
|
"request_rate": qps,
|
||||||
|
"prompt_tokens_p50": _percentile(prompt_tokens, 50.0),
|
||||||
|
"prompt_tokens_p95": _percentile(prompt_tokens, 95.0),
|
||||||
|
"completion_tokens_p50": _percentile(completion_tokens, 50.0),
|
||||||
|
"completion_tokens_p95": _percentile(completion_tokens, 95.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def select_requests_for_threshold(
|
||||||
|
requests: list[TraceRequest], *, threshold: float
|
||||||
|
) -> list[TraceRequest]:
|
||||||
|
return [item for item in requests if item.sampling_u <= threshold]
|
||||||
215
src/aituner/worker.py
Normal file
215
src/aituner/worker.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .engine import build_launch_recipe
|
||||||
|
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
|
||||||
|
from .search import ThresholdProbe, binary_search_max_feasible
|
||||||
|
from .slo import RequestOutcome, summarize_evaluations
|
||||||
|
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
|
||||||
|
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProbePayload:
|
||||||
|
threshold: float
|
||||||
|
request_count: int
|
||||||
|
pass_rate: float
|
||||||
|
request_rate: float
|
||||||
|
feasible: bool
|
||||||
|
outcomes: list[dict[str, Any]]
|
||||||
|
|
||||||
|
def _trial_spec_from_json(path: Path) -> TrialSpec:
|
||||||
|
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
return TrialSpec(
|
||||||
|
study_id=str(payload["study_id"]),
|
||||||
|
trial_id=str(payload["trial_id"]),
|
||||||
|
config_patch=ConfigPatch.from_dict(payload["config_patch"]),
|
||||||
|
search=SamplingSearchSpec.from_dict(payload["search"]),
|
||||||
|
study_spec_path=str(payload["study_spec_path"]),
|
||||||
|
artifact_dir=str(payload["artifact_dir"]),
|
||||||
|
probe_log_path=str(payload["probe_log_path"]),
|
||||||
|
engine_log_path=str(payload["engine_log_path"]),
|
||||||
|
result_path=str(payload["result_path"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_one_request(
|
||||||
|
request: TraceRequest,
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
timeout_s: float,
|
||||||
|
) -> RequestOutcome:
|
||||||
|
try:
|
||||||
|
metrics = stream_chat_completion(base_url=base_url, body=request.body, timeout_s=timeout_s)
|
||||||
|
return RequestOutcome(
|
||||||
|
request_id=request.row_id,
|
||||||
|
success=True,
|
||||||
|
ttft_ms=metrics.ttft_ms,
|
||||||
|
tpot_ms=metrics.tpot_ms,
|
||||||
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
|
completion_tokens=metrics.completion_tokens or request.completion_tokens_hint,
|
||||||
|
)
|
||||||
|
except HttpClientError as exc:
|
||||||
|
return RequestOutcome(
|
||||||
|
request_id=request.row_id,
|
||||||
|
success=False,
|
||||||
|
ttft_ms=None,
|
||||||
|
tpot_ms=None,
|
||||||
|
prompt_tokens=request.prompt_tokens_hint,
|
||||||
|
completion_tokens=request.completion_tokens_hint,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _replay_requests(
|
||||||
|
requests: list[TraceRequest],
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
timeout_s: float,
|
||||||
|
max_concurrency: int,
|
||||||
|
) -> list[RequestOutcome]:
|
||||||
|
outcomes_by_id: dict[str, RequestOutcome] = {}
|
||||||
|
lock = threading.Lock()
|
||||||
|
start = time.monotonic()
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||||
|
futures = []
|
||||||
|
for request in requests:
|
||||||
|
delay = max(0.0, request.arrival_s)
|
||||||
|
now = time.monotonic()
|
||||||
|
sleep_for = (start + delay) - now
|
||||||
|
if sleep_for > 0:
|
||||||
|
time.sleep(sleep_for)
|
||||||
|
futures.append(
|
||||||
|
pool.submit(
|
||||||
|
_run_one_request,
|
||||||
|
request,
|
||||||
|
base_url=base_url,
|
||||||
|
timeout_s=timeout_s,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for future in as_completed(futures):
|
||||||
|
outcome = future.result()
|
||||||
|
with lock:
|
||||||
|
outcomes_by_id[outcome.request_id] = outcome
|
||||||
|
return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
||||||
|
|
||||||
|
|
||||||
|
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||||
|
from .store import StudyStore
|
||||||
|
|
||||||
|
trial = _trial_spec_from_json(trial_spec_path)
|
||||||
|
study_spec_path = Path(Path(trial.study_spec_path).read_text(encoding="utf-8").strip())
|
||||||
|
study = load_study_spec(study_spec_path)
|
||||||
|
window, requests = load_trace_requests(study, study_spec_path=study_spec_path)
|
||||||
|
recipe = build_launch_recipe(study.engine, trial.config_patch)
|
||||||
|
artifact_dir = Path(trial.artifact_dir)
|
||||||
|
artifact_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
engine_log_path = Path(trial.engine_log_path)
|
||||||
|
with engine_log_path.open("w", encoding="utf-8") as engine_log:
|
||||||
|
process = subprocess.Popen( # noqa: S603
|
||||||
|
recipe.argv,
|
||||||
|
cwd=recipe.cwd,
|
||||||
|
env=recipe.env,
|
||||||
|
stdout=engine_log,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
wait_for_server(recipe.base_url, recipe.healthcheck_path, recipe.ready_timeout_s)
|
||||||
|
probe_history: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
|
||||||
|
selected = select_requests_for_threshold(requests, threshold=threshold)
|
||||||
|
outcomes = _replay_requests(
|
||||||
|
selected,
|
||||||
|
base_url=recipe.base_url,
|
||||||
|
timeout_s=recipe.request_timeout_s,
|
||||||
|
max_concurrency=study.trace.max_concurrency,
|
||||||
|
)
|
||||||
|
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||||
|
request_rate = (
|
||||||
|
len(selected) / max(window.window_end - window.window_start, 1e-9)
|
||||||
|
if selected
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
payload = ProbePayload(
|
||||||
|
threshold=threshold,
|
||||||
|
request_count=len(selected),
|
||||||
|
pass_rate=float(summary["slo_pass_rate"]),
|
||||||
|
request_rate=request_rate,
|
||||||
|
feasible=bool(summary["feasible"]),
|
||||||
|
outcomes=[
|
||||||
|
{
|
||||||
|
"request_id": outcome.request_id,
|
||||||
|
"success": outcome.success,
|
||||||
|
"ttft_ms": outcome.ttft_ms,
|
||||||
|
"tpot_ms": outcome.tpot_ms,
|
||||||
|
"prompt_tokens": outcome.prompt_tokens,
|
||||||
|
"completion_tokens": outcome.completion_tokens,
|
||||||
|
"evaluation": evaluation.passed,
|
||||||
|
"reasons": evaluation.reasons,
|
||||||
|
}
|
||||||
|
for outcome, evaluation in zip(outcomes, evaluations)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
probe_record = {
|
||||||
|
"threshold": threshold,
|
||||||
|
"request_count": payload.request_count,
|
||||||
|
"pass_rate": payload.pass_rate,
|
||||||
|
"request_rate": payload.request_rate,
|
||||||
|
"feasible": payload.feasible,
|
||||||
|
}
|
||||||
|
probe_history.append(probe_record)
|
||||||
|
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
||||||
|
return ThresholdProbe(
|
||||||
|
threshold=threshold,
|
||||||
|
feasible=payload.feasible,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
search = binary_search_max_feasible(
|
||||||
|
low=trial.search.low,
|
||||||
|
high=trial.search.high,
|
||||||
|
tolerance=trial.search.tolerance,
|
||||||
|
max_probes=trial.search.max_probes,
|
||||||
|
evaluator=evaluator,
|
||||||
|
)
|
||||||
|
best = search.best_feasible_payload
|
||||||
|
result = {
|
||||||
|
"study_id": trial.study_id,
|
||||||
|
"trial_id": trial.trial_id,
|
||||||
|
"status": "completed",
|
||||||
|
"best_sampling_u": search.best_threshold if best is not None else None,
|
||||||
|
"best_request_rate": best.request_rate if best is not None else None,
|
||||||
|
"best_pass_rate": best.pass_rate if best is not None else None,
|
||||||
|
"best_request_count": best.request_count if best is not None else None,
|
||||||
|
"probes": [
|
||||||
|
{
|
||||||
|
"threshold": probe.threshold,
|
||||||
|
"feasible": probe.feasible,
|
||||||
|
"payload": {
|
||||||
|
"request_count": probe.payload.request_count,
|
||||||
|
"pass_rate": probe.payload.pass_rate,
|
||||||
|
"request_rate": probe.payload.request_rate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for probe in search.probes
|
||||||
|
],
|
||||||
|
}
|
||||||
|
StudyStore.write_json(Path(trial.result_path), result)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=30)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
process.kill()
|
||||||
|
process.wait(timeout=30)
|
||||||
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
if str(SRC) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC))
|
||||||
267
tests/test_core_flow.py
Normal file
267
tests/test_core_flow.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aituner.job import append_job, build_trial_job
|
||||||
|
from aituner.llm import build_prompt, parse_proposal_text
|
||||||
|
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||||
|
from aituner.slo import RequestOutcome, summarize_evaluations
|
||||||
|
from aituner.spec import Proposal, load_study_spec
|
||||||
|
from aituner.store import StudyStore
|
||||||
|
from aituner.trace import load_trace_requests, summarize_window
|
||||||
|
|
||||||
|
|
||||||
|
def _write_study_assets(tmp_path: Path) -> Path:
|
||||||
|
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||||
|
trace_dir.mkdir(parents=True)
|
||||||
|
trace_path = trace_dir / "chat_w1.jsonl"
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"request_id": "r1",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"sampling_u": 0.10,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"input_length": 1000,
|
||||||
|
"output_length": 16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"request_id": "r2",
|
||||||
|
"timestamp": 1.0,
|
||||||
|
"sampling_u": 0.50,
|
||||||
|
"messages": [{"role": "user", "content": "world"}],
|
||||||
|
"input_length": 5000,
|
||||||
|
"output_length": 32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"request_id": "r3",
|
||||||
|
"timestamp": 2.0,
|
||||||
|
"sampling_u": 0.90,
|
||||||
|
"messages": [{"role": "user", "content": "!"}],
|
||||||
|
"input_length": 20000,
|
||||||
|
"output_length": 64
|
||||||
|
}
|
||||||
|
]
|
||||||
|
with trace_path.open("w", encoding="utf-8") as handle:
|
||||||
|
for row in rows:
|
||||||
|
handle.write(json.dumps(row) + "\n")
|
||||||
|
|
||||||
|
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||||
|
windows_payload = {
|
||||||
|
"u_field": "sampling_u",
|
||||||
|
"windows": [
|
||||||
|
{
|
||||||
|
"window_id": "chat_w1",
|
||||||
|
"trace_type": "chat",
|
||||||
|
"trace_file": "traces/chat_w1.jsonl",
|
||||||
|
"window_start": 0.0,
|
||||||
|
"window_end": 10.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
|
||||||
|
|
||||||
|
capability_path = tmp_path / "capability.json"
|
||||||
|
capability_path.write_text(
|
||||||
|
json.dumps({"prefill_service_by_bucket": {"4k": {"tp4_ms": 320, "tp8_ms": 240}}}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
study_path = tmp_path / "study.json"
|
||||||
|
study_payload = {
|
||||||
|
"study_id": "study-1",
|
||||||
|
"hardware": {"gpu_count": 8, "gpu_model": "H20", "host_candidates": ["dash0"]},
|
||||||
|
"model": {
|
||||||
|
"model_id": "qwen",
|
||||||
|
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||||
|
},
|
||||||
|
"engine": {
|
||||||
|
"engine_name": "vllm",
|
||||||
|
"engine_version": "0.1",
|
||||||
|
"exec_path": "/usr/local/bin/vllm",
|
||||||
|
"cwd": str(tmp_path),
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8000,
|
||||||
|
"healthcheck_path": "/v1/models",
|
||||||
|
"ready_timeout_s": 30,
|
||||||
|
"request_timeout_s": 30,
|
||||||
|
"launch_args": ["serve", "/models/qwen"],
|
||||||
|
"base_envs": {"BASE_ENV": "1"},
|
||||||
|
"base_flags": {"host": "127.0.0.1", "port": 8000},
|
||||||
|
"tunable_envs": ["VLLM_ATTENTION_BACKEND"],
|
||||||
|
"tunable_flags": ["tensor-parallel-size", "max-num-seqs"],
|
||||||
|
"python_executable": "python3"
|
||||||
|
},
|
||||||
|
"trace": {
|
||||||
|
"windows_path": str(windows_path),
|
||||||
|
"window_id": "chat_w1",
|
||||||
|
"u_field": "sampling_u",
|
||||||
|
"timestamp_field": "timestamp",
|
||||||
|
"max_concurrency": 4
|
||||||
|
},
|
||||||
|
"slo": {
|
||||||
|
"target_pass_rate": 0.95,
|
||||||
|
"ttft_rule": {
|
||||||
|
"kind": "step_ms",
|
||||||
|
"buckets": [
|
||||||
|
{"max_input_tokens": 4096, "threshold_ms": 2000},
|
||||||
|
{"max_input_tokens": 16384, "threshold_ms": 5000},
|
||||||
|
{"threshold_ms": 9000}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 120}
|
||||||
|
},
|
||||||
|
"search": {
|
||||||
|
"low": 0.0,
|
||||||
|
"high": 1.0,
|
||||||
|
"tolerance": 0.01,
|
||||||
|
"max_probes": 8,
|
||||||
|
"sample_seed": 20260325
|
||||||
|
},
|
||||||
|
"llm": {"system_prompt": "Tune it.", "max_history_trials": 8},
|
||||||
|
"capability_profile_path": str(capability_path)
|
||||||
|
}
|
||||||
|
study_path.write_text(json.dumps(study_payload), encoding="utf-8")
|
||||||
|
return study_path
|
||||||
|
|
||||||
|
|
||||||
|
class CoreFlowTests(unittest.TestCase):
|
||||||
|
def test_trace_and_prompt_flow(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||||
|
study_root = store.init_study(spec_path=study_path, study=study)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
|
||||||
|
window, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||||
|
summary = summarize_window(requests, window)
|
||||||
|
self.assertEqual(summary["request_count"], 3)
|
||||||
|
self.assertEqual(summary["request_rate"], 0.3)
|
||||||
|
|
||||||
|
prompt = build_prompt(
|
||||||
|
study=study,
|
||||||
|
window_summary=summary,
|
||||||
|
state=state,
|
||||||
|
capability_profile={"queueing_knee_by_bucket": {"4k": 1000}},
|
||||||
|
)
|
||||||
|
self.assertIn("allowed_flag_keys", prompt)
|
||||||
|
self.assertIn("study-1", prompt)
|
||||||
|
self.assertIn("queueing_knee_by_bucket", prompt)
|
||||||
|
self.assertTrue(study_root.exists())
|
||||||
|
|
||||||
|
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
study = load_study_spec(_write_study_assets(Path(tmp)))
|
||||||
|
outcomes = [
|
||||||
|
RequestOutcome(
|
||||||
|
request_id="r1",
|
||||||
|
success=True,
|
||||||
|
ttft_ms=1000,
|
||||||
|
tpot_ms=100,
|
||||||
|
prompt_tokens=1000,
|
||||||
|
completion_tokens=16,
|
||||||
|
),
|
||||||
|
RequestOutcome(
|
||||||
|
request_id="r2",
|
||||||
|
success=True,
|
||||||
|
ttft_ms=6000,
|
||||||
|
tpot_ms=100,
|
||||||
|
prompt_tokens=5000,
|
||||||
|
completion_tokens=16,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||||
|
self.assertTrue(evaluations[0].passed)
|
||||||
|
self.assertFalse(evaluations[1].passed)
|
||||||
|
self.assertEqual(summary["slo_pass_rate"], 0.5)
|
||||||
|
self.assertFalse(summary["feasible"])
|
||||||
|
|
||||||
|
def test_binary_search_max_feasible(self) -> None:
|
||||||
|
result = binary_search_max_feasible(
|
||||||
|
low=0.0,
|
||||||
|
high=1.0,
|
||||||
|
tolerance=0.01,
|
||||||
|
max_probes=8,
|
||||||
|
evaluator=lambda threshold: ThresholdProbe(
|
||||||
|
threshold=threshold,
|
||||||
|
feasible=threshold <= 0.625,
|
||||||
|
payload={"threshold": threshold},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertLessEqual(result.best_threshold, 0.625)
|
||||||
|
self.assertGreaterEqual(result.best_threshold, 0.5)
|
||||||
|
self.assertIsNotNone(result.best_feasible_payload)
|
||||||
|
|
||||||
|
def test_proposal_validation_and_job_emission(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||||
|
store.init_study(spec_path=study_path, study=study)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
|
||||||
|
proposal_text = json.dumps(
|
||||||
|
{
|
||||||
|
"observation": "Current TTFT fails before TPOT.",
|
||||||
|
"diagnosis": "Prefill pressure dominates.",
|
||||||
|
"config_patch": {
|
||||||
|
"env_patch": {"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||||
|
"flag_patch": {"tensor-parallel-size": 4, "max-num-seqs": 64}
|
||||||
|
},
|
||||||
|
"expected_effects": ["lower TTFT", "raise feasible sampling_u"],
|
||||||
|
"why_not_previous_failures": "Avoids changing unsupported envs."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
proposal = parse_proposal_text(proposal_text, study)
|
||||||
|
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||||
|
|
||||||
|
job = build_trial_job(study=study, trial=trial, repo_root=tmp_path)
|
||||||
|
jobs_path = tmp_path / "jobs.toml"
|
||||||
|
append_job(jobs_path, job)
|
||||||
|
rendered = jobs_path.read_text(encoding="utf-8")
|
||||||
|
self.assertIn('name = "study-1-trial-0001"', rendered)
|
||||||
|
self.assertIn('command = "python3 -m aituner.cli worker run-trial', rendered)
|
||||||
|
self.assertIn('PYTHONPATH = "src"', rendered)
|
||||||
|
|
||||||
|
def test_ingest_trial_results_updates_best(self) -> None:
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
tmp_path = Path(tmp)
|
||||||
|
study_path = _write_study_assets(tmp_path)
|
||||||
|
study = load_study_spec(study_path)
|
||||||
|
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||||
|
store.init_study(spec_path=study_path, study=study)
|
||||||
|
state = store.load_state(study.study_id)
|
||||||
|
proposal = Proposal.from_dict(
|
||||||
|
{
|
||||||
|
"observation": "Obs",
|
||||||
|
"diagnosis": "Diag",
|
||||||
|
"config_patch": {"env_patch": {}, "flag_patch": {"tensor-parallel-size": 4}},
|
||||||
|
"expected_effects": ["raise rate"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||||
|
Path(trial.result_path).write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"study_id": study.study_id,
|
||||||
|
"trial_id": trial.trial_id,
|
||||||
|
"status": "completed",
|
||||||
|
"best_sampling_u": 0.75,
|
||||||
|
"best_request_rate": 12.5,
|
||||||
|
"best_pass_rate": 0.97
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
next_state = store.ingest_trial_results(study.study_id)
|
||||||
|
self.assertEqual(next_state.best_trial_id, trial.trial_id)
|
||||||
|
self.assertEqual(next_state.best_request_rate, 12.5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user