Use full search range for every trial
This commit is contained in:
@@ -85,15 +85,11 @@ class StudyStore:
|
||||
trial_root = self.study_root(study.study_id) / "trials" / trial_id
|
||||
trial_root.mkdir(parents=True, exist_ok=True)
|
||||
parallel_size = _parallel_size_for_proposal(study=study, proposal=proposal)
|
||||
search_low = _derive_search_floor(study=study, state=state, parallel_size=parallel_size)
|
||||
spec = TrialSpec(
|
||||
study_id=study.study_id,
|
||||
trial_id=trial_id,
|
||||
config_patch=proposal.config_patch,
|
||||
search=replace(
|
||||
study.search,
|
||||
low=search_low,
|
||||
),
|
||||
search=study.search,
|
||||
study_spec_path=str((self.study_root(study.study_id) / "study_spec.source").resolve()),
|
||||
artifact_dir=str(trial_root),
|
||||
probe_log_path=str(trial_root / "probe_history.json"),
|
||||
@@ -305,15 +301,3 @@ def _request_rate_per_gpu(best_request_rate: Any, parallel_size: int | None) ->
|
||||
return None
|
||||
return float(best_request_rate) / float(parallel_size)
|
||||
|
||||
|
||||
def _derive_search_floor(*, study: StudySpec, state: StudyState, parallel_size: int) -> float:
|
||||
low = study.search.low
|
||||
high = study.search.high
|
||||
group_incumbent = (state.best_by_parallel_size or {}).get(str(parallel_size))
|
||||
if isinstance(group_incumbent, dict) and isinstance(
|
||||
group_incumbent.get("best_sampling_u"), (int, float)
|
||||
):
|
||||
candidate = float(group_incumbent["best_sampling_u"])
|
||||
else:
|
||||
candidate = low
|
||||
return min(high, max(low, candidate))
|
||||
|
||||
@@ -2100,7 +2100,12 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.5)
|
||||
self.assertEqual(trial.search.low, 0.0)
|
||||
|
||||
trial_spec_path = Path(trial.artifact_dir) / "trial_spec.json"
|
||||
trial_spec_payload = json.loads(trial_spec_path.read_text(encoding="utf-8"))
|
||||
trial_spec_payload["search"]["low"] = 0.5
|
||||
trial_spec_path.write_text(json.dumps(trial_spec_payload), encoding="utf-8")
|
||||
|
||||
def fake_replay(requests, **kwargs):
|
||||
passing = len(requests) <= 1
|
||||
@@ -2126,7 +2131,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
with mock.patch("aituner.worker._wait_for_server_or_exit", return_value=None):
|
||||
with mock.patch("aituner.worker._terminate_process_tree", return_value=None):
|
||||
with mock.patch("aituner.worker._replay_requests", side_effect=fake_replay):
|
||||
result = run_trial(Path(trial.artifact_dir) / "trial_spec.json")
|
||||
result = run_trial(trial_spec_path)
|
||||
|
||||
self.assertEqual(result["status"], "completed")
|
||||
self.assertEqual(result["best_source"], "lower_range_fallback")
|
||||
@@ -2168,7 +2173,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
self.assertEqual(state.trials, [])
|
||||
self.assertEqual(len(next_state.trials), 1)
|
||||
|
||||
def test_materialize_trial_uses_incumbent_sampling_u_as_search_floor(self) -> None:
|
||||
def test_materialize_trial_uses_full_search_range_with_incumbent(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
@@ -2203,10 +2208,10 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.375)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
self.assertEqual(trial.search.high, 1.0)
|
||||
|
||||
def test_materialize_trial_uses_same_parallel_group_incumbent(self) -> None:
|
||||
def test_materialize_trial_uses_full_search_range_for_same_parallel_group(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
study_path = _write_study_assets(tmp_path)
|
||||
@@ -2241,7 +2246,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
trial, _ = store.materialize_trial(study=study, state=state, proposal=proposal)
|
||||
self.assertEqual(trial.search.low, 0.125)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
|
||||
def test_materialize_trial_resets_search_floor_for_new_parallel_group(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
@@ -2359,7 +2364,7 @@ class CoreFlowTests(unittest.TestCase):
|
||||
"max-num-seqs": 160,
|
||||
},
|
||||
)
|
||||
self.assertEqual(trial.search.low, 0.125)
|
||||
self.assertEqual(trial.search.low, study.search.low)
|
||||
self.assertEqual(
|
||||
next_state.trials[-1].config_patch["flag_patch"],
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user