Make tune trial budget resumable

This commit is contained in:
2026-05-07 17:18:06 +08:00
parent 7263587cb6
commit a7a5e9ad80
3 changed files with 59 additions and 4 deletions

View File

@@ -2919,7 +2919,7 @@ class CoreFlowTests(unittest.TestCase):
"--store-root",
str(store_root),
"--max-trials",
"1",
"5",
]
)
@@ -2997,6 +2997,53 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.trials[0].config_patch, {"env_patch": {}, "flag_patch": {}})
self.assertEqual(state.trials[1].config_patch["flag_patch"], {"max-num-seqs": 64})
def test_cli_tune_max_trials_is_total_budget_on_resume(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
payload = json.loads(study_path.read_text(encoding="utf-8"))
payload["llm"]["endpoint"] = {
"provider": "custom",
"base_url": "http://llm.example/v1",
"wire_api": "chat.completions",
"model": "test-model",
"api_key_env": "OPENAI_API_KEY",
}
study_path.write_text(json.dumps(payload), encoding="utf-8")
store_root = tmp_path / "store"
study = load_study_spec(study_path)
store = StudyStore(store_root)
store.init_study(spec_path=study_path, study=study)
state = StudyState(
study_id=study.study_id,
next_trial_index=3,
trials=[
TrialSummary(trial_id="trial-0001", status="completed"),
TrialSummary(trial_id="trial-0002", status="completed"),
],
)
store.save_state(state)
with mock.patch("aituner.cli.call_llm_for_proposal") as llm_mock:
with mock.patch("aituner.cli.run_trial") as run_trial_mock:
exit_code = cli_main(
[
"study",
"tune",
"--spec",
str(study_path),
"--store-root",
str(store_root),
"--max-trials",
"2",
]
)
self.assertEqual(exit_code, 0)
llm_mock.assert_not_called()
run_trial_mock.assert_not_called()
self.assertEqual(store.load_state(study.study_id).next_trial_index, 3)
def test_load_compare_spec_requires_window_selection(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)