Use full search range for every trial

This commit is contained in:
2026-05-11 12:50:22 +08:00
parent 14259fcec9
commit 8516cd88c0
2 changed files with 13 additions and 24 deletions

View File

@@ -2100,7 +2100,12 @@ class CoreFlowTests(unittest.TestCase):
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.5)
self.assertEqual(trial.search.low, 0.0)
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
trial_spec_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
trial_spec_payload["search"]["low"] = 0.5
trial_spec_path.write_text(json.dumps(trial_spec_payload), encoding="utf-8")
def fake_replay(requests, **kwargs):
passing = len(requests) <= 1
@@ -2126,7 +2131,7 @@ class CoreFlowTests(unittest.TestCase):
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
result = run_trial(trial_spec_path)
self.assertEqual(result["status"], "completed")
self.assertEqual(result["best_source"], "lower_range_fallback")
@@ -2168,7 +2173,7 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.trials, [])
self.assertEqual(len(next_state.trials), 1)
def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None:
def test_materialize_trial_uses_full_search_range_with_incumbent(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
@@ -2203,10 +2208,10 @@ class CoreFlowTests(unittest.TestCase):
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.375)
self.assertEqual(trial.search.low, study.search.low)
self.assertEqual(trial.search.high, 1.0)
def test_materialize_trial_uses_same_parallel_group_incumbent(self) -> None:
def test_materialize_trial_uses_full_search_range_for_same_parallel_group(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path)
@@ -2241,7 +2246,7 @@ class CoreFlowTests(unittest.TestCase):
}
)
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.125)
self.assertEqual(trial.search.low, study.search.low)
def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
@@ -2359,7 +2364,7 @@ class CoreFlowTests(unittest.TestCase):
"max-num-seqs": 160,
},
)
self.assertEqual(trial.search.low, 0.125)
self.assertEqual(trial.search.low, study.search.low)
self.assertEqual(
next_state.trials[-1].config_patch["flag_patch"],
{