from __future__ import annotations import unittest from aituner.declarative_harness import ( AxisSpec, CoverageState, HarnessPolicy, OperatorSpec, config_signature, coverage_unit_id, enumerate_candidate_set, ordered_lattice_failure_region, validate_coverage_stop, ) class DeclarativeHarnessTests(unittest.TestCase): def test_same_state_grammar_policy_candidate_set_is_deterministic(self) -> None: axes = ( AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4)), AxisSpec(name="gmu", kind="bounded_numeric", floor=0.7, ceiling=0.95, step=0.05), ) policy = HarnessPolicy( operators=( OperatorSpec(name="runtime_climb", axis="gmu", kind="local_climb", harness_priority=1), OperatorSpec(name="topology_bracket", axis="tp", kind="bracket", harness_priority=5), OperatorSpec(name="runtime_floor", axis="gmu", kind="jump_to_floor", harness_priority=2), ) ) first = enumerate_candidate_set({"tp": 2, "gmu": 0.8}, axes, policy) second = enumerate_candidate_set({"gmu": 0.8, "tp": 2}, axes, policy) self.assertEqual(first.candidate_set_hash, second.candidate_set_hash) self.assertEqual( [candidate.action_id for candidate in first.eligible], [candidate.action_id for candidate in second.eligible], ) self.assertEqual( [blocked.reason for blocked in first.blocked], [blocked.reason for blocked in second.blocked], ) self.assertTrue(all(candidate.planner_score is None for candidate in first.eligible)) self.assertTrue(all(candidate.backend_score is None for candidate in first.eligible)) def test_toy_lattice_bracket_enumerates_all_other_lattice_points(self) -> None: axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8)) policy = HarnessPolicy( operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),) ) candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy) self.assertEqual({candidate.target_value for candidate in candidate_set.eligible}, {1, 4, 8}) self.assertEqual(candidate_set.blocked, ()) def test_no_repeat_blocks_exact_candidate_signature_and_records_reason(self) -> None: axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4)) policy = HarnessPolicy(operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),)) tested = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})})) candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy, tested) self.assertEqual(candidate_set.eligible, ()) self.assertEqual(len(candidate_set.blocked), 1) self.assertEqual(candidate_set.blocked[0].candidate.target_value, 4) self.assertEqual(candidate_set.blocked[0].reason, "no_repeat: signature already tested") def test_ordered_lattice_upper_boundary_uses_axis_values_not_hard_coded_tp8(self) -> None: for values in ((1, 3, 9), (2, 5, 10, 20)): with self.subTest(values=values): axis = AxisSpec(name="parallel_size", kind="ordered_lattice", values=values) policy = HarnessPolicy( operators=(OperatorSpec(name="step", axis=axis.name, kind="step_up"),) ) candidate_set = enumerate_candidate_set({axis.name: values[-1]}, (axis,), policy) self.assertEqual(candidate_set.eligible, ()) self.assertEqual(len(candidate_set.blocked), 1) self.assertEqual(candidate_set.blocked[0].reason, "ordered_lattice_upper_boundary") self.assertEqual(candidate_set.blocked[0].candidate.source_value, values[-1]) def test_bounded_numeric_jump_to_floor_uses_declared_floor_not_fixed_gmu_values(self) -> None: for current, floor, ceiling in ((0.2, 0.6, 0.95), (0.77, 0.83, 0.91)): with self.subTest(current=current, floor=floor, ceiling=ceiling): axis = AxisSpec( name="memory_fraction", kind="bounded_numeric", floor=floor, ceiling=ceiling, step=0.02, ) policy = HarnessPolicy( operators=(OperatorSpec(name="floor", axis="memory_fraction", kind="jump_to_floor"),) ) candidate_set = enumerate_candidate_set({"memory_fraction": current}, (axis,), policy) self.assertEqual(len(candidate_set.eligible), 1) self.assertEqual(candidate_set.eligible[0].target_value, floor) self.assertEqual(candidate_set.eligible[0].patch, {"memory_fraction": floor}) def test_coverage_stop_does_not_treat_signature_tested_as_coverage(self) -> None: axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2)) required_unit = coverage_unit_id("tp", "step_up", 2) policy = HarnessPolicy( operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),), required_coverage_unit_ids=frozenset({required_unit}), ) candidate = enumerate_candidate_set({"tp": 1}, (axis,), policy).eligible[0] coverage_state = CoverageState(tested_signatures=frozenset({candidate.signature})) candidate_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, coverage_state) stop = validate_coverage_stop(candidate_set, policy, coverage_state) self.assertEqual(candidate_set.eligible, ()) self.assertEqual(stop.candidate_set_hash, candidate_set.candidate_set_hash) self.assertFalse(stop.should_stop) self.assertEqual(stop.reason, "coverage_units_missing") self.assertEqual(stop.uncovered_unit_ids, (required_unit,)) def test_failure_invalidation_uses_conservative_region_not_exact_signature_only(self) -> None: axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8)) policy = HarnessPolicy( operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),) ) exact_only = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})})) exact_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, exact_only) self.assertEqual({candidate.target_value for candidate in exact_set.eligible}, {2, 8}) region = ordered_lattice_failure_region( axis, 4, direction="up", reason="launch_failure_at_or_above_parallel_size", ) regional_set = enumerate_candidate_set( {"tp": 1}, (axis,), policy, CoverageState(failed_regions=(region,)), ) self.assertEqual({candidate.target_value for candidate in regional_set.eligible}, {2}) blocked_targets = {blocked.candidate.target_value for blocked in regional_set.blocked} self.assertTrue({4, 8}.issubset(blocked_targets)) self.assertTrue( all("failure_region:tp:ge:4" in blocked.reason for blocked in regional_set.blocked) ) if __name__ == "__main__": unittest.main()