Add declarative harness prototype
This commit is contained in:
156
tests/test_declarative_harness.py
Normal file
156
tests/test_declarative_harness.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from aituner.declarative_harness import (
|
||||
AxisSpec,
|
||||
CoverageState,
|
||||
HarnessPolicy,
|
||||
OperatorSpec,
|
||||
config_signature,
|
||||
coverage_unit_id,
|
||||
enumerate_candidate_set,
|
||||
ordered_lattice_failure_region,
|
||||
validate_coverage_stop,
|
||||
)
|
||||
|
||||
|
||||
class DeclarativeHarnessTests(unittest.TestCase):
|
||||
def test_same_state_grammar_policy_candidate_set_is_deterministic(self) -> None:
|
||||
axes = (
|
||||
AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4)),
|
||||
AxisSpec(name="gmu", kind="bounded_numeric", floor=0.7, ceiling=0.95, step=0.05),
|
||||
)
|
||||
policy = HarnessPolicy(
|
||||
operators=(
|
||||
OperatorSpec(name="runtime_climb", axis="gmu", kind="local_climb", harness_priority=1),
|
||||
OperatorSpec(name="topology_bracket", axis="tp", kind="bracket", harness_priority=5),
|
||||
OperatorSpec(name="runtime_floor", axis="gmu", kind="jump_to_floor", harness_priority=2),
|
||||
)
|
||||
)
|
||||
|
||||
first = enumerate_candidate_set({"tp": 2, "gmu": 0.8}, axes, policy)
|
||||
second = enumerate_candidate_set({"gmu": 0.8, "tp": 2}, axes, policy)
|
||||
|
||||
self.assertEqual(first.candidate_set_hash, second.candidate_set_hash)
|
||||
self.assertEqual(
|
||||
[candidate.action_id for candidate in first.eligible],
|
||||
[candidate.action_id for candidate in second.eligible],
|
||||
)
|
||||
self.assertEqual(
|
||||
[blocked.reason for blocked in first.blocked],
|
||||
[blocked.reason for blocked in second.blocked],
|
||||
)
|
||||
self.assertTrue(all(candidate.planner_score is None for candidate in first.eligible))
|
||||
self.assertTrue(all(candidate.backend_score is None for candidate in first.eligible))
|
||||
|
||||
def test_toy_lattice_bracket_enumerates_all_other_lattice_points(self) -> None:
|
||||
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
|
||||
policy = HarnessPolicy(
|
||||
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
|
||||
)
|
||||
|
||||
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy)
|
||||
|
||||
self.assertEqual({candidate.target_value for candidate in candidate_set.eligible}, {1, 4, 8})
|
||||
self.assertEqual(candidate_set.blocked, ())
|
||||
|
||||
def test_no_repeat_blocks_exact_candidate_signature_and_records_reason(self) -> None:
|
||||
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4))
|
||||
policy = HarnessPolicy(operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),))
|
||||
tested = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
|
||||
|
||||
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy, tested)
|
||||
|
||||
self.assertEqual(candidate_set.eligible, ())
|
||||
self.assertEqual(len(candidate_set.blocked), 1)
|
||||
self.assertEqual(candidate_set.blocked[0].candidate.target_value, 4)
|
||||
self.assertEqual(candidate_set.blocked[0].reason, "no_repeat: signature already tested")
|
||||
|
||||
def test_ordered_lattice_upper_boundary_uses_axis_values_not_hard_coded_tp8(self) -> None:
|
||||
for values in ((1, 3, 9), (2, 5, 10, 20)):
|
||||
with self.subTest(values=values):
|
||||
axis = AxisSpec(name="parallel_size", kind="ordered_lattice", values=values)
|
||||
policy = HarnessPolicy(
|
||||
operators=(OperatorSpec(name="step", axis=axis.name, kind="step_up"),)
|
||||
)
|
||||
|
||||
candidate_set = enumerate_candidate_set({axis.name: values[-1]}, (axis,), policy)
|
||||
|
||||
self.assertEqual(candidate_set.eligible, ())
|
||||
self.assertEqual(len(candidate_set.blocked), 1)
|
||||
self.assertEqual(candidate_set.blocked[0].reason, "ordered_lattice_upper_boundary")
|
||||
self.assertEqual(candidate_set.blocked[0].candidate.source_value, values[-1])
|
||||
|
||||
def test_bounded_numeric_jump_to_floor_uses_declared_floor_not_fixed_gmu_values(self) -> None:
|
||||
for current, floor, ceiling in ((0.2, 0.6, 0.95), (0.77, 0.83, 0.91)):
|
||||
with self.subTest(current=current, floor=floor, ceiling=ceiling):
|
||||
axis = AxisSpec(
|
||||
name="memory_fraction",
|
||||
kind="bounded_numeric",
|
||||
floor=floor,
|
||||
ceiling=ceiling,
|
||||
step=0.02,
|
||||
)
|
||||
policy = HarnessPolicy(
|
||||
operators=(OperatorSpec(name="floor", axis="memory_fraction", kind="jump_to_floor"),)
|
||||
)
|
||||
|
||||
candidate_set = enumerate_candidate_set({"memory_fraction": current}, (axis,), policy)
|
||||
|
||||
self.assertEqual(len(candidate_set.eligible), 1)
|
||||
self.assertEqual(candidate_set.eligible[0].target_value, floor)
|
||||
self.assertEqual(candidate_set.eligible[0].patch, {"memory_fraction": floor})
|
||||
|
||||
def test_coverage_stop_does_not_treat_signature_tested_as_coverage(self) -> None:
|
||||
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2))
|
||||
required_unit = coverage_unit_id("tp", "step_up", 2)
|
||||
policy = HarnessPolicy(
|
||||
operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),),
|
||||
required_coverage_unit_ids=frozenset({required_unit}),
|
||||
)
|
||||
candidate = enumerate_candidate_set({"tp": 1}, (axis,), policy).eligible[0]
|
||||
coverage_state = CoverageState(tested_signatures=frozenset({candidate.signature}))
|
||||
candidate_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, coverage_state)
|
||||
|
||||
stop = validate_coverage_stop(candidate_set, policy, coverage_state)
|
||||
|
||||
self.assertEqual(candidate_set.eligible, ())
|
||||
self.assertEqual(stop.candidate_set_hash, candidate_set.candidate_set_hash)
|
||||
self.assertFalse(stop.should_stop)
|
||||
self.assertEqual(stop.reason, "coverage_units_missing")
|
||||
self.assertEqual(stop.uncovered_unit_ids, (required_unit,))
|
||||
|
||||
def test_failure_invalidation_uses_conservative_region_not_exact_signature_only(self) -> None:
|
||||
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
|
||||
policy = HarnessPolicy(
|
||||
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
|
||||
)
|
||||
|
||||
exact_only = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
|
||||
exact_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, exact_only)
|
||||
self.assertEqual({candidate.target_value for candidate in exact_set.eligible}, {2, 8})
|
||||
|
||||
region = ordered_lattice_failure_region(
|
||||
axis,
|
||||
4,
|
||||
direction="up",
|
||||
reason="launch_failure_at_or_above_parallel_size",
|
||||
)
|
||||
regional_set = enumerate_candidate_set(
|
||||
{"tp": 1},
|
||||
(axis,),
|
||||
policy,
|
||||
CoverageState(failed_regions=(region,)),
|
||||
)
|
||||
|
||||
self.assertEqual({candidate.target_value for candidate in regional_set.eligible}, {2})
|
||||
blocked_targets = {blocked.candidate.target_value for blocked in regional_set.blocked}
|
||||
self.assertTrue({4, 8}.issubset(blocked_targets))
|
||||
self.assertTrue(
|
||||
all("failure_region:tp:ge:4" in blocked.reason for blocked in regional_set.blocked)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user