Initial AITuner study orchestrator
This commit is contained in:
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
.aituner/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
infra/gpu_fleet/config/fleet.toml
|
||||
infra/gpu_fleet/config/jobs.toml
|
||||
14
configs/examples/capability.example.json
Normal file
14
configs/examples/capability.example.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"prefill_service_by_bucket": {
|
||||
"4k": {
|
||||
"tp4_ms": 320,
|
||||
"tp8_ms": 240
|
||||
}
|
||||
},
|
||||
"queueing_knee_by_bucket": {
|
||||
"4k": {
|
||||
"tp4_tok_s_per_gpu": 1000,
|
||||
"tp8_tok_s_per_gpu": 1100
|
||||
}
|
||||
}
|
||||
}
|
||||
96
configs/examples/study.example.json
Normal file
96
configs/examples/study.example.json
Normal file
@@ -0,0 +1,96 @@
|
||||
{
|
||||
"study_id": "example-chat-window",
|
||||
"hardware": {
|
||||
"gpu_count": 8,
|
||||
"gpu_model": "H20",
|
||||
"host_candidates": ["dash0", "dash1"]
|
||||
},
|
||||
"model": {
|
||||
"model_id": "qwen3-30b",
|
||||
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": "vllm",
|
||||
"engine_version": "0.x",
|
||||
"exec_path": "/usr/local/bin/vllm",
|
||||
"cwd": ".",
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"healthcheck_path": "/v1/models",
|
||||
"ready_timeout_s": 600,
|
||||
"request_timeout_s": 600,
|
||||
"launch_args": [
|
||||
"serve",
|
||||
"/path/to/model"
|
||||
],
|
||||
"base_envs": {},
|
||||
"base_flags": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"served-model-name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
},
|
||||
"tunable_envs": [
|
||||
"VLLM_ATTENTION_BACKEND",
|
||||
"CUDA_GRAPH_MAX_BATCH_SIZE"
|
||||
],
|
||||
"tunable_flags": [
|
||||
"tensor-parallel-size",
|
||||
"data-parallel-size",
|
||||
"pipeline-parallel-size",
|
||||
"max-num-seqs",
|
||||
"max-num-batched-tokens",
|
||||
"gpu-memory-utilization",
|
||||
"enable-prefix-caching",
|
||||
"block-size"
|
||||
],
|
||||
"python_executable": "python3"
|
||||
},
|
||||
"trace": {
|
||||
"windows_path": "trace_windows/windows.json",
|
||||
"window_id": "chat_w_example_peak_0001",
|
||||
"u_field": "sampling_u",
|
||||
"timestamp_field": "timestamp",
|
||||
"max_concurrency": 64
|
||||
},
|
||||
"slo": {
|
||||
"target_pass_rate": 0.95,
|
||||
"ttft_rule": {
|
||||
"kind": "step_ms",
|
||||
"buckets": [
|
||||
{
|
||||
"max_input_tokens": 4096,
|
||||
"threshold_ms": 2000
|
||||
},
|
||||
{
|
||||
"max_input_tokens": 16384,
|
||||
"threshold_ms": 4000
|
||||
},
|
||||
{
|
||||
"threshold_ms": 8000
|
||||
}
|
||||
]
|
||||
},
|
||||
"tpot_rule": {
|
||||
"kind": "fixed_ms",
|
||||
"threshold_ms": 120
|
||||
}
|
||||
},
|
||||
"search": {
|
||||
"low": 0.0,
|
||||
"high": 1.0,
|
||||
"tolerance": 0.01,
|
||||
"max_probes": 8,
|
||||
"sample_seed": 20260325
|
||||
},
|
||||
"llm": {
|
||||
"system_prompt": "Propose a single engine config patch that increases the maximum feasible sampling_u under the SLO target.",
|
||||
"max_history_trials": 8,
|
||||
"endpoint": {
|
||||
"base_url": "https://example-openai-compatible-endpoint",
|
||||
"model": "gpt-4.1-mini",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
"timeout_s": 120
|
||||
}
|
||||
},
|
||||
"capability_profile_path": "capability.example.json"
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
{"request_id":"example-1","timestamp":0.0,"sampling_u":0.10,"messages":[{"role":"user","content":"hello"}],"input_length":512,"output_length":16}
|
||||
{"request_id":"example-2","timestamp":1.0,"sampling_u":0.50,"messages":[{"role":"user","content":"summarize this file"}],"input_length":2048,"output_length":64}
|
||||
{"request_id":"example-3","timestamp":2.5,"sampling_u":0.90,"messages":[{"role":"user","content":"write a longer answer"}],"input_length":8192,"output_length":128}
|
||||
15
configs/examples/trace_windows/windows.json
Normal file
15
configs/examples/trace_windows/windows.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"sample_seed": 20260325,
|
||||
"u_field": "sampling_u",
|
||||
"window_duration_seconds": 10.0,
|
||||
"windows": [
|
||||
{
|
||||
"window_id": "chat_w_example_peak_0001",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_w_example_peak_0001.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0,
|
||||
"num_requests": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
59
infra/gpu_fleet/config/fleet.example.toml
Normal file
59
infra/gpu_fleet/config/fleet.example.toml
Normal file
@@ -0,0 +1,59 @@
|
||||
version = 1
|
||||
|
||||
[paths]
|
||||
state_dir = ".aituner/gpu_fleet/state"
|
||||
artifacts_dir = ".aituner/gpu_fleet/artifacts"
|
||||
|
||||
[ssh]
|
||||
connect_timeout_sec = 10
|
||||
|
||||
[scheduler]
|
||||
gpu_free_memory_mb = 1024
|
||||
gpu_free_utilization_pct = 10
|
||||
prefer_pack = true
|
||||
|
||||
[sync]
|
||||
mode = "rsync"
|
||||
local_path = "."
|
||||
exclude = [
|
||||
".git/",
|
||||
".venv/",
|
||||
".aituner/",
|
||||
"__pycache__/",
|
||||
"*.pyc",
|
||||
]
|
||||
|
||||
[[hosts]]
|
||||
name = "dash0"
|
||||
ssh_alias = "dash0"
|
||||
enabled = true
|
||||
sync_remote_path = "~/workspace/aituner"
|
||||
fleet_root = "~/.aituner_gpu_fleet"
|
||||
|
||||
[[hosts]]
|
||||
name = "dash1"
|
||||
ssh_alias = "dash1"
|
||||
enabled = true
|
||||
sync_remote_path = "~/workspace/aituner"
|
||||
fleet_root = "~/.aituner_gpu_fleet"
|
||||
|
||||
[[hosts]]
|
||||
name = "dash2"
|
||||
ssh_alias = "dash2"
|
||||
enabled = true
|
||||
sync_remote_path = "~/workspace/aituner"
|
||||
fleet_root = "~/.aituner_gpu_fleet"
|
||||
|
||||
[[hosts]]
|
||||
name = "dash3"
|
||||
ssh_alias = "dash3"
|
||||
enabled = true
|
||||
sync_remote_path = "~/aituner"
|
||||
fleet_root = "~/.aituner_gpu_fleet"
|
||||
|
||||
[[hosts]]
|
||||
name = "dash5"
|
||||
ssh_alias = "dash5"
|
||||
enabled = true
|
||||
sync_remote_path = "~/workspace/aituner"
|
||||
fleet_root = "~/.aituner_gpu_fleet"
|
||||
27
infra/gpu_fleet/config/jobs.example.toml
Normal file
27
infra/gpu_fleet/config/jobs.example.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
# This file is an append-only queue source for the monitor.
|
||||
# Each job name must stay unique and immutable once appended.
|
||||
version = 1
|
||||
|
||||
[[jobs]]
|
||||
name = "smoke-train-h20-1gpu"
|
||||
gpus = 1
|
||||
gpu_model = "H20"
|
||||
hosts = ["dash0", "dash1", "dash2"]
|
||||
command = "python train.py --config configs/smoke.toml"
|
||||
artifacts = ["outputs/smoke-train-h20-1gpu"]
|
||||
env = { WANDB_MODE = "offline" }
|
||||
|
||||
[[jobs]]
|
||||
name = "eval-5090-4gpu"
|
||||
gpus = 4
|
||||
gpu_model = "5090"
|
||||
hosts = ["dash5"]
|
||||
command = "python eval.py --config configs/eval.toml"
|
||||
artifacts = ["outputs/eval-5090-4gpu", "logs/eval-5090-4gpu.log"]
|
||||
|
||||
[[jobs]]
|
||||
name = "special-dash3-run"
|
||||
gpus = 2
|
||||
hosts = ["dash3"]
|
||||
command = "python benchmark.py --suite long-context"
|
||||
artifacts = ["outputs/special-dash3-run"]
|
||||
8
infra/gpu_fleet/config/ssh_aliases.example.txt
Normal file
8
infra/gpu_fleet/config/ssh_aliases.example.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# One SSH alias per line.
|
||||
# Lines starting with "#" are ignored.
|
||||
dash0
|
||||
dash1
|
||||
dash2
|
||||
dash3
|
||||
dash5
|
||||
|
||||
1132
infra/gpu_fleet/gpu_fleet.py
Executable file
1132
infra/gpu_fleet/gpu_fleet.py
Executable file
File diff suppressed because it is too large
Load Diff
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "aituner"
|
||||
version = "0.1.0"
|
||||
description = "AITuner study orchestrator for OpenAI-compatible serving engines"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = []
|
||||
|
||||
[project.scripts]
|
||||
aituner = "aituner.cli:main"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "src"}
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
5
src/aituner/__init__.py
Normal file
5
src/aituner/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""AITuner package."""
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
177
src/aituner/cli.py
Normal file
177
src/aituner/cli.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .job import append_job, build_trial_job
|
||||
from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text
|
||||
from .spec import Proposal, SpecError, load_study_spec
|
||||
from .store import StudyStore
|
||||
from .trace import load_trace_requests, summarize_window
|
||||
from .worker import run_trial
|
||||
|
||||
|
||||
def _study_source_path(study_root: Path) -> Path:
|
||||
return Path((study_root / "study_spec.source").read_text(encoding="utf-8").strip())
|
||||
|
||||
|
||||
def cmd_study_init(args: argparse.Namespace) -> int:
|
||||
spec_path = Path(args.spec).resolve()
|
||||
study = load_study_spec(spec_path)
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
root = store.init_study(spec_path=spec_path, study=study)
|
||||
print(root)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_study_prompt(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = Path(args.study_root).resolve()
|
||||
source_path = _study_source_path(study_root)
|
||||
study = load_study_spec(source_path)
|
||||
state = store.load_state(study.study_id)
|
||||
capability_profile = load_capability_profile(study, study_spec_path=source_path)
|
||||
window, requests = load_trace_requests(study, study_spec_path=source_path)
|
||||
prompt = build_prompt(
|
||||
study=study,
|
||||
window_summary=summarize_window(requests, window),
|
||||
state=state,
|
||||
capability_profile=capability_profile,
|
||||
)
|
||||
prompt_name = args.prompt_name or f"prompt-{state.next_trial_index:04d}"
|
||||
path = store.write_prompt(study.study_id, prompt_name, prompt)
|
||||
print(path)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_study_llm_propose(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = Path(args.study_root).resolve()
|
||||
source_path = _study_source_path(study_root)
|
||||
study = load_study_spec(source_path)
|
||||
state = store.load_state(study.study_id)
|
||||
capability_profile = load_capability_profile(study, study_spec_path=source_path)
|
||||
window, requests = load_trace_requests(study, study_spec_path=source_path)
|
||||
prompt = build_prompt(
|
||||
study=study,
|
||||
window_summary=summarize_window(requests, window),
|
||||
state=state,
|
||||
capability_profile=capability_profile,
|
||||
)
|
||||
proposal_text = call_llm_for_proposal(policy=study.llm, prompt=prompt)
|
||||
proposal = parse_proposal_text(proposal_text, study)
|
||||
name = args.proposal_name or f"proposal-{state.next_trial_index:04d}"
|
||||
path = store.write_proposal(study.study_id, name, proposal)
|
||||
print(path)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_study_register_proposal(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = Path(args.study_root).resolve()
|
||||
source_path = _study_source_path(study_root)
|
||||
study = load_study_spec(source_path)
|
||||
proposal = parse_proposal_text(Path(args.proposal_file).read_text(encoding="utf-8"), study)
|
||||
name = args.proposal_name or Path(args.proposal_file).stem
|
||||
path = store.write_proposal(study.study_id, name, proposal)
|
||||
print(path)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_study_emit_job(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = Path(args.study_root).resolve()
|
||||
source_path = _study_source_path(study_root)
|
||||
study = load_study_spec(source_path)
|
||||
state = store.load_state(study.study_id)
|
||||
proposal_text = Path(args.proposal_file).read_text(encoding="utf-8")
|
||||
proposal = parse_proposal_text(proposal_text, study)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
job = build_trial_job(study=study, trial=trial, repo_root=repo_root)
|
||||
append_job(Path(args.jobs_file).resolve(), job)
|
||||
print(trial.trial_id)
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_study_ingest(args: argparse.Namespace) -> int:
|
||||
store = StudyStore(Path(args.store_root) if args.store_root else None)
|
||||
study_root = Path(args.study_root).resolve()
|
||||
study_id = study_root.name
|
||||
state = store.ingest_trial_results(study_id)
|
||||
print(json.dumps({"best_trial_id": state.best_trial_id, "best_request_rate": state.best_request_rate}))
|
||||
return 0
|
||||
|
||||
|
||||
def cmd_worker_run_trial(args: argparse.Namespace) -> int:
|
||||
result = run_trial(Path(args.trial_spec).resolve())
|
||||
print(json.dumps(result))
|
||||
return 0
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="AITuner CLI")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
study = subparsers.add_parser("study")
|
||||
study_sub = study.add_subparsers(dest="study_command", required=True)
|
||||
|
||||
init = study_sub.add_parser("init")
|
||||
init.add_argument("--spec", required=True)
|
||||
init.add_argument("--store-root")
|
||||
init.set_defaults(func=cmd_study_init)
|
||||
|
||||
prompt = study_sub.add_parser("prompt")
|
||||
prompt.add_argument("--study-root", required=True)
|
||||
prompt.add_argument("--store-root")
|
||||
prompt.add_argument("--prompt-name")
|
||||
prompt.set_defaults(func=cmd_study_prompt)
|
||||
|
||||
llm_propose = study_sub.add_parser("llm-propose")
|
||||
llm_propose.add_argument("--study-root", required=True)
|
||||
llm_propose.add_argument("--store-root")
|
||||
llm_propose.add_argument("--proposal-name")
|
||||
llm_propose.set_defaults(func=cmd_study_llm_propose)
|
||||
|
||||
register = study_sub.add_parser("register-proposal")
|
||||
register.add_argument("--study-root", required=True)
|
||||
register.add_argument("--store-root")
|
||||
register.add_argument("--proposal-file", required=True)
|
||||
register.add_argument("--proposal-name")
|
||||
register.set_defaults(func=cmd_study_register_proposal)
|
||||
|
||||
emit = study_sub.add_parser("emit-job")
|
||||
emit.add_argument("--study-root", required=True)
|
||||
emit.add_argument("--store-root")
|
||||
emit.add_argument("--proposal-file", required=True)
|
||||
emit.add_argument("--jobs-file", required=True)
|
||||
emit.set_defaults(func=cmd_study_emit_job)
|
||||
|
||||
ingest = study_sub.add_parser("ingest")
|
||||
ingest.add_argument("--study-root", required=True)
|
||||
ingest.add_argument("--store-root")
|
||||
ingest.set_defaults(func=cmd_study_ingest)
|
||||
|
||||
worker = subparsers.add_parser("worker")
|
||||
worker_sub = worker.add_subparsers(dest="worker_command", required=True)
|
||||
run = worker_sub.add_parser("run-trial")
|
||||
run.add_argument("--trial-spec", required=True)
|
||||
run.set_defaults(func=cmd_worker_run_trial)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
try:
|
||||
return int(args.func(args))
|
||||
except SpecError as exc:
|
||||
print(str(exc), file=sys.stderr)
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
65
src/aituner/engine.py
Normal file
65
src/aituner/engine.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .spec import ConfigPatch, EngineLaunchSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LaunchRecipe:
|
||||
argv: list[str]
|
||||
env: dict[str, str]
|
||||
cwd: str | None
|
||||
base_url: str
|
||||
healthcheck_path: str
|
||||
ready_timeout_s: float
|
||||
request_timeout_s: float
|
||||
|
||||
|
||||
def _normalize_flag_name(name: str) -> str:
|
||||
return str(name).strip().replace("_", "-")
|
||||
|
||||
|
||||
def _serialize_flag_parts(name: str, value: Any) -> list[str]:
|
||||
flag = f"--{_normalize_flag_name(name)}"
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, bool):
|
||||
return [flag] if value else [f"--no-{_normalize_flag_name(name)}"]
|
||||
if isinstance(value, list):
|
||||
parts: list[str] = []
|
||||
for item in value:
|
||||
parts.extend([flag, str(item)])
|
||||
return parts
|
||||
return [flag, str(value)]
|
||||
|
||||
|
||||
def build_launch_recipe(spec: EngineLaunchSpec, patch: ConfigPatch) -> LaunchRecipe:
|
||||
env = dict(os.environ)
|
||||
env.update(spec.base_envs)
|
||||
env.update(patch.env_patch)
|
||||
flags = dict(spec.base_flags)
|
||||
flags.update(patch.flag_patch)
|
||||
argv = [spec.exec_path, *spec.launch_args]
|
||||
for key, value in flags.items():
|
||||
argv.extend(_serialize_flag_parts(key, value))
|
||||
cwd = None
|
||||
if spec.cwd:
|
||||
cwd = str(Path(spec.cwd).expanduser())
|
||||
return LaunchRecipe(
|
||||
argv=argv,
|
||||
env={str(key): str(value) for key, value in env.items()},
|
||||
cwd=cwd,
|
||||
base_url=spec.base_url,
|
||||
healthcheck_path=spec.healthcheck_path,
|
||||
ready_timeout_s=spec.ready_timeout_s,
|
||||
request_timeout_s=spec.request_timeout_s,
|
||||
)
|
||||
|
||||
|
||||
def shell_join(argv: list[str]) -> str:
|
||||
return " ".join(shlex.quote(item) for item in argv)
|
||||
147
src/aituner/http_client.py
Normal file
147
src/aituner/http_client.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
class HttpClientError(RuntimeError):
|
||||
"""Raised for HTTP client failures."""
|
||||
|
||||
|
||||
def _auth_headers(api_key_env: str | None) -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, path: str, timeout_s: float) -> None:
|
||||
deadline = time.monotonic() + timeout_s
|
||||
url = f"{base_url.rstrip('/')}{path}"
|
||||
last_error = "server_not_ready"
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
request = urllib.request.Request(url=url, headers=_auth_headers(None), method="GET")
|
||||
with urllib.request.urlopen(request, timeout=5) as response:
|
||||
if 200 <= response.status < 500:
|
||||
return
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = str(exc)
|
||||
time.sleep(1.0)
|
||||
raise HttpClientError(f"Timed out waiting for {url}: {last_error}")
|
||||
|
||||
|
||||
def chat_completion(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key_env: str | None,
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
timeout_s: float,
|
||||
system_prompt: str = "",
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"model": model, "messages": messages}
|
||||
if system_prompt:
|
||||
payload["messages"] = [{"role": "system", "content": system_prompt}, *messages]
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||
headers=_auth_headers(api_key_env),
|
||||
data=data,
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||
return json.loads(response.read().decode("utf-8"))
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise HttpClientError(f"chat_completion failed: {exc.code} {detail}") from exc
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamMetrics:
|
||||
ttft_ms: float | None
|
||||
tpot_ms: float | None
|
||||
completion_tokens: int | None
|
||||
|
||||
|
||||
def stream_chat_completion(
|
||||
*,
|
||||
base_url: str,
|
||||
body: dict[str, Any],
|
||||
timeout_s: float,
|
||||
) -> StreamMetrics:
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
request = urllib.request.Request(
|
||||
url=f"{base_url.rstrip('/')}/v1/chat/completions",
|
||||
headers=_auth_headers(None),
|
||||
data=data,
|
||||
method="POST",
|
||||
)
|
||||
start = time.monotonic()
|
||||
first_token_at: float | None = None
|
||||
last_token_at: float | None = None
|
||||
chunk_token_count = 0
|
||||
completion_tokens: int | None = None
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout_s) as response:
|
||||
for raw in _iter_sse_lines(response):
|
||||
if raw == "[DONE]":
|
||||
break
|
||||
payload = json.loads(raw)
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
usage = payload.get("usage")
|
||||
if isinstance(usage, dict):
|
||||
comp = usage.get("completion_tokens")
|
||||
if isinstance(comp, int) and comp >= 0:
|
||||
completion_tokens = comp
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
if not isinstance(delta, dict):
|
||||
continue
|
||||
content = delta.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
now = time.monotonic()
|
||||
if first_token_at is None:
|
||||
first_token_at = now
|
||||
last_token_at = now
|
||||
chunk_token_count += 1
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise HttpClientError(f"stream_chat_completion failed: {exc.code} {detail}") from exc
|
||||
ttft_ms = None if first_token_at is None else (first_token_at - start) * 1000.0
|
||||
used_tokens = completion_tokens if completion_tokens is not None else chunk_token_count
|
||||
if (
|
||||
first_token_at is None
|
||||
or last_token_at is None
|
||||
or used_tokens is None
|
||||
or used_tokens <= 1
|
||||
):
|
||||
tpot_ms = None
|
||||
else:
|
||||
tpot_ms = ((last_token_at - first_token_at) / max(used_tokens - 1, 1)) * 1000.0
|
||||
return StreamMetrics(
|
||||
ttft_ms=ttft_ms,
|
||||
tpot_ms=tpot_ms,
|
||||
completion_tokens=used_tokens if used_tokens > 0 else None,
|
||||
)
|
||||
|
||||
|
||||
def _iter_sse_lines(response: Any) -> Iterable[str]:
|
||||
for raw in response:
|
||||
line = raw.decode("utf-8", errors="replace").strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
payload = line[len("data:") :].strip()
|
||||
if payload:
|
||||
yield payload
|
||||
75
src/aituner/job.py
Normal file
75
src/aituner/job.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .spec import StudySpec, TrialSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InfraJob:
|
||||
name: str
|
||||
gpus: int
|
||||
gpu_model: str | None
|
||||
hosts: list[str]
|
||||
command: str
|
||||
artifacts: list[str]
|
||||
env: dict[str, str]
|
||||
|
||||
|
||||
def _toml_scalar(value: Any) -> str:
|
||||
if isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
if isinstance(value, int):
|
||||
return str(value)
|
||||
text = str(value).replace("\\", "\\\\").replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
|
||||
def _toml_list(values: list[Any]) -> str:
|
||||
return "[" + ", ".join(_toml_scalar(item) for item in values) + "]"
|
||||
|
||||
|
||||
def _toml_inline_table(mapping: dict[str, str]) -> str:
|
||||
parts = [f"{key} = {_toml_scalar(value)}" for key, value in sorted(mapping.items())]
|
||||
return "{ " + ", ".join(parts) + " }"
|
||||
|
||||
|
||||
def build_trial_job(*, study: StudySpec, trial: TrialSpec, repo_root: Path) -> InfraJob:
|
||||
trial_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||
rel_trial_path = trial_path.resolve().relative_to(repo_root.resolve())
|
||||
rel_trial_dir = Path(trial.artifact_dir).resolve().relative_to(repo_root.resolve())
|
||||
command = (
|
||||
f"{study.engine.python_executable} -m aituner.cli worker run-trial "
|
||||
f"--trial-spec {rel_trial_path}"
|
||||
)
|
||||
env = {"PYTHONPATH": "src"}
|
||||
return InfraJob(
|
||||
name=f"{study.study_id}-{trial.trial_id}",
|
||||
gpus=study.hardware.gpu_count,
|
||||
gpu_model=study.hardware.gpu_model,
|
||||
hosts=list(study.hardware.host_candidates),
|
||||
command=command,
|
||||
artifacts=[str(rel_trial_dir)],
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def append_job(jobs_path: Path, job: InfraJob) -> None:
|
||||
jobs_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jobs_path.open("a", encoding="utf-8") as handle:
|
||||
if jobs_path.stat().st_size == 0:
|
||||
handle.write("version = 1\n")
|
||||
handle.write("\n[[jobs]]\n")
|
||||
handle.write(f"name = {_toml_scalar(job.name)}\n")
|
||||
handle.write(f"gpus = {job.gpus}\n")
|
||||
if job.gpu_model:
|
||||
handle.write(f"gpu_model = {_toml_scalar(job.gpu_model)}\n")
|
||||
if job.hosts:
|
||||
handle.write(f"hosts = {_toml_list(job.hosts)}\n")
|
||||
handle.write(f"command = {_toml_scalar(job.command)}\n")
|
||||
if job.artifacts:
|
||||
handle.write(f"artifacts = {_toml_list(job.artifacts)}\n")
|
||||
if job.env:
|
||||
handle.write(f"env = {_toml_inline_table(job.env)}\n")
|
||||
144
src/aituner/llm.py
Normal file
144
src/aituner/llm.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .http_client import chat_completion
|
||||
from .spec import LLMPolicySpec, Proposal, SpecError, StudySpec, StudyState
|
||||
|
||||
|
||||
def build_prompt(
|
||||
*,
|
||||
study: StudySpec,
|
||||
window_summary: dict[str, Any],
|
||||
state: StudyState,
|
||||
capability_profile: dict[str, Any] | None,
|
||||
) -> str:
|
||||
history = []
|
||||
for trial in state.trials[-study.llm.max_history_trials :]:
|
||||
history.append(
|
||||
{
|
||||
"trial_id": trial.trial_id,
|
||||
"status": trial.status,
|
||||
"best_sampling_u": trial.best_sampling_u,
|
||||
"best_request_rate": trial.best_request_rate,
|
||||
"best_pass_rate": trial.best_pass_rate,
|
||||
"diagnosis": trial.diagnosis,
|
||||
}
|
||||
)
|
||||
sections = [
|
||||
"You are tuning an OpenAI-compatible serving engine.",
|
||||
"Return exactly one JSON object with keys: observation, diagnosis, config_patch, expected_effects, why_not_previous_failures.",
|
||||
"config_patch must contain env_patch and flag_patch.",
|
||||
"Only use allowed tunable env keys and allowed tunable flag keys.",
|
||||
"",
|
||||
"Study stack:",
|
||||
json.dumps(
|
||||
{
|
||||
"study_id": study.study_id,
|
||||
"hardware": {
|
||||
"gpu_count": study.hardware.gpu_count,
|
||||
"gpu_model": study.hardware.gpu_model,
|
||||
},
|
||||
"model": {
|
||||
"model_id": study.model.model_id,
|
||||
"served_model_name": study.model.served_model_name,
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": study.engine.engine_name,
|
||||
"engine_version": study.engine.engine_version,
|
||||
"base_flags": study.engine.base_flags,
|
||||
"base_envs": study.engine.base_envs,
|
||||
"allowed_flag_keys": study.engine.tunable_flags,
|
||||
"allowed_env_keys": study.engine.tunable_envs,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
"",
|
||||
"Window summary:",
|
||||
json.dumps(window_summary, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"SLO:",
|
||||
json.dumps(
|
||||
{
|
||||
"target_pass_rate": study.slo.target_pass_rate,
|
||||
"ttft_rule": study.slo.ttft_rule,
|
||||
"tpot_rule": study.slo.tpot_rule,
|
||||
},
|
||||
default=lambda value: value.__dict__,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
"",
|
||||
"Capability profile:",
|
||||
json.dumps(capability_profile or {}, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"Trial history:",
|
||||
json.dumps(history, ensure_ascii=False, indent=2),
|
||||
"",
|
||||
"The proposal should improve the maximum feasible sampling_u under the 95%+ SLO target.",
|
||||
]
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def load_capability_profile(study: StudySpec, *, study_spec_path: Path) -> dict[str, Any] | None:
|
||||
if not study.capability_profile_path:
|
||||
return None
|
||||
path = Path(study.capability_profile_path)
|
||||
if not path.is_absolute():
|
||||
path = (study_spec_path.parent / path).resolve()
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def validate_proposal(proposal: Proposal, study: StudySpec) -> Proposal:
|
||||
unknown_envs = sorted(set(proposal.config_patch.env_patch) - set(study.engine.tunable_envs))
|
||||
unknown_flags = sorted(
|
||||
set(proposal.config_patch.flag_patch) - set(study.engine.tunable_flags)
|
||||
)
|
||||
if unknown_envs:
|
||||
raise SpecError(f"Proposal uses unsupported env keys: {', '.join(unknown_envs)}")
|
||||
if unknown_flags:
|
||||
raise SpecError(f"Proposal uses unsupported flag keys: {', '.join(unknown_flags)}")
|
||||
return proposal
|
||||
|
||||
|
||||
def parse_proposal_text(text: str, study: StudySpec) -> Proposal:
|
||||
payload = json.loads(text)
|
||||
proposal = Proposal.from_dict(payload)
|
||||
return validate_proposal(proposal, study)
|
||||
|
||||
|
||||
def call_llm_for_proposal(
|
||||
*,
|
||||
policy: LLMPolicySpec,
|
||||
prompt: str,
|
||||
) -> str:
|
||||
if policy.endpoint is None:
|
||||
raise RuntimeError("study.llm.endpoint is not configured")
|
||||
response = chat_completion(
|
||||
base_url=policy.endpoint.base_url,
|
||||
api_key_env=policy.endpoint.api_key_env,
|
||||
model=policy.endpoint.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
timeout_s=policy.endpoint.timeout_s,
|
||||
system_prompt=policy.system_prompt,
|
||||
)
|
||||
choices = response.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
raise RuntimeError("LLM response does not contain choices")
|
||||
message = choices[0].get("message", {})
|
||||
if not isinstance(message, dict):
|
||||
raise RuntimeError("LLM response does not contain a valid message")
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
item.get("text", "")
|
||||
for item in content
|
||||
if isinstance(item, dict) and isinstance(item.get("text"), str)
|
||||
)
|
||||
raise RuntimeError("LLM response content is empty")
|
||||
58
src/aituner/search.py
Normal file
58
src/aituner/search.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdProbe(Generic[T]):
|
||||
threshold: float
|
||||
feasible: bool
|
||||
payload: T
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdSearchResult(Generic[T]):
|
||||
best_threshold: float
|
||||
best_feasible_payload: T | None
|
||||
probes: list[ThresholdProbe[T]]
|
||||
|
||||
|
||||
def binary_search_max_feasible(
|
||||
*,
|
||||
low: float,
|
||||
high: float,
|
||||
tolerance: float,
|
||||
max_probes: int,
|
||||
evaluator: Callable[[float], ThresholdProbe[T]],
|
||||
) -> ThresholdSearchResult[T]:
|
||||
probes: list[ThresholdProbe[T]] = []
|
||||
cache: dict[float, ThresholdProbe[T]] = {}
|
||||
best_threshold = low
|
||||
best_payload: T | None = None
|
||||
cur_low = low
|
||||
cur_high = high
|
||||
for _ in range(max_probes):
|
||||
if cur_high - cur_low <= tolerance:
|
||||
break
|
||||
threshold = round((cur_low + cur_high) / 2.0, 12)
|
||||
probe = cache.get(threshold)
|
||||
if probe is None:
|
||||
probe = evaluator(threshold)
|
||||
cache[threshold] = probe
|
||||
probes.append(probe)
|
||||
if probe.feasible:
|
||||
if threshold >= best_threshold:
|
||||
best_threshold = threshold
|
||||
best_payload = probe.payload
|
||||
cur_low = threshold
|
||||
else:
|
||||
cur_high = threshold
|
||||
return ThresholdSearchResult(
|
||||
best_threshold=best_threshold,
|
||||
best_feasible_payload=best_payload,
|
||||
probes=probes,
|
||||
)
|
||||
80
src/aituner/slo.py
Normal file
80
src/aituner/slo.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .spec import SloSpec, ThresholdRule
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestOutcome:
|
||||
request_id: str
|
||||
success: bool
|
||||
ttft_ms: float | None
|
||||
tpot_ms: float | None
|
||||
prompt_tokens: int | None
|
||||
completion_tokens: int | None
|
||||
error: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RequestEvaluation:
|
||||
request_id: str
|
||||
passed: bool
|
||||
reasons: list[str]
|
||||
|
||||
|
||||
def _rule_threshold_ms(rule: ThresholdRule, prompt_tokens: int | None) -> float:
|
||||
if rule.kind == "fixed_ms":
|
||||
assert rule.threshold_ms is not None
|
||||
return rule.threshold_ms
|
||||
if rule.kind != "step_ms":
|
||||
raise ValueError(f"Unsupported threshold rule: {rule.kind}")
|
||||
prompt = float(prompt_tokens or 0)
|
||||
for bucket in rule.buckets:
|
||||
ceiling = bucket.get("max_input_tokens")
|
||||
if ceiling is None or prompt <= ceiling:
|
||||
return float(bucket["threshold_ms"])
|
||||
return float(rule.buckets[-1]["threshold_ms"])
|
||||
|
||||
|
||||
def evaluate_request(outcome: RequestOutcome, slo: SloSpec) -> RequestEvaluation:
|
||||
reasons: list[str] = []
|
||||
if not outcome.success:
|
||||
reasons.append(outcome.error or "request_failed")
|
||||
return RequestEvaluation(request_id=outcome.request_id, passed=False, reasons=reasons)
|
||||
if slo.ttft_rule is not None:
|
||||
if outcome.ttft_ms is None:
|
||||
reasons.append("ttft_missing")
|
||||
else:
|
||||
threshold = _rule_threshold_ms(slo.ttft_rule, outcome.prompt_tokens)
|
||||
if outcome.ttft_ms > threshold:
|
||||
reasons.append(f"ttft_ms>{threshold}")
|
||||
if slo.tpot_rule is not None:
|
||||
if outcome.tpot_ms is None:
|
||||
reasons.append("tpot_missing")
|
||||
else:
|
||||
threshold = _rule_threshold_ms(slo.tpot_rule, outcome.prompt_tokens)
|
||||
if outcome.tpot_ms > threshold:
|
||||
reasons.append(f"tpot_ms>{threshold}")
|
||||
return RequestEvaluation(
|
||||
request_id=outcome.request_id,
|
||||
passed=not reasons,
|
||||
reasons=reasons,
|
||||
)
|
||||
|
||||
|
||||
def summarize_evaluations(
|
||||
outcomes: list[RequestOutcome], slo: SloSpec
|
||||
) -> tuple[list[RequestEvaluation], dict[str, Any]]:
|
||||
evaluations = [evaluate_request(item, slo) for item in outcomes]
|
||||
total = len(evaluations)
|
||||
passed = sum(1 for item in evaluations if item.passed)
|
||||
pass_rate = (passed / total) if total else 0.0
|
||||
return evaluations, {
|
||||
"request_count": total,
|
||||
"slo_pass_count": passed,
|
||||
"slo_pass_rate": pass_rate,
|
||||
"target_pass_rate": slo.target_pass_rate,
|
||||
"feasible": pass_rate >= slo.target_pass_rate,
|
||||
}
|
||||
440
src/aituner/spec.py
Normal file
440
src/aituner/spec.py
Normal file
@@ -0,0 +1,440 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tomllib
|
||||
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
class SpecError(ValueError):
|
||||
"""Raised when a structured spec is invalid."""
|
||||
|
||||
|
||||
def _require_mapping(value: Any, *, context: str) -> Mapping[str, Any]:
|
||||
if not isinstance(value, Mapping):
|
||||
raise SpecError(f"{context} must be an object.")
|
||||
return value
|
||||
|
||||
|
||||
def _require_str(value: Any, *, context: str) -> str:
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
raise SpecError(f"{context} must be a non-empty string.")
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _require_float(value: Any, *, context: str) -> float:
|
||||
if isinstance(value, bool) or not isinstance(value, (int, float)):
|
||||
raise SpecError(f"{context} must be numeric.")
|
||||
return float(value)
|
||||
|
||||
|
||||
def _require_int(value: Any, *, context: str) -> int:
|
||||
if isinstance(value, bool) or not isinstance(value, int):
|
||||
raise SpecError(f"{context} must be an integer.")
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_str_map(value: Any, *, context: str) -> dict[str, str]:
|
||||
mapping = _require_mapping(value or {}, context=context)
|
||||
return {str(key): str(item) for key, item in mapping.items()}
|
||||
|
||||
|
||||
def _coerce_any_map(value: Any, *, context: str) -> dict[str, Any]:
|
||||
mapping = _require_mapping(value or {}, context=context)
|
||||
return {str(key): item for key, item in mapping.items()}
|
||||
|
||||
|
||||
def _coerce_str_list(value: Any, *, context: str) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if not isinstance(value, list):
|
||||
raise SpecError(f"{context} must be a list.")
|
||||
result: list[str] = []
|
||||
for item in value:
|
||||
result.append(_require_str(item, context=context))
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HardwareSpec:
|
||||
gpu_count: int
|
||||
gpu_model: str | None = None
|
||||
host_candidates: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "HardwareSpec":
|
||||
return cls(
|
||||
gpu_count=_require_int(data.get("gpu_count"), context="hardware.gpu_count"),
|
||||
gpu_model=str(data["gpu_model"]).strip() if data.get("gpu_model") else None,
|
||||
host_candidates=_coerce_str_list(
|
||||
data.get("host_candidates"), context="hardware.host_candidates"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelSpec:
|
||||
model_id: str
|
||||
served_model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "ModelSpec":
|
||||
return cls(
|
||||
model_id=_require_str(data.get("model_id"), context="model.model_id"),
|
||||
served_model_name=_require_str(
|
||||
data.get("served_model_name"), context="model.served_model_name"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineLaunchSpec:
|
||||
engine_name: str
|
||||
engine_version: str | None
|
||||
exec_path: str
|
||||
cwd: str | None
|
||||
host: str
|
||||
port: int
|
||||
ready_timeout_s: float
|
||||
request_timeout_s: float
|
||||
healthcheck_path: str
|
||||
launch_args: list[str]
|
||||
base_envs: dict[str, str]
|
||||
base_flags: dict[str, Any]
|
||||
tunable_envs: list[str]
|
||||
tunable_flags: list[str]
|
||||
python_executable: str = "python3"
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "EngineLaunchSpec":
|
||||
return cls(
|
||||
engine_name=_require_str(data.get("engine_name"), context="engine.engine_name"),
|
||||
engine_version=str(data["engine_version"]).strip()
|
||||
if data.get("engine_version")
|
||||
else None,
|
||||
exec_path=_require_str(data.get("exec_path"), context="engine.exec_path"),
|
||||
cwd=str(data["cwd"]).strip() if data.get("cwd") else None,
|
||||
host=str(data.get("host") or "127.0.0.1").strip(),
|
||||
port=_require_int(data.get("port", 8000), context="engine.port"),
|
||||
ready_timeout_s=_require_float(
|
||||
data.get("ready_timeout_s", 600.0), context="engine.ready_timeout_s"
|
||||
),
|
||||
request_timeout_s=_require_float(
|
||||
data.get("request_timeout_s", 600.0),
|
||||
context="engine.request_timeout_s",
|
||||
),
|
||||
healthcheck_path=str(data.get("healthcheck_path") or "/v1/models").strip(),
|
||||
launch_args=_coerce_str_list(data.get("launch_args"), context="engine.launch_args"),
|
||||
base_envs=_coerce_str_map(data.get("base_envs"), context="engine.base_envs"),
|
||||
base_flags=_coerce_any_map(data.get("base_flags"), context="engine.base_flags"),
|
||||
tunable_envs=_coerce_str_list(
|
||||
data.get("tunable_envs"), context="engine.tunable_envs"
|
||||
),
|
||||
tunable_flags=_coerce_str_list(
|
||||
data.get("tunable_flags"), context="engine.tunable_flags"
|
||||
),
|
||||
python_executable=str(data.get("python_executable") or "python3").strip(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceSpec:
|
||||
windows_path: str
|
||||
window_id: str
|
||||
trace_file_override: str | None
|
||||
u_field: str
|
||||
timestamp_field: str
|
||||
max_concurrency: int
|
||||
max_requests_per_probe: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "TraceSpec":
|
||||
max_requests = data.get("max_requests_per_probe")
|
||||
return cls(
|
||||
windows_path=_require_str(data.get("windows_path"), context="trace.windows_path"),
|
||||
window_id=_require_str(data.get("window_id"), context="trace.window_id"),
|
||||
trace_file_override=str(data["trace_file_override"]).strip()
|
||||
if data.get("trace_file_override")
|
||||
else None,
|
||||
u_field=str(data.get("u_field") or "sampling_u").strip(),
|
||||
timestamp_field=str(data.get("timestamp_field") or "timestamp").strip(),
|
||||
max_concurrency=_require_int(
|
||||
data.get("max_concurrency", 64), context="trace.max_concurrency"
|
||||
),
|
||||
max_requests_per_probe=int(max_requests) if max_requests is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdRule:
|
||||
kind: str
|
||||
threshold_ms: float | None = None
|
||||
buckets: list[dict[str, float]] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "ThresholdRule":
|
||||
kind = _require_str(data.get("kind"), context=f"{context}.kind")
|
||||
if kind == "fixed_ms":
|
||||
return cls(
|
||||
kind=kind,
|
||||
threshold_ms=_require_float(
|
||||
data.get("threshold_ms"), context=f"{context}.threshold_ms"
|
||||
),
|
||||
)
|
||||
if kind == "step_ms":
|
||||
raw = data.get("buckets")
|
||||
if not isinstance(raw, list) or not raw:
|
||||
raise SpecError(f"{context}.buckets must be a non-empty list.")
|
||||
buckets: list[dict[str, float]] = []
|
||||
for idx, item in enumerate(raw):
|
||||
mapping = _require_mapping(item, context=f"{context}.buckets[{idx}]")
|
||||
bucket: dict[str, float] = {
|
||||
"threshold_ms": _require_float(
|
||||
mapping.get("threshold_ms"),
|
||||
context=f"{context}.buckets[{idx}].threshold_ms",
|
||||
)
|
||||
}
|
||||
if "max_input_tokens" in mapping and mapping["max_input_tokens"] is not None:
|
||||
bucket["max_input_tokens"] = _require_float(
|
||||
mapping["max_input_tokens"],
|
||||
context=f"{context}.buckets[{idx}].max_input_tokens",
|
||||
)
|
||||
buckets.append(bucket)
|
||||
return cls(kind=kind, buckets=buckets)
|
||||
raise SpecError(f"Unsupported threshold rule kind: {kind}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SloSpec:
|
||||
target_pass_rate: float
|
||||
ttft_rule: ThresholdRule | None
|
||||
tpot_rule: ThresholdRule | None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "SloSpec":
|
||||
ttft_rule = (
|
||||
ThresholdRule.from_dict(
|
||||
_require_mapping(data["ttft_rule"], context="slo.ttft_rule"),
|
||||
context="slo.ttft_rule",
|
||||
)
|
||||
if data.get("ttft_rule")
|
||||
else None
|
||||
)
|
||||
tpot_rule = (
|
||||
ThresholdRule.from_dict(
|
||||
_require_mapping(data["tpot_rule"], context="slo.tpot_rule"),
|
||||
context="slo.tpot_rule",
|
||||
)
|
||||
if data.get("tpot_rule")
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
target_pass_rate=_require_float(
|
||||
data.get("target_pass_rate", 0.95), context="slo.target_pass_rate"
|
||||
),
|
||||
ttft_rule=ttft_rule,
|
||||
tpot_rule=tpot_rule,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SamplingSearchSpec:
|
||||
low: float
|
||||
high: float
|
||||
tolerance: float
|
||||
max_probes: int
|
||||
sample_seed: int
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "SamplingSearchSpec":
|
||||
return cls(
|
||||
low=_require_float(data.get("low", 0.0), context="search.low"),
|
||||
high=_require_float(data.get("high", 1.0), context="search.high"),
|
||||
tolerance=_require_float(
|
||||
data.get("tolerance", 0.01), context="search.tolerance"
|
||||
),
|
||||
max_probes=_require_int(data.get("max_probes", 8), context="search.max_probes"),
|
||||
sample_seed=_require_int(
|
||||
data.get("sample_seed", 20260325), context="search.sample_seed"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMEndpointSpec:
|
||||
base_url: str
|
||||
model: str
|
||||
api_key_env: str = "OPENAI_API_KEY"
|
||||
timeout_s: float = 120.0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "LLMEndpointSpec":
|
||||
return cls(
|
||||
base_url=_require_str(data.get("base_url"), context="llm.endpoint.base_url"),
|
||||
model=_require_str(data.get("model"), context="llm.endpoint.model"),
|
||||
api_key_env=str(data.get("api_key_env") or "OPENAI_API_KEY").strip(),
|
||||
timeout_s=_require_float(
|
||||
data.get("timeout_s", 120.0), context="llm.endpoint.timeout_s"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMPolicySpec:
|
||||
endpoint: LLMEndpointSpec | None
|
||||
system_prompt: str
|
||||
max_history_trials: int
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any] | None) -> "LLMPolicySpec":
|
||||
payload = _require_mapping(data or {}, context="llm")
|
||||
endpoint = (
|
||||
LLMEndpointSpec.from_dict(
|
||||
_require_mapping(payload["endpoint"], context="llm.endpoint")
|
||||
)
|
||||
if payload.get("endpoint")
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
endpoint=endpoint,
|
||||
system_prompt=str(payload.get("system_prompt") or "").strip(),
|
||||
max_history_trials=_require_int(
|
||||
payload.get("max_history_trials", 8), context="llm.max_history_trials"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StudySpec:
|
||||
study_id: str
|
||||
hardware: HardwareSpec
|
||||
model: ModelSpec
|
||||
engine: EngineLaunchSpec
|
||||
trace: TraceSpec
|
||||
slo: SloSpec
|
||||
search: SamplingSearchSpec
|
||||
llm: LLMPolicySpec
|
||||
capability_profile_path: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "StudySpec":
|
||||
return cls(
|
||||
study_id=_require_str(data.get("study_id"), context="study_id"),
|
||||
hardware=HardwareSpec.from_dict(
|
||||
_require_mapping(data.get("hardware"), context="hardware")
|
||||
),
|
||||
model=ModelSpec.from_dict(_require_mapping(data.get("model"), context="model")),
|
||||
engine=EngineLaunchSpec.from_dict(
|
||||
_require_mapping(data.get("engine"), context="engine")
|
||||
),
|
||||
trace=TraceSpec.from_dict(_require_mapping(data.get("trace"), context="trace")),
|
||||
slo=SloSpec.from_dict(_require_mapping(data.get("slo"), context="slo")),
|
||||
search=SamplingSearchSpec.from_dict(
|
||||
_require_mapping(data.get("search"), context="search")
|
||||
),
|
||||
llm=LLMPolicySpec.from_dict(data.get("llm")),
|
||||
capability_profile_path=str(data["capability_profile_path"]).strip()
|
||||
if data.get("capability_profile_path")
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConfigPatch:
|
||||
env_patch: dict[str, str] = field(default_factory=dict)
|
||||
flag_patch: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "ConfigPatch":
|
||||
return cls(
|
||||
env_patch=_coerce_str_map(data.get("env_patch"), context="config_patch.env_patch"),
|
||||
flag_patch=_coerce_any_map(
|
||||
data.get("flag_patch"), context="config_patch.flag_patch"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Proposal:
|
||||
observation: str
|
||||
diagnosis: str
|
||||
config_patch: ConfigPatch
|
||||
expected_effects: list[str]
|
||||
why_not_previous_failures: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "Proposal":
|
||||
return cls(
|
||||
observation=_require_str(data.get("observation"), context="proposal.observation"),
|
||||
diagnosis=_require_str(data.get("diagnosis"), context="proposal.diagnosis"),
|
||||
config_patch=ConfigPatch.from_dict(
|
||||
_require_mapping(data.get("config_patch"), context="proposal.config_patch")
|
||||
),
|
||||
expected_effects=_coerce_str_list(
|
||||
data.get("expected_effects"), context="proposal.expected_effects"
|
||||
),
|
||||
why_not_previous_failures=str(data.get("why_not_previous_failures") or "").strip(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrialSpec:
|
||||
study_id: str
|
||||
trial_id: str
|
||||
config_patch: ConfigPatch
|
||||
search: SamplingSearchSpec
|
||||
study_spec_path: str
|
||||
artifact_dir: str
|
||||
probe_log_path: str
|
||||
engine_log_path: str
|
||||
result_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrialSummary:
|
||||
trial_id: str
|
||||
status: str
|
||||
best_sampling_u: float | None = None
|
||||
best_request_rate: float | None = None
|
||||
best_pass_rate: float | None = None
|
||||
result_path: str | None = None
|
||||
diagnosis: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StudyState:
|
||||
study_id: str
|
||||
best_trial_id: str | None = None
|
||||
best_request_rate: float | None = None
|
||||
next_trial_index: int = 1
|
||||
trials: list[TrialSummary] = field(default_factory=list)
|
||||
|
||||
|
||||
def to_jsonable(value: Any) -> Any:
|
||||
if is_dataclass(value):
|
||||
return {key: to_jsonable(item) for key, item in asdict(value).items()}
|
||||
if isinstance(value, dict):
|
||||
return {str(key): to_jsonable(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [to_jsonable(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def load_structured_file(path: Path) -> Mapping[str, Any]:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix == ".json":
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
elif suffix in {".toml", ".tml"}:
|
||||
payload = tomllib.loads(path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
raise SpecError(f"Unsupported spec file type: {path}")
|
||||
return _require_mapping(payload, context=str(path))
|
||||
|
||||
|
||||
def load_study_spec(path: Path) -> StudySpec:
|
||||
return StudySpec.from_dict(load_structured_file(path))
|
||||
118
src/aituner/store.py
Normal file
118
src/aituner/store.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .spec import Proposal, StudySpec, StudyState, TrialSpec, TrialSummary, to_jsonable
|
||||
|
||||
|
||||
class StudyStore:
|
||||
def __init__(self, root: Path | None = None):
|
||||
base = root or Path(".aituner") / "studies"
|
||||
self.root = base.resolve()
|
||||
|
||||
def study_root(self, study_id: str) -> Path:
|
||||
return self.root / study_id
|
||||
|
||||
def init_study(self, *, spec_path: Path, study: StudySpec) -> Path:
|
||||
root = self.study_root(study.study_id)
|
||||
for rel in ("prompts", "proposals", "trials", "results"):
|
||||
(root / rel).mkdir(parents=True, exist_ok=True)
|
||||
(root / "study_spec.source").write_text(str(spec_path.resolve()) + "\n", encoding="utf-8")
|
||||
self.write_json(root / "study_spec.snapshot.json", to_jsonable(study))
|
||||
if not (root / "state.json").exists():
|
||||
self.write_json(root / "state.json", to_jsonable(StudyState(study_id=study.study_id)))
|
||||
return root
|
||||
|
||||
def load_state(self, study_id: str) -> StudyState:
|
||||
payload = json.loads((self.study_root(study_id) / "state.json").read_text(encoding="utf-8"))
|
||||
trials = [TrialSummary(**item) for item in payload.get("trials", [])]
|
||||
return StudyState(
|
||||
study_id=str(payload["study_id"]),
|
||||
best_trial_id=payload.get("best_trial_id"),
|
||||
best_request_rate=payload.get("best_request_rate"),
|
||||
next_trial_index=int(payload.get("next_trial_index", 1)),
|
||||
trials=trials,
|
||||
)
|
||||
|
||||
def save_state(self, state: StudyState) -> None:
|
||||
self.write_json(self.study_root(state.study_id) / "state.json", to_jsonable(state))
|
||||
|
||||
def write_prompt(self, study_id: str, prompt_name: str, prompt_text: str) -> Path:
|
||||
path = self.study_root(study_id) / "prompts" / f"{prompt_name}.txt"
|
||||
path.write_text(prompt_text, encoding="utf-8")
|
||||
return path
|
||||
|
||||
def write_proposal(self, study_id: str, proposal_name: str, proposal: Proposal) -> Path:
|
||||
path = self.study_root(study_id) / "proposals" / f"{proposal_name}.json"
|
||||
self.write_json(path, to_jsonable(proposal))
|
||||
return path
|
||||
|
||||
def materialize_trial(
|
||||
self,
|
||||
*,
|
||||
study: StudySpec,
|
||||
state: StudyState,
|
||||
proposal: Proposal,
|
||||
) -> tuple[TrialSpec, StudyState]:
|
||||
trial_id = f"trial-{state.next_trial_index:04d}"
|
||||
trial_root = self.study_root(study.study_id) / "trials" / trial_id
|
||||
trial_root.mkdir(parents=True, exist_ok=True)
|
||||
spec = TrialSpec(
|
||||
study_id=study.study_id,
|
||||
trial_id=trial_id,
|
||||
config_patch=proposal.config_patch,
|
||||
search=study.search,
|
||||
study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()),
|
||||
artifact_dir=str(trial_root),
|
||||
probe_log_path=str(trial_root / "probe_history.json"),
|
||||
engine_log_path=str(trial_root / "engine.log"),
|
||||
result_path=str(trial_root / "result.json"),
|
||||
)
|
||||
self.write_json(trial_root / "trial_spec.json", to_jsonable(spec))
|
||||
next_state = replace(state, next_trial_index=state.next_trial_index + 1)
|
||||
next_state.trials.append(
|
||||
TrialSummary(trial_id=trial_id, status="queued", diagnosis=proposal.diagnosis)
|
||||
)
|
||||
self.save_state(next_state)
|
||||
return spec, next_state
|
||||
|
||||
def ingest_trial_results(self, study_id: str) -> StudyState:
|
||||
state = self.load_state(study_id)
|
||||
by_id = {item.trial_id: item for item in state.trials}
|
||||
trials_dir = self.study_root(study_id) / "trials"
|
||||
best_trial_id = state.best_trial_id
|
||||
best_rate = state.best_request_rate
|
||||
for trial_dir in sorted(trials_dir.glob("trial-*")):
|
||||
result_path = trial_dir / "result.json"
|
||||
if not result_path.exists():
|
||||
continue
|
||||
payload = json.loads(result_path.read_text(encoding="utf-8"))
|
||||
trial_id = str(payload["trial_id"])
|
||||
summary = by_id.get(trial_id)
|
||||
if summary is None:
|
||||
summary = TrialSummary(trial_id=trial_id, status="unknown")
|
||||
state.trials.append(summary)
|
||||
by_id[trial_id] = summary
|
||||
summary.status = str(payload.get("status") or "completed")
|
||||
summary.best_sampling_u = payload.get("best_sampling_u")
|
||||
summary.best_request_rate = payload.get("best_request_rate")
|
||||
summary.best_pass_rate = payload.get("best_pass_rate")
|
||||
summary.result_path = str(result_path)
|
||||
if (
|
||||
isinstance(summary.best_request_rate, (int, float))
|
||||
and (best_rate is None or summary.best_request_rate > best_rate)
|
||||
):
|
||||
best_rate = float(summary.best_request_rate)
|
||||
best_trial_id = trial_id
|
||||
state.best_request_rate = best_rate
|
||||
state.best_trial_id = best_trial_id
|
||||
self.save_state(state)
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def write_json(path: Path, payload: Any) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
178
src/aituner/trace.py
Normal file
178
src/aituner/trace.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .spec import StudySpec
|
||||
|
||||
|
||||
class TraceError(ValueError):
|
||||
"""Raised when trace assets are invalid."""
|
||||
|
||||
|
||||
def _percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
ordered = sorted(values)
|
||||
idx = min(len(ordered) - 1, max(0, math.ceil((p / 100.0) * len(ordered)) - 1))
|
||||
return ordered[idx]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WindowRecord:
|
||||
window_id: str
|
||||
trace_path: Path
|
||||
trace_type: str
|
||||
window_start: float
|
||||
window_end: float
|
||||
source_payload: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TraceRequest:
|
||||
row_id: str
|
||||
arrival_s: float
|
||||
sampling_u: float
|
||||
body: dict[str, Any]
|
||||
prompt_tokens_hint: int | None
|
||||
completion_tokens_hint: int | None
|
||||
|
||||
|
||||
def resolve_window_record(study: StudySpec, *, study_spec_path: Path) -> WindowRecord:
|
||||
windows_path = Path(study.trace.windows_path)
|
||||
if not windows_path.is_absolute():
|
||||
windows_path = (study_spec_path.parent / windows_path).resolve()
|
||||
payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
||||
windows = payload["windows"] if isinstance(payload, Mapping) and "windows" in payload else payload
|
||||
if not isinstance(windows, list):
|
||||
raise TraceError(f"windows payload must contain a list: {windows_path}")
|
||||
for item in windows:
|
||||
if not isinstance(item, Mapping):
|
||||
continue
|
||||
if str(item.get("window_id") or "").strip() != study.trace.window_id:
|
||||
continue
|
||||
trace_file = study.trace.trace_file_override or str(item.get("trace_file") or "").strip()
|
||||
if not trace_file:
|
||||
raise TraceError(f"window {study.trace.window_id} does not define trace_file")
|
||||
trace_path = Path(trace_file)
|
||||
if not trace_path.is_absolute():
|
||||
trace_path = (windows_path.parent / trace_path).resolve()
|
||||
return WindowRecord(
|
||||
window_id=study.trace.window_id,
|
||||
trace_path=trace_path,
|
||||
trace_type=str(item.get("trace_type") or "chat").strip(),
|
||||
window_start=float(item.get("window_start") or 0.0),
|
||||
window_end=float(item.get("window_end") or 0.0),
|
||||
source_payload={str(key): value for key, value in item.items()},
|
||||
)
|
||||
raise TraceError(f"window_id not found: {study.trace.window_id}")
|
||||
|
||||
|
||||
def _coerce_messages(row: Mapping[str, Any]) -> list[dict[str, Any]]:
|
||||
messages = row.get("messages")
|
||||
if isinstance(messages, list) and messages:
|
||||
return [dict(item) for item in messages if isinstance(item, Mapping)]
|
||||
prompt = row.get("prompt") or row.get("input") or row.get("text")
|
||||
if isinstance(prompt, str) and prompt.strip():
|
||||
return [{"role": "user", "content": prompt}]
|
||||
raise TraceError("trace row is missing chat messages/prompt text")
|
||||
|
||||
|
||||
def _coerce_completion_tokens(row: Mapping[str, Any]) -> int | None:
|
||||
for key in ("max_completion_tokens", "max_tokens", "output_length", "completion_tokens"):
|
||||
value = row.get(key)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value >= 0:
|
||||
return value
|
||||
if isinstance(value, float) and value >= 0:
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_prompt_tokens(row: Mapping[str, Any]) -> int | None:
|
||||
for key in ("input_length", "prompt_length", "prompt_len", "input_tokens"):
|
||||
value = row.get(key)
|
||||
if isinstance(value, bool):
|
||||
continue
|
||||
if isinstance(value, int) and value >= 0:
|
||||
return value
|
||||
if isinstance(value, float) and value >= 0:
|
||||
return int(value)
|
||||
return None
|
||||
|
||||
|
||||
def load_trace_requests(study: StudySpec, *, study_spec_path: Path) -> tuple[WindowRecord, list[TraceRequest]]:
|
||||
window = resolve_window_record(study, study_spec_path=study_spec_path)
|
||||
requests: list[TraceRequest] = []
|
||||
with window.trace_path.open("r", encoding="utf-8") as handle:
|
||||
for idx, raw in enumerate(handle):
|
||||
if not raw.strip():
|
||||
continue
|
||||
row = json.loads(raw)
|
||||
if not isinstance(row, Mapping):
|
||||
continue
|
||||
timestamp = row.get(study.trace.timestamp_field)
|
||||
if timestamp is None:
|
||||
timestamp = row.get("arrival_time", row.get("timestamp"))
|
||||
if isinstance(timestamp, bool) or not isinstance(timestamp, (int, float)):
|
||||
raise TraceError(f"trace row {idx} is missing numeric timestamp")
|
||||
sampling_u = row.get(study.trace.u_field, 1.0)
|
||||
if isinstance(sampling_u, bool) or not isinstance(sampling_u, (int, float)):
|
||||
raise TraceError(f"trace row {idx} is missing numeric {study.trace.u_field}")
|
||||
body: dict[str, Any] = {
|
||||
"model": study.model.served_model_name,
|
||||
"messages": _coerce_messages(row),
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
completion_tokens = _coerce_completion_tokens(row)
|
||||
if completion_tokens is not None:
|
||||
body["max_tokens"] = completion_tokens
|
||||
temperature = row.get("temperature")
|
||||
if isinstance(temperature, (int, float)) and not isinstance(temperature, bool):
|
||||
body["temperature"] = temperature
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
row_id=str(row.get("request_id") or row.get("id") or idx),
|
||||
arrival_s=float(timestamp),
|
||||
sampling_u=float(sampling_u),
|
||||
body=body,
|
||||
prompt_tokens_hint=_coerce_prompt_tokens(row),
|
||||
completion_tokens_hint=completion_tokens,
|
||||
)
|
||||
)
|
||||
requests.sort(key=lambda item: item.arrival_s)
|
||||
if study.trace.max_requests_per_probe is not None:
|
||||
requests = requests[: study.trace.max_requests_per_probe]
|
||||
return window, requests
|
||||
|
||||
|
||||
def summarize_window(requests: list[TraceRequest], window: WindowRecord) -> dict[str, Any]:
|
||||
prompt_tokens = [float(item.prompt_tokens_hint or 0) for item in requests]
|
||||
completion_tokens = [float(item.completion_tokens_hint or 0) for item in requests]
|
||||
duration = max(window.window_end - window.window_start, 0.0) or (
|
||||
requests[-1].arrival_s - requests[0].arrival_s if len(requests) >= 2 else 0.0
|
||||
)
|
||||
qps = (len(requests) / duration) if duration > 0 else 0.0
|
||||
return {
|
||||
"window_id": window.window_id,
|
||||
"trace_path": str(window.trace_path),
|
||||
"trace_type": window.trace_type,
|
||||
"request_count": len(requests),
|
||||
"duration_s": duration,
|
||||
"request_rate": qps,
|
||||
"prompt_tokens_p50": _percentile(prompt_tokens, 50.0),
|
||||
"prompt_tokens_p95": _percentile(prompt_tokens, 95.0),
|
||||
"completion_tokens_p50": _percentile(completion_tokens, 50.0),
|
||||
"completion_tokens_p95": _percentile(completion_tokens, 95.0),
|
||||
}
|
||||
|
||||
|
||||
def select_requests_for_threshold(
|
||||
requests: list[TraceRequest], *, threshold: float
|
||||
) -> list[TraceRequest]:
|
||||
return [item for item in requests if item.sampling_u <= threshold]
|
||||
215
src/aituner/worker.py
Normal file
215
src/aituner/worker.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .engine import build_launch_recipe
|
||||
from .http_client import HttpClientError, stream_chat_completion, wait_for_server
|
||||
from .search import ThresholdProbe, binary_search_max_feasible
|
||||
from .slo import RequestOutcome, summarize_evaluations
|
||||
from .spec import ConfigPatch, SamplingSearchSpec, TrialSpec, load_study_spec
|
||||
from .trace import TraceRequest, load_trace_requests, select_requests_for_threshold
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProbePayload:
|
||||
threshold: float
|
||||
request_count: int
|
||||
pass_rate: float
|
||||
request_rate: float
|
||||
feasible: bool
|
||||
outcomes: list[dict[str, Any]]
|
||||
|
||||
def _trial_spec_from_json(path: Path) -> TrialSpec:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
return TrialSpec(
|
||||
study_id=str(payload["study_id"]),
|
||||
trial_id=str(payload["trial_id"]),
|
||||
config_patch=ConfigPatch.from_dict(payload["config_patch"]),
|
||||
search=SamplingSearchSpec.from_dict(payload["search"]),
|
||||
study_spec_path=str(payload["study_spec_path"]),
|
||||
artifact_dir=str(payload["artifact_dir"]),
|
||||
probe_log_path=str(payload["probe_log_path"]),
|
||||
engine_log_path=str(payload["engine_log_path"]),
|
||||
result_path=str(payload["result_path"]),
|
||||
)
|
||||
|
||||
|
||||
def _run_one_request(
|
||||
request: TraceRequest,
|
||||
*,
|
||||
base_url: str,
|
||||
timeout_s: float,
|
||||
) -> RequestOutcome:
|
||||
try:
|
||||
metrics = stream_chat_completion(base_url=base_url, body=request.body, timeout_s=timeout_s)
|
||||
return RequestOutcome(
|
||||
request_id=request.row_id,
|
||||
success=True,
|
||||
ttft_ms=metrics.ttft_ms,
|
||||
tpot_ms=metrics.tpot_ms,
|
||||
prompt_tokens=request.prompt_tokens_hint,
|
||||
completion_tokens=metrics.completion_tokens or request.completion_tokens_hint,
|
||||
)
|
||||
except HttpClientError as exc:
|
||||
return RequestOutcome(
|
||||
request_id=request.row_id,
|
||||
success=False,
|
||||
ttft_ms=None,
|
||||
tpot_ms=None,
|
||||
prompt_tokens=request.prompt_tokens_hint,
|
||||
completion_tokens=request.completion_tokens_hint,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
def _replay_requests(
|
||||
requests: list[TraceRequest],
|
||||
*,
|
||||
base_url: str,
|
||||
timeout_s: float,
|
||||
max_concurrency: int,
|
||||
) -> list[RequestOutcome]:
|
||||
outcomes_by_id: dict[str, RequestOutcome] = {}
|
||||
lock = threading.Lock()
|
||||
start = time.monotonic()
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
futures = []
|
||||
for request in requests:
|
||||
delay = max(0.0, request.arrival_s)
|
||||
now = time.monotonic()
|
||||
sleep_for = (start + delay) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
futures.append(
|
||||
pool.submit(
|
||||
_run_one_request,
|
||||
request,
|
||||
base_url=base_url,
|
||||
timeout_s=timeout_s,
|
||||
)
|
||||
)
|
||||
for future in as_completed(futures):
|
||||
outcome = future.result()
|
||||
with lock:
|
||||
outcomes_by_id[outcome.request_id] = outcome
|
||||
return [outcomes_by_id[item.row_id] for item in requests if item.row_id in outcomes_by_id]
|
||||
|
||||
|
||||
def run_trial(trial_spec_path: Path) -> dict[str, Any]:
|
||||
from .store import StudyStore
|
||||
|
||||
trial = _trial_spec_from_json(trial_spec_path)
|
||||
study_spec_path = Path(Path(trial.study_spec_path).read_text(encoding="utf-8").strip())
|
||||
study = load_study_spec(study_spec_path)
|
||||
window, requests = load_trace_requests(study, study_spec_path=study_spec_path)
|
||||
recipe = build_launch_recipe(study.engine, trial.config_patch)
|
||||
artifact_dir = Path(trial.artifact_dir)
|
||||
artifact_dir.mkdir(parents=True, exist_ok=True)
|
||||
engine_log_path = Path(trial.engine_log_path)
|
||||
with engine_log_path.open("w", encoding="utf-8") as engine_log:
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
recipe.argv,
|
||||
cwd=recipe.cwd,
|
||||
env=recipe.env,
|
||||
stdout=engine_log,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
try:
|
||||
wait_for_server(recipe.base_url, recipe.healthcheck_path, recipe.ready_timeout_s)
|
||||
probe_history: list[dict[str, Any]] = []
|
||||
|
||||
def evaluator(threshold: float) -> ThresholdProbe[ProbePayload]:
|
||||
selected = select_requests_for_threshold(requests, threshold=threshold)
|
||||
outcomes = _replay_requests(
|
||||
selected,
|
||||
base_url=recipe.base_url,
|
||||
timeout_s=recipe.request_timeout_s,
|
||||
max_concurrency=study.trace.max_concurrency,
|
||||
)
|
||||
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||
request_rate = (
|
||||
len(selected) / max(window.window_end - window.window_start, 1e-9)
|
||||
if selected
|
||||
else 0.0
|
||||
)
|
||||
payload = ProbePayload(
|
||||
threshold=threshold,
|
||||
request_count=len(selected),
|
||||
pass_rate=float(summary["slo_pass_rate"]),
|
||||
request_rate=request_rate,
|
||||
feasible=bool(summary["feasible"]),
|
||||
outcomes=[
|
||||
{
|
||||
"request_id": outcome.request_id,
|
||||
"success": outcome.success,
|
||||
"ttft_ms": outcome.ttft_ms,
|
||||
"tpot_ms": outcome.tpot_ms,
|
||||
"prompt_tokens": outcome.prompt_tokens,
|
||||
"completion_tokens": outcome.completion_tokens,
|
||||
"evaluation": evaluation.passed,
|
||||
"reasons": evaluation.reasons,
|
||||
}
|
||||
for outcome, evaluation in zip(outcomes, evaluations)
|
||||
],
|
||||
)
|
||||
probe_record = {
|
||||
"threshold": threshold,
|
||||
"request_count": payload.request_count,
|
||||
"pass_rate": payload.pass_rate,
|
||||
"request_rate": payload.request_rate,
|
||||
"feasible": payload.feasible,
|
||||
}
|
||||
probe_history.append(probe_record)
|
||||
StudyStore.write_json(Path(trial.probe_log_path), probe_history)
|
||||
return ThresholdProbe(
|
||||
threshold=threshold,
|
||||
feasible=payload.feasible,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
search = binary_search_max_feasible(
|
||||
low=trial.search.low,
|
||||
high=trial.search.high,
|
||||
tolerance=trial.search.tolerance,
|
||||
max_probes=trial.search.max_probes,
|
||||
evaluator=evaluator,
|
||||
)
|
||||
best = search.best_feasible_payload
|
||||
result = {
|
||||
"study_id": trial.study_id,
|
||||
"trial_id": trial.trial_id,
|
||||
"status": "completed",
|
||||
"best_sampling_u": search.best_threshold if best is not None else None,
|
||||
"best_request_rate": best.request_rate if best is not None else None,
|
||||
"best_pass_rate": best.pass_rate if best is not None else None,
|
||||
"best_request_count": best.request_count if best is not None else None,
|
||||
"probes": [
|
||||
{
|
||||
"threshold": probe.threshold,
|
||||
"feasible": probe.feasible,
|
||||
"payload": {
|
||||
"request_count": probe.payload.request_count,
|
||||
"pass_rate": probe.payload.pass_rate,
|
||||
"request_rate": probe.payload.request_rate,
|
||||
},
|
||||
}
|
||||
for probe in search.probes
|
||||
],
|
||||
}
|
||||
StudyStore.write_json(Path(trial.result_path), result)
|
||||
return result
|
||||
finally:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=30)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
process.wait(timeout=30)
|
||||
10
tests/conftest.py
Normal file
10
tests/conftest.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
267
tests/test_core_flow.py
Normal file
267
tests/test_core_flow.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from aituner.job import append_job, build_trial_job
|
||||
from aituner.llm import build_prompt, parse_proposal_text
|
||||
from aituner.search import ThresholdProbe, binary_search_max_feasible
|
||||
from aituner.slo import RequestOutcome, summarize_evaluations
|
||||
from aituner.spec import Proposal, load_study_spec
|
||||
from aituner.store import StudyStore
|
||||
from aituner.trace import load_trace_requests, summarize_window
|
||||
|
||||
|
||||
def _write_study_assets(tmp_path: Path) -> Path:
|
||||
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||
trace_dir.mkdir(parents=True)
|
||||
trace_path = trace_dir / "chat_w1.jsonl"
|
||||
rows = [
|
||||
{
|
||||
"request_id": "r1",
|
||||
"timestamp": 0.0,
|
||||
"sampling_u": 0.10,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"input_length": 1000,
|
||||
"output_length": 16
|
||||
},
|
||||
{
|
||||
"request_id": "r2",
|
||||
"timestamp": 1.0,
|
||||
"sampling_u": 0.50,
|
||||
"messages": [{"role": "user", "content": "world"}],
|
||||
"input_length": 5000,
|
||||
"output_length": 32
|
||||
},
|
||||
{
|
||||
"request_id": "r3",
|
||||
"timestamp": 2.0,
|
||||
"sampling_u": 0.90,
|
||||
"messages": [{"role": "user", "content": "!"}],
|
||||
"input_length": 20000,
|
||||
"output_length": 64
|
||||
}
|
||||
]
|
||||
with trace_path.open("w", encoding="utf-8") as handle:
|
||||
for row in rows:
|
||||
handle.write(json.dumps(row) + "\n")
|
||||
|
||||
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||
windows_payload = {
|
||||
"u_field": "sampling_u",
|
||||
"windows": [
|
||||
{
|
||||
"window_id": "chat_w1",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_w1.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0
|
||||
}
|
||||
]
|
||||
}
|
||||
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
|
||||
|
||||
capability_path = tmp_path / "capability.json"
|
||||
capability_path.write_text(
|
||||
json.dumps({"prefill_service_by_bucket": {"4k": {"tp4_ms": 320, "tp8_ms": 240}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
study_path = tmp_path / "study.json"
|
||||
study_payload = {
|
||||
"study_id": "study-1",
|
||||
"hardware": {"gpu_count": 8, "gpu_model": "H20", "host_candidates": ["dash0"]},
|
||||
"model": {
|
||||
"model_id": "qwen",
|
||||
"served_model_name": "Qwen/Qwen3-30B-A3B-Instruct-2507"
|
||||
},
|
||||
"engine": {
|
||||
"engine_name": "vllm",
|
||||
"engine_version": "0.1",
|
||||
"exec_path": "/usr/local/bin/vllm",
|
||||
"cwd": str(tmp_path),
|
||||
"host": "127.0.0.1",
|
||||
"port": 8000,
|
||||
"healthcheck_path": "/v1/models",
|
||||
"ready_timeout_s": 30,
|
||||
"request_timeout_s": 30,
|
||||
"launch_args": ["serve", "/models/qwen"],
|
||||
"base_envs": {"BASE_ENV": "1"},
|
||||
"base_flags": {"host": "127.0.0.1", "port": 8000},
|
||||
"tunable_envs": ["VLLM_ATTENTION_BACKEND"],
|
||||
"tunable_flags": ["tensor-parallel-size", "max-num-seqs"],
|
||||
"python_executable": "python3"
|
||||
},
|
||||
"trace": {
|
||||
"windows_path": str(windows_path),
|
||||
"window_id": "chat_w1",
|
||||
"u_field": "sampling_u",
|
||||
"timestamp_field": "timestamp",
|
||||
"max_concurrency": 4
|
||||
},
|
||||
"slo": {
|
||||
"target_pass_rate": 0.95,
|
||||
"ttft_rule": {
|
||||
"kind": "step_ms",
|
||||
"buckets": [
|
||||
{"max_input_tokens": 4096, "threshold_ms": 2000},
|
||||
{"max_input_tokens": 16384, "threshold_ms": 5000},
|
||||
{"threshold_ms": 9000}
|
||||
]
|
||||
},
|
||||
"tpot_rule": {"kind": "fixed_ms", "threshold_ms": 120}
|
||||
},
|
||||
"search": {
|
||||
"low": 0.0,
|
||||
"high": 1.0,
|
||||
"tolerance": 0.01,
|
||||
"max_probes": 8,
|
||||
"sample_seed": 20260325
|
||||
},
|
||||
"llm": {"system_prompt": "Tune it.", "max_history_trials": 8},
|
||||
"capability_profile_path": str(capability_path)
|
||||
}
|
||||
study_path.write_text(json.dumps(study_payload), encoding="utf-8")
|
||||
return study_path
|
||||
|
||||
|
||||
class CoreFlowTests(unittest.TestCase):
|
||||
def test_trace_and_prompt_flow(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
study = load_study_spec(study_path)
|
||||
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||
study_root = store.init_study(spec_path=study_path, study=study)
|
||||
state = store.load_state(study.study_id)
|
||||
|
||||
window, requests = load_trace_requests(study, study_spec_path=study_path)
|
||||
summary = summarize_window(requests, window)
|
||||
self.assertEqual(summary["request_count"], 3)
|
||||
self.assertEqual(summary["request_rate"], 0.3)
|
||||
|
||||
prompt = build_prompt(
|
||||
study=study,
|
||||
window_summary=summary,
|
||||
state=state,
|
||||
capability_profile={"queueing_knee_by_bucket": {"4k": 1000}},
|
||||
)
|
||||
self.assertIn("allowed_flag_keys", prompt)
|
||||
self.assertIn("study-1", prompt)
|
||||
self.assertIn("queueing_knee_by_bucket", prompt)
|
||||
self.assertTrue(study_root.exists())
|
||||
|
||||
def test_slo_evaluation_step_and_fixed_rules(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
study = load_study_spec(_write_study_assets(Path(tmp)))
|
||||
outcomes = [
|
||||
RequestOutcome(
|
||||
request_id="r1",
|
||||
success=True,
|
||||
ttft_ms=1000,
|
||||
tpot_ms=100,
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=16,
|
||||
),
|
||||
RequestOutcome(
|
||||
request_id="r2",
|
||||
success=True,
|
||||
ttft_ms=6000,
|
||||
tpot_ms=100,
|
||||
prompt_tokens=5000,
|
||||
completion_tokens=16,
|
||||
),
|
||||
]
|
||||
evaluations, summary = summarize_evaluations(outcomes, study.slo)
|
||||
self.assertTrue(evaluations[0].passed)
|
||||
self.assertFalse(evaluations[1].passed)
|
||||
self.assertEqual(summary["slo_pass_rate"], 0.5)
|
||||
self.assertFalse(summary["feasible"])
|
||||
|
||||
def test_binary_search_max_feasible(self) -> None:
|
||||
result = binary_search_max_feasible(
|
||||
low=0.0,
|
||||
high=1.0,
|
||||
tolerance=0.01,
|
||||
max_probes=8,
|
||||
evaluator=lambda threshold: ThresholdProbe(
|
||||
threshold=threshold,
|
||||
feasible=threshold <= 0.625,
|
||||
payload={"threshold": threshold},
|
||||
),
|
||||
)
|
||||
self.assertLessEqual(result.best_threshold, 0.625)
|
||||
self.assertGreaterEqual(result.best_threshold, 0.5)
|
||||
self.assertIsNotNone(result.best_feasible_payload)
|
||||
|
||||
def test_proposal_validation_and_job_emission(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
study = load_study_spec(study_path)
|
||||
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||
store.init_study(spec_path=study_path, study=study)
|
||||
state = store.load_state(study.study_id)
|
||||
|
||||
proposal_text = json.dumps(
|
||||
{
|
||||
"observation": "Current TTFT fails before TPOT.",
|
||||
"diagnosis": "Prefill pressure dominates.",
|
||||
"config_patch": {
|
||||
"env_patch": {"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
"flag_patch": {"tensor-parallel-size": 4, "max-num-seqs": 64}
|
||||
},
|
||||
"expected_effects": ["lower TTFT", "raise feasible sampling_u"],
|
||||
"why_not_previous_failures": "Avoids changing unsupported envs."
|
||||
}
|
||||
)
|
||||
proposal = parse_proposal_text(proposal_text, study)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
|
||||
job = build_trial_job(study=study, trial=trial, repo_root=tmp_path)
|
||||
jobs_path = tmp_path / "jobs.toml"
|
||||
append_job(jobs_path, job)
|
||||
rendered = jobs_path.read_text(encoding="utf-8")
|
||||
self.assertIn('name = "study-1-trial-0001"', rendered)
|
||||
self.assertIn('command = "python3 -m aituner.cli worker run-trial', rendered)
|
||||
self.assertIn('PYTHONPATH = "src"', rendered)
|
||||
|
||||
def test_ingest_trial_results_updates_best(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
study = load_study_spec(study_path)
|
||||
store = StudyStore(tmp_path / ".aituner" / "studies")
|
||||
store.init_study(spec_path=study_path, study=study)
|
||||
state = store.load_state(study.study_id)
|
||||
proposal = Proposal.from_dict(
|
||||
{
|
||||
"observation": "Obs",
|
||||
"diagnosis": "Diag",
|
||||
"config_patch": {"env_patch": {}, "flag_patch": {"tensor-parallel-size": 4}},
|
||||
"expected_effects": ["raise rate"]
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
Path(trial.result_path).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"study_id": study.study_id,
|
||||
"trial_id": trial.trial_id,
|
||||
"status": "completed",
|
||||
"best_sampling_u": 0.75,
|
||||
"best_request_rate": 12.5,
|
||||
"best_pass_rate": 0.97
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
next_state = store.ingest_trial_results(study.study_id)
|
||||
self.assertEqual(next_state.best_trial_id, trial.trial_id)
|
||||
self.assertEqual(next_state.best_request_rate, 12.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user