diff --git a/src/aituner/store.py b/src/aituner/store.py index 7456431..0cc33d6 100644 --- a/src/aituner/store.py +++ b/src/aituner/store.py @@ -85,15 +85,11 @@ class StudyStore: trial_root = self.study_root(study.study_id) / "trials" / trial_id trial_root.mkdir(parents=True, exist_ok=True) parallel_size = _parallel_size_for_proposal(study=study, proposal=proposal) - search_low = _derive_search_floor(study=study, state=state, parallel_size=parallel_size) spec = TrialSpec( study_id=study.study_id, trial_id=trial_id, config_patch=proposal.config_patch, - search=replace( - study.search, - low=search_low, - ), + search=study.search, study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()), artifact_dir=str(trial_root), probe_log_path=str(trial_root / "probe_history.json"), @@ -305,15 +301,3 @@ def _request_rate_per_gpu(best_request_rate: Any, parallel_size: int | None) -> return None return float(best_request_rate) / float(parallel_size) - -def _derive_search_floor(*, study: StudySpec, state: StudyState, parallel_size: int) -> float: - low = study.search.low - high = study.search.high - group_incumbent = (state.best_by_parallel_size or {}).get(str(parallel_size)) - if isinstance(group_incumbent, dict) and isinstance( - group_incumbent.get("best_sampling_u"), (int, float) - ): - candidate = float(group_incumbent["best_sampling_u"]) - else: - candidate = low - return min(high, max(low, candidate)) diff --git a/tests/test_core_flow.py b/tests/test_core_flow.py index 10ba8fd..8e20b1f 100644 --- a/tests/test_core_flow.py +++ b/tests/test_core_flow.py @@ -2100,7 +2100,12 @@ class CoreFlowTests(unittest.TestCase): } ) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) - self.assertEqual(trial.search.low, 0.5) + self.assertEqual(trial.search.low, 0.0) + + trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json" + trial_spec_payload = json.loads(trial_spec_path.read_text(encoding="utf-8")) + trial_spec_payload["search"]["low"] = 0.5 + trial_spec_path.write_text(json.dumps(trial_spec_payload), encoding="utf-8") def fake_replay(requests, **kwargs): passing = len(requests) <= 1 @@ -2126,7 +2131,7 @@ class CoreFlowTests(unittest.TestCase): with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None): with mock.patch("aituner.worker._terminate_process_tree", return_value=None): with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay): - result = run_trial(Path(trial.artifact_dir) / "trial_spec.json") + result = run_trial(trial_spec_path) self.assertEqual(result["status"], "completed") self.assertEqual(result["best_source"], "lower_range_fallback") @@ -2168,7 +2173,7 @@ class CoreFlowTests(unittest.TestCase): self.assertEqual(state.trials, []) self.assertEqual(len(next_state.trials), 1) - def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None: + def test_materialize_trial_uses_full_search_range_with_incumbent(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) study_path = _write_study_assets(tmp_path) @@ -2203,10 +2208,10 @@ class CoreFlowTests(unittest.TestCase): } ) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) - self.assertEqual(trial.search.low, 0.375) + self.assertEqual(trial.search.low, study.search.low) self.assertEqual(trial.search.high, 1.0) - def test_materialize_trial_uses_same_parallel_group_incumbent(self) -> None: + def test_materialize_trial_uses_full_search_range_for_same_parallel_group(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) study_path = _write_study_assets(tmp_path) @@ -2241,7 +2246,7 @@ class CoreFlowTests(unittest.TestCase): } ) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) - self.assertEqual(trial.search.low, 0.125) + self.assertEqual(trial.search.low, study.search.low) def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None: with tempfile.TemporaryDirectory() as tmp: @@ -2359,7 +2364,7 @@ class CoreFlowTests(unittest.TestCase): "max-num-seqs": 160, }, ) - self.assertEqual(trial.search.low, 0.125) + self.assertEqual(trial.search.low, study.search.low) self.assertEqual( next_state.trials[-1].config_patch["flag_patch"], {