Stop tuning when baseline is infeasible

This commit is contained in:
2026-05-08 01:07:36 +08:00
parent a7a5e9ad80
commit f212673f44
4 changed files with 170 additions and 0 deletions

View File

@@ -2997,6 +2997,97 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
def test_cli_tune_stops_when_baseline_is_all_infeasible(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
payload = json.loads(study_path.read_text(encoding="utf-8"))
payload["llm"]["endpoint"] = {
"provider": "custom",
"base_url": "http://llm.example/v1",
"wire_api": "chat.completions",
"model": "test-model",
"api_key_env": "OPENAI_API_KEY",
}
study_path.write_text(json.dumps(payload), encoding="utf-8")
store_root = tmp_path / "store"
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
trial_root = Path(payload["artifact_dir"])
result = {
"study_id": payload["study_id"],
"trial_id": payload["trial_id"],
"status": "completed",
"best_sampling_u": None,
"best_request_rate": None,
"best_pass_rate": None,
"best_request_count": None,
"probes": [
{
"threshold": 0.5,
"feasible": False,
"payload": {"pass_rate": 0.0, "request_rate": 2.0},
},
{
"threshold": 0.25,
"feasible": False,
"payload": {"pass_rate": 0.5, "request_rate": 1.0},
},
],
"all_infeasible_diagnostics": {
"threshold": 0.25,
"request_rate": 1.0,
"pass_rate": 0.5,
"early_stop_reason": "slo_pass_rate_unrecoverable",
},
}
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
return result
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
exit_code = cli_main(
[
"study",
"tune",
"--spec",
str(study_path),
"--store-root",
str(store_root),
"--max-trials",
"3",
]
)
self.assertEqual(exit_code, 0)
llm_mock.assert_not_called()
store = StudyStore(store_root)
state = store.load_state("study-1")
self.assertEqual(state.next_trial_index, 2)
self.assertEqual(len(state.trials), 1)
self.assertEqual(state.tuning_stop_reason, "baseline_all_infeasible")
self.assertIn("lowest_sampled_request_rate=1", state.tuning_stop_diagnosis)
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
exit_code = cli_main(
[
"study",
"tune",
"--spec",
str(study_path),
"--store-root",
str(store_root),
"--max-trials",
"3",
]
)
self.assertEqual(exit_code, 0)
run_trial_mock.assert_not_called()
llm_mock.assert_not_called()
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)