Add multi-window baseline vs tuned compare flow
This commit is contained in:
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from aituner.cli import main as cli_main
|
||||
from aituner.compare import load_compare_spec, run_compare
|
||||
from aituner.engine import build_launch_recipe
|
||||
from aituner.http_client import _auth_headers, _openai_url, _should_bypass_proxy
|
||||
from aituner.job import append_job, build_trial_job
|
||||
@@ -162,6 +163,36 @@ def _write_study_assets(
|
||||
return study_path
|
||||
|
||||
|
||||
def _write_compare_assets(
|
||||
tmp_path: Path,
|
||||
*,
|
||||
study_path: Path,
|
||||
window_ids: list[str] | None = None,
|
||||
window_selector: dict[str, object] | None = None,
|
||||
baseline: dict[str, object] | None = None,
|
||||
tuned: dict[str, object] | None = None,
|
||||
) -> Path:
|
||||
compare_path = tmp_path / "compare.json"
|
||||
payload: dict[str, object] = {
|
||||
"compare_id": "compare-1",
|
||||
"study_spec_path": str(study_path),
|
||||
"baseline": baseline or {"config_patch": {"env_patch": {}, "flag_patch": {}}},
|
||||
"tuned": tuned
|
||||
or {
|
||||
"config_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {"tensor-parallel-size": 2},
|
||||
}
|
||||
},
|
||||
}
|
||||
if window_ids is not None:
|
||||
payload["window_ids"] = window_ids
|
||||
if window_selector is not None:
|
||||
payload["window_selector"] = window_selector
|
||||
compare_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
return compare_path
|
||||
|
||||
|
||||
class CoreFlowTests(unittest.TestCase):
|
||||
def test_trace_and_prompt_flow(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
@@ -1597,6 +1628,243 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.best_request_rate, 2.0)
|
||||
self.assertEqual(state.next_trial_index, 3)
|
||||
|
||||
def test_load_compare_spec_requires_window_selection(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
compare_path = tmp_path / "compare.json"
|
||||
compare_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"compare_id": "compare-1",
|
||||
"study_spec_path": str(study_path),
|
||||
"baseline": {"config_patch": {"env_patch": {}, "flag_patch": {}}},
|
||||
"tuned": {"config_patch": {"env_patch": {}, "flag_patch": {}}},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
with self.assertRaisesRegex(SpecError, "window_ids or window_selector"):
|
||||
load_compare_spec(compare_path)
|
||||
|
||||
def test_run_compare_outputs_summary_and_report(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||
trace_path = trace_dir / "chat_w2.jsonl"
|
||||
trace_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"request_id": "r4",
|
||||
"timestamp": 0.0,
|
||||
"sampling_u": 0.2,
|
||||
"messages": [{"role": "user", "content": "extra"}],
|
||||
"input_length": 3000,
|
||||
"output_length": 32,
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||
windows_payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
||||
windows_payload["windows"].append(
|
||||
{
|
||||
"window_id": "chat_w2",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_w2.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0,
|
||||
"date": "2026-03-12",
|
||||
"slot_token": "1000",
|
||||
"slot_label": "10:00-10:10",
|
||||
}
|
||||
)
|
||||
windows_payload["windows"][0]["date"] = "2026-03-11"
|
||||
windows_payload["windows"][0]["slot_token"] = "1000"
|
||||
windows_payload["windows"][0]["slot_label"] = "10:00-10:10"
|
||||
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
|
||||
compare_path = _write_compare_assets(
|
||||
tmp_path,
|
||||
study_path=study_path,
|
||||
window_ids=["chat_w1", "chat_w2"],
|
||||
)
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
source_path = Path(trial_payload["study_spec_path"])
|
||||
actual_spec_path = Path(source_path.read_text(encoding="utf-8").strip())
|
||||
study_payload = json.loads(actual_spec_path.read_text(encoding="utf-8"))
|
||||
window_id = study_payload["trace"]["window_id"]
|
||||
trial_id = trial_payload["trial_id"]
|
||||
rate_map = {
|
||||
("chat_w1", "baseline"): 1.0,
|
||||
("chat_w1", "tuned"): 3.0,
|
||||
("chat_w2", "baseline"): 3.0,
|
||||
("chat_w2", "tuned"): 7.0,
|
||||
}
|
||||
best_rate = rate_map[(window_id, trial_id)]
|
||||
result = {
|
||||
"study_id": trial_payload["study_id"],
|
||||
"trial_id": trial_id,
|
||||
"status": "completed",
|
||||
"best_sampling_u": 0.5,
|
||||
"best_request_rate": best_rate,
|
||||
"best_pass_rate": 1.0,
|
||||
"best_request_count": 2,
|
||||
"probes": [],
|
||||
}
|
||||
Path(trial_payload["result_path"]).write_text(
|
||||
json.dumps(result),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return result
|
||||
|
||||
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
|
||||
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
|
||||
self.assertEqual(len(summary["windows"]), 2)
|
||||
self.assertEqual(summary["aggregate"]["wins"]["tuned"], 2)
|
||||
self.assertTrue((tmp_path / ".compare" / "summary.json").exists())
|
||||
self.assertTrue((tmp_path / ".compare" / "report.md").exists())
|
||||
|
||||
def test_run_compare_resolves_trial_ref_candidate(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
prior_root = tmp_path / "prior-study"
|
||||
trial_dir = prior_root / "trials" / "trial-0002"
|
||||
trial_dir.mkdir(parents=True)
|
||||
trial_spec = {
|
||||
"study_id": "prior-study",
|
||||
"trial_id": "trial-0002",
|
||||
"config_patch": {
|
||||
"env_patch": {},
|
||||
"flag_patch": {"data-parallel-size": 2},
|
||||
},
|
||||
"search": {
|
||||
"low": 0.0,
|
||||
"high": 1.0,
|
||||
"tolerance": 0.01,
|
||||
"max_probes": 8,
|
||||
"sample_seed": 20260325,
|
||||
},
|
||||
"study_spec_path": str(study_path),
|
||||
"artifact_dir": str(trial_dir),
|
||||
"probe_log_path": str(trial_dir / "probe_history.json"),
|
||||
"engine_log_path": str(trial_dir / "engine.log"),
|
||||
"result_path": str(trial_dir / "result.json"),
|
||||
}
|
||||
(trial_dir / "trial_spec.json").write_text(json.dumps(trial_spec), encoding="utf-8")
|
||||
compare_path = _write_compare_assets(
|
||||
tmp_path,
|
||||
study_path=study_path,
|
||||
window_ids=["chat_w1"],
|
||||
baseline={
|
||||
"trial_ref": {
|
||||
"study_root": str(prior_root),
|
||||
"trial_id": "trial-0002",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
flags = (trial_payload["config_patch"] or {}).get("flag_patch") or {}
|
||||
best_rate = 5.0 if flags.get("data-parallel-size") == 2 else 2.0
|
||||
result = {
|
||||
"study_id": trial_payload["study_id"],
|
||||
"trial_id": trial_payload["trial_id"],
|
||||
"status": "completed",
|
||||
"best_sampling_u": 0.5,
|
||||
"best_request_rate": best_rate,
|
||||
"best_pass_rate": 1.0,
|
||||
"best_request_count": 2,
|
||||
"probes": [],
|
||||
}
|
||||
Path(trial_payload["result_path"]).write_text(json.dumps(result), encoding="utf-8")
|
||||
return result
|
||||
|
||||
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
|
||||
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
|
||||
self.assertEqual(summary["baseline_source"]["kind"], "trial_ref")
|
||||
self.assertEqual(
|
||||
summary["windows"][0]["baseline"]["config_patch"]["flag_patch"]["data-parallel-size"],
|
||||
2,
|
||||
)
|
||||
|
||||
def test_run_compare_window_selector_filters_windows(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
trace_dir = tmp_path / "trace_windows" / "traces"
|
||||
for name in ("chat_w2.jsonl", "thinking_w3.jsonl"):
|
||||
(trace_dir / name).write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"request_id": name,
|
||||
"timestamp": 0.0,
|
||||
"sampling_u": 0.2,
|
||||
"messages": [{"role": "user", "content": name}],
|
||||
"input_length": 3000,
|
||||
"output_length": 32,
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
windows_path = tmp_path / "trace_windows" / "windows.json"
|
||||
windows_payload = json.loads(windows_path.read_text(encoding="utf-8"))
|
||||
windows_payload["windows"][0]["date"] = "2026-03-11"
|
||||
windows_payload["windows"][0]["slot_token"] = "1000"
|
||||
windows_payload["windows"].append(
|
||||
{
|
||||
"window_id": "chat_w2",
|
||||
"trace_type": "chat",
|
||||
"trace_file": "traces/chat_w2.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0,
|
||||
"date": "2026-03-12",
|
||||
"slot_token": "1000",
|
||||
}
|
||||
)
|
||||
windows_payload["windows"].append(
|
||||
{
|
||||
"window_id": "thinking_w3",
|
||||
"trace_type": "thinking",
|
||||
"trace_file": "traces/thinking_w3.jsonl",
|
||||
"window_start": 0.0,
|
||||
"window_end": 10.0,
|
||||
"date": "2026-03-12",
|
||||
"slot_token": "1000",
|
||||
}
|
||||
)
|
||||
windows_path.write_text(json.dumps(windows_payload), encoding="utf-8")
|
||||
compare_path = _write_compare_assets(
|
||||
tmp_path,
|
||||
study_path=study_path,
|
||||
window_selector={"trace_type": "chat", "date_prefix": "2026-03-12"},
|
||||
)
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
trial_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
result = {
|
||||
"study_id": trial_payload["study_id"],
|
||||
"trial_id": trial_payload["trial_id"],
|
||||
"status": "completed",
|
||||
"best_sampling_u": 0.5,
|
||||
"best_request_rate": 1.0,
|
||||
"best_pass_rate": 1.0,
|
||||
"best_request_count": 2,
|
||||
"probes": [],
|
||||
}
|
||||
Path(trial_payload["result_path"]).write_text(json.dumps(result), encoding="utf-8")
|
||||
return result
|
||||
|
||||
with mock.patch("aituner.compare.run_trial", side_effect=fake_run_trial):
|
||||
summary = run_compare(compare_path, output_root=tmp_path / ".compare")
|
||||
self.assertEqual([row["window_id"] for row in summary["windows"]], ["chat_w2"])
|
||||
|
||||
def test_proposal_expected_effects_accepts_string(self) -> None:
|
||||
proposal = Proposal.from_dict(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user