Use full search range for every trial

This commit is contained in:
2026-05-11 12:50:22 +08:00
parent 14259fcec9
commit 8516cd88c0
2 changed files with 13 additions and 24 deletions

View File

@@ -85,15 +85,11 @@ class StudyStore:
trial_root = self.study_root(study.study_id) / "trials" / trial_id trial_root = self.study_root(study.study_id) / "trials" / trial_id
trial_root.mkdir(parents=True, exist_ok=True) trial_root.mkdir(parents=True, exist_ok=True)
parallel_size = _parallel_size_for_proposal(study=study, proposal=proposal) parallel_size = _parallel_size_for_proposal(study=study, proposal=proposal)
search_low = _derive_search_floor(study=study, state=state, parallel_size=parallel_size)
spec = TrialSpec( spec = TrialSpec(
study_id=study.study_id, study_id=study.study_id,
trial_id=trial_id, trial_id=trial_id,
config_patch=proposal.config_patch, config_patch=proposal.config_patch,
search=replace( search=study.search,
study.search,
low=search_low,
),
study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()), study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()),
artifact_dir=str(trial_root), artifact_dir=str(trial_root),
probe_log_path=str(trial_root / "probe_history.json"), probe_log_path=str(trial_root / "probe_history.json"),
@@ -305,15 +301,3 @@ def _request_rate_per_gpu(best_request_rate: Any, parallel_size: int | None) ->
return None return None
return float(best_request_rate) / float(parallel_size) return float(best_request_rate) / float(parallel_size)
def _derive_search_floor(*, study: StudySpec, state: StudyState, parallel_size: int) -> float:
low = study.search.low
high = study.search.high
group_incumbent = (state.best_by_parallel_size or {}).get(str(parallel_size))
if isinstance(group_incumbent, dict) and isinstance(
group_incumbent.get("best_sampling_u"), (int, float)
):
candidate = float(group_incumbent["best_sampling_u"])
else:
candidate = low
return min(high, max(low, candidate))

View File

@@ -2100,7 +2100,12 @@ class CoreFlowTests(unittest.TestCase):
} }
) )
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.5) self.assertEqual(trial.search.low, 0.0)
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
trial_spec_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
trial_spec_payload["search"]["low"] = 0.5
trial_spec_path.write_text(json.dumps(trial_spec_payload), encoding="utf-8")
def fake_replay(requests, **kwargs): def fake_replay(requests, **kwargs):
passing = len(requests) <= 1 passing = len(requests) <= 1
@@ -2126,7 +2131,7 @@ class CoreFlowTests(unittest.TestCase):
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None): with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
with mock.patch("aituner.worker._terminate_process_tree", return_value=None): with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay): with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json") result = run_trial(trial_spec_path)
self.assertEqual(result["status"], "completed") self.assertEqual(result["status"], "completed")
self.assertEqual(result["best_source"], "lower_range_fallback") self.assertEqual(result["best_source"], "lower_range_fallback")
@@ -2168,7 +2173,7 @@ class CoreFlowTests(unittest.TestCase):
self.assertEqual(state.trials, []) self.assertEqual(state.trials, [])
self.assertEqual(len(next_state.trials), 1) self.assertEqual(len(next_state.trials), 1)
def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None: def test_materialize_trial_uses_full_search_range_with_incumbent(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp) tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path) study_path = _write_study_assets(tmp_path)
@@ -2203,10 +2208,10 @@ class CoreFlowTests(unittest.TestCase):
} }
) )
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.375) self.assertEqual(trial.search.low, study.search.low)
self.assertEqual(trial.search.high, 1.0) self.assertEqual(trial.search.high, 1.0)
def test_materialize_trial_uses_same_parallel_group_incumbent(self) -> None: def test_materialize_trial_uses_full_search_range_for_same_parallel_group(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp) tmp_path = Path(tmp)
study_path = _write_study_assets(tmp_path) study_path = _write_study_assets(tmp_path)
@@ -2241,7 +2246,7 @@ class CoreFlowTests(unittest.TestCase):
} }
) )
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal) trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
self.assertEqual(trial.search.low, 0.125) self.assertEqual(trial.search.low, study.search.low)
def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None: def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
@@ -2359,7 +2364,7 @@ class CoreFlowTests(unittest.TestCase):
"max-num-seqs": 160, "max-num-seqs": 160,
}, },
) )
self.assertEqual(trial.search.low, 0.125) self.assertEqual(trial.search.low, study.search.low)
self.assertEqual( self.assertEqual(
next_state.trials[-1].config_patch["flag_patch"], next_state.trials[-1].config_patch["flag_patch"],
{ {