Keep target topology explicit in delta projections
This commit is contained in:
@@ -1698,6 +1698,7 @@ def _frontier_delta_projection_actions(
|
|||||||
f"_to_{target.get('trial_id')}"
|
f"_to_{target.get('trial_id')}"
|
||||||
)
|
)
|
||||||
if not runtime_delta:
|
if not runtime_delta:
|
||||||
|
target_topology_patch = _explicit_topology_patch(study, target_flags)
|
||||||
blocked_candidates.append(
|
blocked_candidates.append(
|
||||||
_blocked_candidate(
|
_blocked_candidate(
|
||||||
action_id=action_id,
|
action_id=action_id,
|
||||||
@@ -1705,7 +1706,7 @@ def _frontier_delta_projection_actions(
|
|||||||
config_patch={
|
config_patch={
|
||||||
"env_patch": {},
|
"env_patch": {},
|
||||||
"flag_patch": {
|
"flag_patch": {
|
||||||
**_preserve_topology_patch(study, target_flags),
|
**target_topology_patch,
|
||||||
**_preserve_runtime_patch(study, target_flags),
|
**_preserve_runtime_patch(study, target_flags),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1715,7 +1716,7 @@ def _frontier_delta_projection_actions(
|
|||||||
{
|
{
|
||||||
"env_patch": {},
|
"env_patch": {},
|
||||||
"flag_patch": {
|
"flag_patch": {
|
||||||
**_preserve_topology_patch(study, target_flags),
|
**target_topology_patch,
|
||||||
**_preserve_runtime_patch(study, target_flags),
|
**_preserve_runtime_patch(study, target_flags),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1725,7 +1726,7 @@ def _frontier_delta_projection_actions(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
patch = {
|
patch = {
|
||||||
**_preserve_topology_patch(study, target_flags),
|
**_explicit_topology_patch(study, target_flags),
|
||||||
**_preserve_runtime_patch(study, target_flags),
|
**_preserve_runtime_patch(study, target_flags),
|
||||||
**runtime_delta,
|
**runtime_delta,
|
||||||
}
|
}
|
||||||
@@ -2647,6 +2648,24 @@ def _preserve_topology_patch(study: StudySpec, flags: dict[str, Any]) -> dict[st
|
|||||||
return patch
|
return patch
|
||||||
|
|
||||||
|
|
||||||
|
def _explicit_topology_patch(study: StudySpec, flags: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
patch: dict[str, Any] = {}
|
||||||
|
tunable = set(study.engine.tunable_flags)
|
||||||
|
normalized = _normalized_topology_flags(flags)
|
||||||
|
for key in (
|
||||||
|
"tensor-parallel-size",
|
||||||
|
"data-parallel-size",
|
||||||
|
"expert-parallel-size",
|
||||||
|
"enable-expert-parallel",
|
||||||
|
):
|
||||||
|
if key not in tunable:
|
||||||
|
continue
|
||||||
|
if key not in flags and key not in study.engine.base_flags:
|
||||||
|
continue
|
||||||
|
patch[key] = normalized[key]
|
||||||
|
return patch
|
||||||
|
|
||||||
|
|
||||||
def _preserve_runtime_patch(study: StudySpec, flags: dict[str, Any]) -> dict[str, Any]:
|
def _preserve_runtime_patch(study: StudySpec, flags: dict[str, Any]) -> dict[str, Any]:
|
||||||
patch: dict[str, Any] = {}
|
patch: dict[str, Any] = {}
|
||||||
tunable = set(study.engine.tunable_flags)
|
tunable = set(study.engine.tunable_flags)
|
||||||
|
|||||||
@@ -2928,8 +2928,24 @@ class CoreFlowTests(unittest.TestCase):
|
|||||||
self.assertEqual(next_action["knob_family"], "frontier-delta-projection")
|
self.assertEqual(next_action["knob_family"], "frontier-delta-projection")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
next_action["config_patch"]["flag_patch"],
|
next_action["config_patch"]["flag_patch"],
|
||||||
{"gpu-memory-utilization": 0.9},
|
{
|
||||||
|
"tensor-parallel-size": 2,
|
||||||
|
"data-parallel-size": 1,
|
||||||
|
"gpu-memory-utilization": 0.9,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
proposal = build_harness_guided_proposal(context)
|
||||||
|
self.assertIsNotNone(proposal)
|
||||||
|
materialized_signature = materialized_effective_config_signature(
|
||||||
|
study=study,
|
||||||
|
state=state,
|
||||||
|
proposal=proposal,
|
||||||
|
)
|
||||||
|
tested_signatures = {
|
||||||
|
_effective_config_signature(study, trial.config_patch)
|
||||||
|
for trial in state.trials
|
||||||
|
}
|
||||||
|
self.assertNotIn(materialized_signature, tested_signatures)
|
||||||
self.assertIsNone(build_harness_stop_proposal(context))
|
self.assertIsNone(build_harness_stop_proposal(context))
|
||||||
|
|
||||||
def test_harness_validates_unmeasured_tp_frontier_before_runtime_refinement(self) -> None:
|
def test_harness_validates_unmeasured_tp_frontier_before_runtime_refinement(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user