Add declarative harness prototype

This commit is contained in:
2026-06-26 18:07:02 +08:00
parent 4075c7abf0
commit 384cb58f1f
5 changed files with 752 additions and 1 deletions

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import unittest
from aituner.declarative_harness import (
AxisSpec,
CoverageState,
HarnessPolicy,
OperatorSpec,
config_signature,
coverage_unit_id,
enumerate_candidate_set,
ordered_lattice_failure_region,
validate_coverage_stop,
)
class DeclarativeHarnessTests(unittest.TestCase):
def test_same_state_grammar_policy_candidate_set_is_deterministic(self) -> None:
axes = (
AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4)),
AxisSpec(name="gmu", kind="bounded_numeric", floor=0.7, ceiling=0.95, step=0.05),
)
policy = HarnessPolicy(
operators=(
OperatorSpec(name="runtime_climb", axis="gmu", kind="local_climb", harness_priority=1),
OperatorSpec(name="topology_bracket", axis="tp", kind="bracket", harness_priority=5),
OperatorSpec(name="runtime_floor", axis="gmu", kind="jump_to_floor", harness_priority=2),
)
)
first = enumerate_candidate_set({"tp": 2, "gmu": 0.8}, axes, policy)
second = enumerate_candidate_set({"gmu": 0.8, "tp": 2}, axes, policy)
self.assertEqual(first.candidate_set_hash, second.candidate_set_hash)
self.assertEqual(
[candidate.action_id for candidate in first.eligible],
[candidate.action_id for candidate in second.eligible],
)
self.assertEqual(
[blocked.reason for blocked in first.blocked],
[blocked.reason for blocked in second.blocked],
)
self.assertTrue(all(candidate.planner_score is None for candidate in first.eligible))
self.assertTrue(all(candidate.backend_score is None for candidate in first.eligible))
def test_toy_lattice_bracket_enumerates_all_other_lattice_points(self) -> None:
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
policy = HarnessPolicy(
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
)
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy)
self.assertEqual({candidate.target_value for candidate in candidate_set.eligible}, {1, 4, 8})
self.assertEqual(candidate_set.blocked, ())
def test_no_repeat_blocks_exact_candidate_signature_and_records_reason(self) -> None:
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4))
policy = HarnessPolicy(operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),))
tested = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy, tested)
self.assertEqual(candidate_set.eligible, ())
self.assertEqual(len(candidate_set.blocked), 1)
self.assertEqual(candidate_set.blocked[0].candidate.target_value, 4)
self.assertEqual(candidate_set.blocked[0].reason, "no_repeat: signature already tested")
def test_ordered_lattice_upper_boundary_uses_axis_values_not_hard_coded_tp8(self) -> None:
for values in ((1, 3, 9), (2, 5, 10, 20)):
with self.subTest(values=values):
axis = AxisSpec(name="parallel_size", kind="ordered_lattice", values=values)
policy = HarnessPolicy(
operators=(OperatorSpec(name="step", axis=axis.name, kind="step_up"),)
)
candidate_set = enumerate_candidate_set({axis.name: values[-1]}, (axis,), policy)
self.assertEqual(candidate_set.eligible, ())
self.assertEqual(len(candidate_set.blocked), 1)
self.assertEqual(candidate_set.blocked[0].reason, "ordered_lattice_upper_boundary")
self.assertEqual(candidate_set.blocked[0].candidate.source_value, values[-1])
def test_bounded_numeric_jump_to_floor_uses_declared_floor_not_fixed_gmu_values(self) -> None:
for current, floor, ceiling in ((0.2, 0.6, 0.95), (0.77, 0.83, 0.91)):
with self.subTest(current=current, floor=floor, ceiling=ceiling):
axis = AxisSpec(
name="memory_fraction",
kind="bounded_numeric",
floor=floor,
ceiling=ceiling,
step=0.02,
)
policy = HarnessPolicy(
operators=(OperatorSpec(name="floor", axis="memory_fraction", kind="jump_to_floor"),)
)
candidate_set = enumerate_candidate_set({"memory_fraction": current}, (axis,), policy)
self.assertEqual(len(candidate_set.eligible), 1)
self.assertEqual(candidate_set.eligible[0].target_value, floor)
self.assertEqual(candidate_set.eligible[0].patch, {"memory_fraction": floor})
def test_coverage_stop_does_not_treat_signature_tested_as_coverage(self) -> None:
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2))
required_unit = coverage_unit_id("tp", "step_up", 2)
policy = HarnessPolicy(
operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),),
required_coverage_unit_ids=frozenset({required_unit}),
)
candidate = enumerate_candidate_set({"tp": 1}, (axis,), policy).eligible[0]
coverage_state = CoverageState(tested_signatures=frozenset({candidate.signature}))
candidate_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, coverage_state)
stop = validate_coverage_stop(candidate_set, policy, coverage_state)
self.assertEqual(candidate_set.eligible, ())
self.assertEqual(stop.candidate_set_hash, candidate_set.candidate_set_hash)
self.assertFalse(stop.should_stop)
self.assertEqual(stop.reason, "coverage_units_missing")
self.assertEqual(stop.uncovered_unit_ids, (required_unit,))
def test_failure_invalidation_uses_conservative_region_not_exact_signature_only(self) -> None:
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
policy = HarnessPolicy(
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
)
exact_only = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
exact_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, exact_only)
self.assertEqual({candidate.target_value for candidate in exact_set.eligible}, {2, 8})
region = ordered_lattice_failure_region(
axis,
4,
direction="up",
reason="launch_failure_at_or_above_parallel_size",
)
regional_set = enumerate_candidate_set(
{"tp": 1},
(axis,),
policy,
CoverageState(failed_regions=(region,)),
)
self.assertEqual({candidate.target_value for candidate in regional_set.eligible}, {2})
blocked_targets = {blocked.candidate.target_value for blocked in regional_set.blocked}
self.assertTrue({4, 8}.issubset(blocked_targets))
self.assertTrue(
all("failure_region:tp:ge:4" in blocked.reason for blocked in regional_set.blocked)
)
if __name__ == "__main__":
unittest.main()