Use full search range for every trial
This commit is contained in:
@@ -2100,7 +2100,12 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.5)
|
||||
self.assertEqual(trial.search.low, 0.0)
|
||||
|
||||
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||
trial_spec_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
trial_spec_payload["search"]["low"] = 0.5
|
||||
trial_spec_path.write_text(json.dumps(trial_spec_payload), encoding="utf-8")
|
||||
|
||||
def fake_replay(requests, **kwargs):
|
||||
passing = len(requests) <= 1
|
||||
@@ -2126,7 +2131,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
|
||||
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
|
||||
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
|
||||
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
|
||||
result = run_trial(trial_spec_path)
|
||||
|
||||
self.assertEqual(result["status"], "completed")
|
||||
self.assertEqual(result["best_source"], "lower_range_fallback")
|
||||
@@ -2168,7 +2173,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.trials, [])
|
||||
self.assertEqual(len(next_state.trials), 1)
|
||||
|
||||
def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None:
|
||||
def test_materialize_trial_uses_full_search_range_with_incumbent(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
@@ -2203,10 +2208,10 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.375)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
self.assertEqual(trial.search.high, 1.0)
|
||||
|
||||
def test_materialize_trial_uses_same_parallel_group_incumbent(self) -> None:
|
||||
def test_materialize_trial_uses_full_search_range_for_same_parallel_group(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
@@ -2241,7 +2246,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.125)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
|
||||
def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
@@ -2359,7 +2364,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
"max-num-seqs": 160,
|
||||
},
|
||||
)
|
||||
self.assertEqual(trial.search.low, 0.125)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
self.assertEqual(
|
||||
next_state.trials[-1].config_patch["flag_patch"],
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user