Add multi-window baseline vs tuned compare flow

This commit is contained in:
2026-04-11 13:51:54 +08:00
parent a0b2d7eab2
commit 5e54e9c8f5
5 changed files with 860 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
{
"compare_id": "example-weekly-compare",
"study_spec_path": "study.example.json",
"window_selector": {
"trace_type": "chat",
"date_prefix": "2026-03-2",
"slot_token": "1000"
},
"baseline": {
"config_patch": {
"env_patch": {},
"flag_patch": {}
}
},
"tuned": {
"trial_ref": {
"study_root": "/abs/path/to/.aituner-tight/example-study",
"trial_id": "trial-0004"
}
}
}

View File

@@ -5,6 +5,7 @@ import json
import sys
from pathlib import Path
from .compare import run_compare
from .job import append_job, build_trial_job
from .llm import build_prompt, call_llm_for_proposal, load_capability_profile, parse_proposal_text
from .spec import Proposal, SpecError, load_study_spec
@@ -183,6 +184,25 @@ def cmd_worker_run_trial(args: argparse.Namespace) -> int:
return 0
def cmd_compare_run(args: argparse.Namespace) -> int:
summary = run_compare(
Path(args.spec).resolve(),
output_root=Path(args.output_root).resolve() if args.output_root else None,
)
print(
json.dumps(
{
"compare_id": summary["compare_id"],
"compare_root": summary["compare_root"],
"window_count": summary["aggregate"]["window_count"],
"wins": summary["aggregate"]["wins"],
},
ensure_ascii=False,
)
)
return 0
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="AITuner CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
@@ -239,6 +259,13 @@ def build_parser() -> argparse.ArgumentParser:
run.add_argument("--trial-spec", required=True)
run.set_defaults(func=cmd_worker_run_trial)
compare = subparsers.add_parser("compare")
compare_sub = compare.add_subparsers(dest="compare_command", required=True)
compare_run = compare_sub.add_parser("run")
compare_run.add_argument("--spec", required=True)
compare_run.add_argument("--output-root")
compare_run.set_defaults(func=cmd_compare_run)
return parser

436
src/aituner/compare.py Normal file
View File

@@ -0,0 +1,436 @@
from __future__ import annotations
import json
from dataclasses import replace
from pathlib import Path
from typing import Any, Mapping
from .spec import (
CompareCandidateSpec,
CompareSpec,
ConfigPatch,
SpecError,
StudySpec,
TrialSpec,
load_structured_file,
load_study_spec,
)
from .store import StudyStore
from .worker import run_trial
def load_compare_spec(path: Path) -> CompareSpec:
return CompareSpec.from_dict(load_structured_file(path))
def run_compare(compare_spec_path: Path, *, output_root: Path | None = None) -> dict[str, Any]:
compare_spec_path = compare_spec_path.resolve()
compare = load_compare_spec(compare_spec_path)
base_study_path = _resolve_path(compare.study_spec_path, base_dir=compare_spec_path.parent)
study = load_study_spec(base_study_path)
compare_root = (output_root or (Path(".aituner-compare") / compare.compare_id)).resolve()
compare_root.mkdir(parents=True, exist_ok=True)
StudyStore.write_json(compare_root / "compare_spec.snapshot.json", _compare_snapshot(compare))
windows = _select_windows(compare, study=study, study_spec_path=base_study_path)
baseline_patch, baseline_source = _resolve_candidate_config(
compare.baseline,
compare_spec_path=compare_spec_path,
)
tuned_patch, tuned_source = _resolve_candidate_config(
compare.tuned,
compare_spec_path=compare_spec_path,
)
per_window: list[dict[str, Any]] = []
for window in windows:
window_id = str(window["window_id"])
row = {
"window_id": window_id,
"trace_type": window.get("trace_type"),
"date": window.get("date"),
"slot_token": window.get("slot_token"),
"slot_label": window.get("slot_label"),
"window_start": window.get("window_start"),
"window_end": window.get("window_end"),
"baseline": _run_compare_candidate(
compare_root=compare_root,
compare_id=compare.compare_id,
study=study,
study_spec_path=base_study_path,
window_id=window_id,
candidate_name="baseline",
config_patch=baseline_patch,
source=baseline_source,
),
"tuned": _run_compare_candidate(
compare_root=compare_root,
compare_id=compare.compare_id,
study=study,
study_spec_path=base_study_path,
window_id=window_id,
candidate_name="tuned",
config_patch=tuned_patch,
source=tuned_source,
),
}
row["delta"] = _window_delta(row["baseline"], row["tuned"])
per_window.append(row)
summary = {
"compare_id": compare.compare_id,
"study_spec_path": str(base_study_path),
"compare_root": str(compare_root),
"baseline_source": baseline_source,
"tuned_source": tuned_source,
"windows": per_window,
"aggregate": _aggregate_summary(per_window),
}
StudyStore.write_json(compare_root / "summary.json", summary)
(compare_root / "report.md").write_text(_render_report(summary), encoding="utf-8")
return summary
def _compare_snapshot(compare: CompareSpec) -> dict[str, Any]:
return {
"compare_id": compare.compare_id,
"study_spec_path": compare.study_spec_path,
"window_ids": compare.window_ids,
"window_selector": (
{
"trace_type": compare.window_selector.trace_type,
"date_prefix": compare.window_selector.date_prefix,
"date_from": compare.window_selector.date_from,
"date_to": compare.window_selector.date_to,
"slot_token": compare.window_selector.slot_token,
}
if compare.window_selector is not None
else None
),
"baseline": _candidate_snapshot(compare.baseline),
"tuned": _candidate_snapshot(compare.tuned),
}
def _candidate_snapshot(candidate: CompareCandidateSpec) -> dict[str, Any]:
if candidate.config_patch is not None:
return {
"kind": "config_patch",
"config_patch": {
"env_patch": dict(candidate.config_patch.env_patch),
"flag_patch": dict(candidate.config_patch.flag_patch),
},
}
assert candidate.trial_ref is not None
return {
"kind": "trial_ref",
"trial_ref": {
"study_root": candidate.trial_ref.study_root,
"trial_id": candidate.trial_ref.trial_id,
},
}
def _resolve_path(raw_path: str, *, base_dir: Path) -> Path:
path = Path(raw_path)
if not path.is_absolute():
path = (base_dir / path).resolve()
return path
def _resolve_windows_path(study: StudySpec, *, study_spec_path: Path) -> Path:
path = Path(study.trace.windows_path)
if not path.is_absolute():
path = (study_spec_path.parent / path).resolve()
return path
def _load_windows_payload(study: StudySpec, *, study_spec_path: Path) -> list[dict[str, Any]]:
windows_path = _resolve_windows_path(study, study_spec_path=study_spec_path)
payload = json.loads(windows_path.read_text(encoding="utf-8"))
raw_windows = payload.get("windows") if isinstance(payload, Mapping) else payload
if not isinstance(raw_windows, list):
raise SpecError(f"windows payload must contain a list: {windows_path}")
return [{str(key): value for key, value in item.items()} for item in raw_windows if isinstance(item, Mapping)]
def _select_windows(compare: CompareSpec, *, study: StudySpec, study_spec_path: Path) -> list[dict[str, Any]]:
windows = _load_windows_payload(study, study_spec_path=study_spec_path)
indexed = {str(item.get("window_id") or "").strip(): item for item in windows}
selected: dict[str, dict[str, Any]] = {}
for window_id in compare.window_ids:
item = indexed.get(window_id)
if item is None:
raise SpecError(f"window_id not found in windows payload: {window_id}")
selected[window_id] = item
selector = compare.window_selector
if selector is not None:
for item in windows:
window_id = str(item.get("window_id") or "").strip()
if not window_id:
continue
if selector.trace_type and str(item.get("trace_type") or "").strip() != selector.trace_type:
continue
date_value = str(item.get("date") or "").strip()
if selector.date_prefix and not date_value.startswith(selector.date_prefix):
continue
if selector.date_from and date_value and date_value < selector.date_from:
continue
if selector.date_to and date_value and date_value > selector.date_to:
continue
if selector.slot_token and str(item.get("slot_token") or "").strip() != selector.slot_token:
continue
selected[window_id] = item
ordered = sorted(
selected.values(),
key=lambda item: (
str(item.get("date") or ""),
str(item.get("slot_token") or ""),
str(item.get("window_id") or ""),
),
)
if not ordered:
raise SpecError("Compare spec selected zero windows.")
return ordered
def _resolve_candidate_config(
candidate: CompareCandidateSpec,
*,
compare_spec_path: Path,
) -> tuple[ConfigPatch, dict[str, Any]]:
if candidate.config_patch is not None:
return candidate.config_patch, {
"kind": "config_patch",
"config_patch": {
"env_patch": dict(candidate.config_patch.env_patch),
"flag_patch": dict(candidate.config_patch.flag_patch),
},
}
assert candidate.trial_ref is not None
study_root = _resolve_path(candidate.trial_ref.study_root, base_dir=compare_spec_path.parent)
trial_spec_path = study_root / "trials" / candidate.trial_ref.trial_id / "trial_spec.json"
if not trial_spec_path.exists():
raise SpecError(f"trial_ref target not found: {trial_spec_path}")
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
config_patch = ConfigPatch.from_dict(payload.get("config_patch") or {})
return config_patch, {
"kind": "trial_ref",
"study_root": str(study_root),
"trial_id": candidate.trial_ref.trial_id,
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
}
def _run_compare_candidate(
*,
compare_root: Path,
compare_id: str,
study: StudySpec,
study_spec_path: Path,
window_id: str,
candidate_name: str,
config_patch: ConfigPatch,
source: dict[str, Any],
) -> dict[str, Any]:
run_root = compare_root / "runs" / window_id / candidate_name
run_root.mkdir(parents=True, exist_ok=True)
window_study = replace(study, trace=replace(study.trace, window_id=window_id))
actual_study_path = run_root / "study_spec.json"
source_path = run_root / "study_spec.source"
actual_study_path.write_text(json.dumps(_study_to_snapshot(window_study), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
source_path.write_text(str(actual_study_path) + "\n", encoding="utf-8")
trial = TrialSpec(
study_id=compare_id,
trial_id=candidate_name,
config_patch=config_patch,
search=window_study.search,
study_spec_path=str(source_path),
artifact_dir=str(run_root),
probe_log_path=str(run_root / "probe_history.json"),
engine_log_path=str(run_root / "engine.log"),
result_path=str(run_root / "result.json"),
)
trial_spec_path = run_root / "trial_spec.json"
StudyStore.write_json(trial_spec_path, _trial_snapshot(trial))
result = run_trial(trial_spec_path)
parallel_size = _parallel_size_for_candidate(study=window_study, patch=config_patch)
best_rate = result.get("best_request_rate")
best_rate_per_gpu = (
float(best_rate) / float(parallel_size)
if isinstance(best_rate, (int, float)) and parallel_size > 0
else None
)
return {
"candidate": candidate_name,
"source": source,
"parallel_size": parallel_size,
"config_patch": {
"env_patch": dict(config_patch.env_patch),
"flag_patch": dict(config_patch.flag_patch),
},
"status": result.get("status"),
"best_sampling_u": result.get("best_sampling_u"),
"best_request_rate": best_rate,
"best_request_rate_per_gpu": best_rate_per_gpu,
"best_pass_rate": result.get("best_pass_rate"),
"best_request_count": result.get("best_request_count"),
"failure_stage": result.get("failure_stage"),
"failure_reason": result.get("failure_reason"),
"artifact_dir": str(run_root),
"result_path": str(run_root / "result.json"),
"probe_log_path": str(run_root / "probe_history.json"),
"engine_log_path": str(run_root / "engine.log"),
}
def _study_to_snapshot(study: StudySpec) -> dict[str, Any]:
from .spec import to_jsonable
return to_jsonable(study)
def _trial_snapshot(trial: TrialSpec) -> dict[str, Any]:
from .spec import to_jsonable
return to_jsonable(trial)
def _parse_int_like(value: Any, *, default: int = 1) -> int:
if value is None:
return default
if isinstance(value, bool):
raise SpecError("Topology values must be integers.")
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str) and value.strip():
return int(value.strip())
raise SpecError(f"Unable to parse integer topology value: {value!r}")
def _parallel_size_for_candidate(*, study: StudySpec, patch: ConfigPatch) -> int:
flags = dict(study.engine.base_flags)
flags.update(patch.flag_patch)
tp = _parse_int_like(flags.get("tensor-parallel-size"), default=1)
dp = _parse_int_like(flags.get("data-parallel-size"), default=1)
return tp * dp
def _window_delta(baseline: dict[str, Any], tuned: dict[str, Any]) -> dict[str, Any]:
return {
"request_rate_delta": _numeric_delta(
baseline.get("best_request_rate"), tuned.get("best_request_rate")
),
"request_rate_per_gpu_delta": _numeric_delta(
baseline.get("best_request_rate_per_gpu"),
tuned.get("best_request_rate_per_gpu"),
),
"pass_rate_delta": _numeric_delta(
baseline.get("best_pass_rate"), tuned.get("best_pass_rate")
),
"winner": _winner(
baseline.get("best_request_rate_per_gpu"), tuned.get("best_request_rate_per_gpu")
),
}
def _numeric_delta(lhs: Any, rhs: Any) -> float | None:
if not isinstance(lhs, (int, float)) or not isinstance(rhs, (int, float)):
return None
return float(rhs) - float(lhs)
def _winner(baseline_per_gpu: Any, tuned_per_gpu: Any) -> str:
if isinstance(baseline_per_gpu, (int, float)) and isinstance(tuned_per_gpu, (int, float)):
if float(tuned_per_gpu) > float(baseline_per_gpu):
return "tuned"
if float(tuned_per_gpu) < float(baseline_per_gpu):
return "baseline"
return "tie"
return "incomparable"
def _aggregate_summary(rows: list[dict[str, Any]]) -> dict[str, Any]:
baseline_rates = [
float(row["baseline"]["best_request_rate"])
for row in rows
if isinstance(row["baseline"].get("best_request_rate"), (int, float))
]
tuned_rates = [
float(row["tuned"]["best_request_rate"])
for row in rows
if isinstance(row["tuned"].get("best_request_rate"), (int, float))
]
baseline_per_gpu = [
float(row["baseline"]["best_request_rate_per_gpu"])
for row in rows
if isinstance(row["baseline"].get("best_request_rate_per_gpu"), (int, float))
]
tuned_per_gpu = [
float(row["tuned"]["best_request_rate_per_gpu"])
for row in rows
if isinstance(row["tuned"].get("best_request_rate_per_gpu"), (int, float))
]
wins = {"baseline": 0, "tuned": 0, "tie": 0, "incomparable": 0}
for row in rows:
wins[row["delta"]["winner"]] += 1
return {
"window_count": len(rows),
"wins": wins,
"baseline_mean_request_rate": _mean_or_none(baseline_rates),
"tuned_mean_request_rate": _mean_or_none(tuned_rates),
"baseline_mean_request_rate_per_gpu": _mean_or_none(baseline_per_gpu),
"tuned_mean_request_rate_per_gpu": _mean_or_none(tuned_per_gpu),
}
def _mean_or_none(values: list[float]) -> float | None:
if not values:
return None
return sum(values) / len(values)
def _render_report(summary: dict[str, Any]) -> str:
lines = [
f"# {summary['compare_id']}",
"",
"## Setup",
"",
f"- Study spec: `{summary['study_spec_path']}`",
f"- Compare root: `{summary['compare_root']}`",
f"- Baseline source: `{summary['baseline_source']['kind']}`",
f"- Tuned source: `{summary['tuned_source']['kind']}`",
"",
"## Aggregate",
"",
f"- Window count: `{summary['aggregate']['window_count']}`",
f"- Wins: `{json.dumps(summary['aggregate']['wins'], ensure_ascii=False)}`",
f"- Baseline mean request rate: `{summary['aggregate']['baseline_mean_request_rate']}`",
f"- Tuned mean request rate: `{summary['aggregate']['tuned_mean_request_rate']}`",
f"- Baseline mean request rate per GPU: `{summary['aggregate']['baseline_mean_request_rate_per_gpu']}`",
f"- Tuned mean request rate per GPU: `{summary['aggregate']['tuned_mean_request_rate_per_gpu']}`",
"",
"## Per Window",
"",
"| Window | Date | Baseline req/s | Baseline req/s/gpu | Tuned req/s | Tuned req/s/gpu | Winner |",
"| --- | --- | ---: | ---: | ---: | ---: | --- |",
]
for row in summary["windows"]:
baseline = row["baseline"]
tuned = row["tuned"]
lines.append(
f"| `{row['window_id']}` | `{row.get('date') or ''}` | "
f"`{baseline.get('best_request_rate')}` | `{baseline.get('best_request_rate_per_gpu')}` | "
f"`{tuned.get('best_request_rate')}` | `{tuned.get('best_request_rate_per_gpu')}` | "
f"`{row['delta']['winner']}` |"
)
lines.append("")
return "\n".join(lines)

View File

@@ -706,6 +706,114 @@ class StudyState:
trials: list[TrialSummary] = field(default_factory=list)
@dataclass(frozen=True)
class CompareTrialRefSpec:
study_root: str
trial_id: str
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "CompareTrialRefSpec":
return cls(
study_root=_require_str(data.get("study_root"), context="trial_ref.study_root"),
trial_id=_require_str(data.get("trial_id"), context="trial_ref.trial_id"),
)
@dataclass(frozen=True)
class CompareCandidateSpec:
config_patch: ConfigPatch | None = None
trial_ref: CompareTrialRefSpec | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "CompareCandidateSpec":
config_patch = (
ConfigPatch.from_dict(
_require_mapping(data.get("config_patch"), context=f"{context}.config_patch")
)
if data.get("config_patch") is not None
else None
)
trial_ref = (
CompareTrialRefSpec.from_dict(
_require_mapping(data.get("trial_ref"), context=f"{context}.trial_ref")
)
if data.get("trial_ref") is not None
else None
)
if (config_patch is None) == (trial_ref is None):
raise SpecError(
f"{context} must define exactly one of config_patch or trial_ref."
)
return cls(config_patch=config_patch, trial_ref=trial_ref)
@dataclass(frozen=True)
class WindowSelectorSpec:
trace_type: str | None = None
date_prefix: str | None = None
date_from: str | None = None
date_to: str | None = None
slot_token: str | None = None
@classmethod
def from_dict(cls, data: Mapping[str, Any], *, context: str) -> "WindowSelectorSpec":
selector = cls(
trace_type=str(data.get("trace_type") or "").strip() or None,
date_prefix=str(data.get("date_prefix") or "").strip() or None,
date_from=str(data.get("date_from") or "").strip() or None,
date_to=str(data.get("date_to") or "").strip() or None,
slot_token=str(data.get("slot_token") or "").strip() or None,
)
if (
selector.trace_type is None
and selector.date_prefix is None
and selector.date_from is None
and selector.date_to is None
and selector.slot_token is None
):
raise SpecError(f"{context} must define at least one selector field.")
return selector
@dataclass(frozen=True)
class CompareSpec:
compare_id: str
study_spec_path: str
window_ids: list[str]
window_selector: WindowSelectorSpec | None
baseline: CompareCandidateSpec
tuned: CompareCandidateSpec
@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> "CompareSpec":
compare = cls(
compare_id=_require_str(data.get("compare_id"), context="compare_id"),
study_spec_path=_require_str(
data.get("study_spec_path"), context="study_spec_path"
),
window_ids=_coerce_str_list(data.get("window_ids"), context="window_ids"),
window_selector=(
WindowSelectorSpec.from_dict(
_require_mapping(data.get("window_selector"), context="window_selector"),
context="window_selector",
)
if data.get("window_selector") is not None
else None
),
baseline=CompareCandidateSpec.from_dict(
_require_mapping(data.get("baseline"), context="baseline"),
context="baseline",
),
tuned=CompareCandidateSpec.from_dict(
_require_mapping(data.get("tuned"), context="tuned"),
context="tuned",
),
)
if not compare.window_ids and compare.window_selector is None:
raise SpecError("Compare spec must define window_ids or window_selector.")
return compare
def to_jsonable(value: Any) -> Any:
if is_dataclass(value):
return {key: to_jsonable(item) for key, item in asdict(value).items()}

View File

@@ -9,6 +9,7 @@ from pathlib import Path
from unittest import mock
from aituner.cli import main as cli_main
from aituner.compare import load_compare_spec, run_compare
from aituner.engine import build_launch_recipe
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
from aituner.job import append_job, build_trial_job
@@ -162,6 +163,36 @@ def _write_study_assets(
return study_path
def _write_compare_assets(
tmp_path: Path,
*,
study_path: Path,
window_ids: list[str] | None = None,
window_selector: dict[str, object] | None = None,
baseline: dict[str, object] | None = None,
tuned: dict[str, object] | None = None,
) -> Path:
compare_path = tmp_path / "compare.json"
payload: dict[str, object] = {
"compare_id": "compare-1",
"study_spec_path": str(study_path),
"baseline": baseline or {"config_patch": {"env_patch": {}, "flag_patch": {}}},
"tuned": tuned
or {
"config_patch": {
"env_patch": {},
"flag_patch": {"tensor-parallel-size": 2},
}
},
}
if window_ids is not None:
payload["window_ids"] = window_ids
if window_selector is not None:
payload["window_selector"] = window_selector
compare_path.write_text(json.dumps(payload), encoding="utf-8")
return compare_path
class CoreFlowTests(unittest.TestCase):
def test_trace_and_prompt_flow(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
@@ -1597,6 +1628,243 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.best_request_rate, 2.0)
self.assertEqual(state.next_trial_index, 3)
def test_load_compare_spec_requires_window_selection(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
compare_path = tmp_path / "compare.json"
compare_path.write_text(
json.dumps(
{
"compare_id": "compare-1",
"study_spec_path": str(study_path),
"baseline": {"config_patch": {"env_patch": {}, "flag_patch": {}}},
"tuned": {"config_patch": {"env_patch": {}, "flag_patch": {}}},
}
),
encoding="utf-8",
)
with self.assertRaisesRegex(SpecError, "window_ids or window_selector"):
load_compare_spec(compare_path)
def test_run_compare_outputs_summary_and_report(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
trace_dir = tmp_path / "trace_windows" / "traces"
trace_path = trace_dir / "chat_w2.jsonl"
trace_path.write_text(
json.dumps(
{
"request_id": "r4",
"timestamp": 0.0,
"sampling_u": 0.2,
"messages": [{"role": "user", "content": "extra"}],
"input_length": 3000,
"output_length": 32,
}
)
+ "\n",
encoding="utf-8",
)
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_payload = json.loads(windows_path.read_text(encoding="utf-8"))
windows_payload["windows"].append(
{
"window_id": "chat_w2",
"trace_type": "chat",
"trace_file": "traces/chat_w2.jsonl",
"window_start": 0.0,
"window_end": 10.0,
"date": "2026-03-12",
"slot_token": "1000",
"slot_label": "10:00-10:10",
}
)
windows_payload["windows"][0]["date"] = "2026-03-11"
windows_payload["windows"][0]["slot_token"] = "1000"
windows_payload["windows"][0]["slot_label"] = "10:00-10:10"
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
compare_path = _write_compare_assets(
tmp_path,
study_path=study_path,
window_ids=["chat_w1", "chat_w2"],
)
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
source_path = Path(trial_payload["study_spec_path"])
actual_spec_path = Path(source_path.read_text(encoding="utf-8").strip())
study_payload = json.loads(actual_spec_path.read_text(encoding="utf-8"))
window_id = study_payload["trace"]["window_id"]
trial_id = trial_payload["trial_id"]
rate_map = {
("chat_w1", "baseline"): 1.0,
("chat_w1", "tuned"): 3.0,
("chat_w2", "baseline"): 3.0,
("chat_w2", "tuned"): 7.0,
}
best_rate = rate_map[(window_id, trial_id)]
result = {
"study_id": trial_payload["study_id"],
"trial_id": trial_id,
"status": "completed",
"best_sampling_u": 0.5,
"best_request_rate": best_rate,
"best_pass_rate": 1.0,
"best_request_count": 2,
"probes": [],
}
Path(trial_payload["result_path"]).write_text(
json.dumps(result),
encoding="utf-8",
)
return result
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
self.assertEqual(len(summary["windows"]), 2)
self.assertEqual(summary["aggregate"]["wins"]["tuned"], 2)
self.assertTrue((tmp_path / ".compare" / "summary.json").exists())
self.assertTrue((tmp_path / ".compare" / "report.md").exists())
def test_run_compare_resolves_trial_ref_candidate(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
prior_root = tmp_path / "prior-study"
trial_dir = prior_root / "trials" / "trial-0002"
trial_dir.mkdir(parents=True)
trial_spec = {
"study_id": "prior-study",
"trial_id": "trial-0002",
"config_patch": {
"env_patch": {},
"flag_patch": {"data-parallel-size": 2},
},
"search": {
"low": 0.0,
"high": 1.0,
"tolerance": 0.01,
"max_probes": 8,
"sample_seed": 20260325,
},
"study_spec_path": str(study_path),
"artifact_dir": str(trial_dir),
"probe_log_path": str(trial_dir / "probe_history.json"),
"engine_log_path": str(trial_dir / "engine.log"),
"result_path": str(trial_dir / "result.json"),
}
(trial_dir / "trial_spec.json").write_text(json.dumps(trial_spec), encoding="utf-8")
compare_path = _write_compare_assets(
tmp_path,
study_path=study_path,
window_ids=["chat_w1"],
baseline={
"trial_ref": {
"study_root": str(prior_root),
"trial_id": "trial-0002",
}
},
)
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
flags = (trial_payload["config_patch"] or {}).get("flag_patch") or {}
best_rate = 5.0 if flags.get("data-parallel-size") == 2 else 2.0
result = {
"study_id": trial_payload["study_id"],
"trial_id": trial_payload["trial_id"],
"status": "completed",
"best_sampling_u": 0.5,
"best_request_rate": best_rate,
"best_pass_rate": 1.0,
"best_request_count": 2,
"probes": [],
}
Path(trial_payload["result_path"]).write_text(json.dumps(result), encoding="utf-8")
return result
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
self.assertEqual(summary["baseline_source"]["kind"], "trial_ref")
self.assertEqual(
summary["windows"][0]["baseline"]["config_patch"]["flag_patch"]["data-parallel-size"],
2,
)
def test_run_compare_window_selector_filters_windows(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
trace_dir = tmp_path / "trace_windows" / "traces"
for name in ("chat_w2.jsonl", "thinking_w3.jsonl"):
(trace_dir / name).write_text(
json.dumps(
{
"request_id": name,
"timestamp": 0.0,
"sampling_u": 0.2,
"messages": [{"role": "user", "content": name}],
"input_length": 3000,
"output_length": 32,
}
)
+ "\n",
encoding="utf-8",
)
windows_path = tmp_path / "trace_windows" / "windows.json"
windows_payload = json.loads(windows_path.read_text(encoding="utf-8"))
windows_payload["windows"][0]["date"] = "2026-03-11"
windows_payload["windows"][0]["slot_token"] = "1000"
windows_payload["windows"].append(
{
"window_id": "chat_w2",
"trace_type": "chat",
"trace_file": "traces/chat_w2.jsonl",
"window_start": 0.0,
"window_end": 10.0,
"date": "2026-03-12",
"slot_token": "1000",
}
)
windows_payload["windows"].append(
{
"window_id": "thinking_w3",
"trace_type": "thinking",
"trace_file": "traces/thinking_w3.jsonl",
"window_start": 0.0,
"window_end": 10.0,
"date": "2026-03-12",
"slot_token": "1000",
}
)
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
compare_path = _write_compare_assets(
tmp_path,
study_path=study_path,
window_selector={"trace_type": "chat", "date_prefix": "2026-03-12"},
)
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
result = {
"study_id": trial_payload["study_id"],
"trial_id": trial_payload["trial_id"],
"status": "completed",
"best_sampling_u": 0.5,
"best_request_rate": 1.0,
"best_pass_rate": 1.0,
"best_request_count": 2,
"probes": [],
}
Path(trial_payload["result_path"]).write_text(json.dumps(result), encoding="utf-8")
return result
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
self.assertEqual([row["window_id"] for row in summary["windows"]], ["chat_w2"])
def test_proposal_expected_effects_accepts_string(self) -> None:
proposal = Proposal.from_dict(
{