Stop tuning when baseline is infeasible
This commit is contained in:
@@ -2997,6 +2997,97 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
|
||||
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
|
||||
|
||||
def test_cli_tune_stops_when_baseline_is_all_infeasible(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
payload = json.loads(study_path.read_text(encoding="utf-8"))
|
||||
payload["llm"]["endpoint"] = {
|
||||
"provider": "custom",
|
||||
"base_url": "http://llm.example/v1",
|
||||
"wire_api": "chat.completions",
|
||||
"model": "test-model",
|
||||
"api_key_env": "OPENAI_API_KEY",
|
||||
}
|
||||
study_path.write_text(json.dumps(payload), encoding="utf-8")
|
||||
store_root = tmp_path / "store"
|
||||
|
||||
def fake_run_trial(trial_spec_path: Path) -> dict[str, object]:
|
||||
payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
trial_root = Path(payload["artifact_dir"])
|
||||
result = {
|
||||
"study_id": payload["study_id"],
|
||||
"trial_id": payload["trial_id"],
|
||||
"status": "completed",
|
||||
"best_sampling_u": None,
|
||||
"best_request_rate": None,
|
||||
"best_pass_rate": None,
|
||||
"best_request_count": None,
|
||||
"probes": [
|
||||
{
|
||||
"threshold": 0.5,
|
||||
"feasible": False,
|
||||
"payload": {"pass_rate": 0.0, "request_rate": 2.0},
|
||||
},
|
||||
{
|
||||
"threshold": 0.25,
|
||||
"feasible": False,
|
||||
"payload": {"pass_rate": 0.5, "request_rate": 1.0},
|
||||
},
|
||||
],
|
||||
"all_infeasible_diagnostics": {
|
||||
"threshold": 0.25,
|
||||
"request_rate": 1.0,
|
||||
"pass_rate": 0.5,
|
||||
"early_stop_reason": "slo_pass_rate_unrecoverable",
|
||||
},
|
||||
}
|
||||
(trial_root / "result.json").write_text(json.dumps(result), encoding="utf-8")
|
||||
return result
|
||||
|
||||
with mock.patch("aituner.cli.run_trial", side_effect=fake_run_trial):
|
||||
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"3",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
llm_mock.assert_not_called()
|
||||
store = StudyStore(store_root)
|
||||
state = store.load_state("study-1")
|
||||
self.assertEqual(state.next_trial_index, 2)
|
||||
self.assertEqual(len(state.trials), 1)
|
||||
self.assertEqual(state.tuning_stop_reason, "baseline_all_infeasible")
|
||||
self.assertIn("lowest_sampled_request_rate=1", state.tuning_stop_diagnosis)
|
||||
|
||||
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
|
||||
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
|
||||
exit_code = cli_main(
|
||||
[
|
||||
"study",
|
||||
"tune",
|
||||
"--spec",
|
||||
str(study_path),
|
||||
"--store-root",
|
||||
str(store_root),
|
||||
"--max-trials",
|
||||
"3",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(exit_code, 0)
|
||||
run_trial_mock.assert_not_called()
|
||||
llm_mock.assert_not_called()
|
||||
|
||||
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
|
||||
Reference in New Issue
Block a user