157 lines
7.2 KiB
Python
157 lines
7.2 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
from aituner.declarative_harness import (
|
|
AxisSpec,
|
|
CoverageState,
|
|
HarnessPolicy,
|
|
OperatorSpec,
|
|
config_signature,
|
|
coverage_unit_id,
|
|
enumerate_candidate_set,
|
|
ordered_lattice_failure_region,
|
|
validate_coverage_stop,
|
|
)
|
|
|
|
|
|
class DeclarativeHarnessTests(unittest.TestCase):
|
|
def test_same_state_grammar_policy_candidate_set_is_deterministic(self) -> None:
|
|
axes = (
|
|
AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4)),
|
|
AxisSpec(name="gmu", kind="bounded_numeric", floor=0.7, ceiling=0.95, step=0.05),
|
|
)
|
|
policy = HarnessPolicy(
|
|
operators=(
|
|
OperatorSpec(name="runtime_climb", axis="gmu", kind="local_climb", harness_priority=1),
|
|
OperatorSpec(name="topology_bracket", axis="tp", kind="bracket", harness_priority=5),
|
|
OperatorSpec(name="runtime_floor", axis="gmu", kind="jump_to_floor", harness_priority=2),
|
|
)
|
|
)
|
|
|
|
first = enumerate_candidate_set({"tp": 2, "gmu": 0.8}, axes, policy)
|
|
second = enumerate_candidate_set({"gmu": 0.8, "tp": 2}, axes, policy)
|
|
|
|
self.assertEqual(first.candidate_set_hash, second.candidate_set_hash)
|
|
self.assertEqual(
|
|
[candidate.action_id for candidate in first.eligible],
|
|
[candidate.action_id for candidate in second.eligible],
|
|
)
|
|
self.assertEqual(
|
|
[blocked.reason for blocked in first.blocked],
|
|
[blocked.reason for blocked in second.blocked],
|
|
)
|
|
self.assertTrue(all(candidate.planner_score is None for candidate in first.eligible))
|
|
self.assertTrue(all(candidate.backend_score is None for candidate in first.eligible))
|
|
|
|
def test_toy_lattice_bracket_enumerates_all_other_lattice_points(self) -> None:
|
|
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
|
|
policy = HarnessPolicy(
|
|
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
|
|
)
|
|
|
|
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy)
|
|
|
|
self.assertEqual({candidate.target_value for candidate in candidate_set.eligible}, {1, 4, 8})
|
|
self.assertEqual(candidate_set.blocked, ())
|
|
|
|
def test_no_repeat_blocks_exact_candidate_signature_and_records_reason(self) -> None:
|
|
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4))
|
|
policy = HarnessPolicy(operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),))
|
|
tested = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
|
|
|
|
candidate_set = enumerate_candidate_set({"tp": 2}, (axis,), policy, tested)
|
|
|
|
self.assertEqual(candidate_set.eligible, ())
|
|
self.assertEqual(len(candidate_set.blocked), 1)
|
|
self.assertEqual(candidate_set.blocked[0].candidate.target_value, 4)
|
|
self.assertEqual(candidate_set.blocked[0].reason, "no_repeat: signature already tested")
|
|
|
|
def test_ordered_lattice_upper_boundary_uses_axis_values_not_hard_coded_tp8(self) -> None:
|
|
for values in ((1, 3, 9), (2, 5, 10, 20)):
|
|
with self.subTest(values=values):
|
|
axis = AxisSpec(name="parallel_size", kind="ordered_lattice", values=values)
|
|
policy = HarnessPolicy(
|
|
operators=(OperatorSpec(name="step", axis=axis.name, kind="step_up"),)
|
|
)
|
|
|
|
candidate_set = enumerate_candidate_set({axis.name: values[-1]}, (axis,), policy)
|
|
|
|
self.assertEqual(candidate_set.eligible, ())
|
|
self.assertEqual(len(candidate_set.blocked), 1)
|
|
self.assertEqual(candidate_set.blocked[0].reason, "ordered_lattice_upper_boundary")
|
|
self.assertEqual(candidate_set.blocked[0].candidate.source_value, values[-1])
|
|
|
|
def test_bounded_numeric_jump_to_floor_uses_declared_floor_not_fixed_gmu_values(self) -> None:
|
|
for current, floor, ceiling in ((0.2, 0.6, 0.95), (0.77, 0.83, 0.91)):
|
|
with self.subTest(current=current, floor=floor, ceiling=ceiling):
|
|
axis = AxisSpec(
|
|
name="memory_fraction",
|
|
kind="bounded_numeric",
|
|
floor=floor,
|
|
ceiling=ceiling,
|
|
step=0.02,
|
|
)
|
|
policy = HarnessPolicy(
|
|
operators=(OperatorSpec(name="floor", axis="memory_fraction", kind="jump_to_floor"),)
|
|
)
|
|
|
|
candidate_set = enumerate_candidate_set({"memory_fraction": current}, (axis,), policy)
|
|
|
|
self.assertEqual(len(candidate_set.eligible), 1)
|
|
self.assertEqual(candidate_set.eligible[0].target_value, floor)
|
|
self.assertEqual(candidate_set.eligible[0].patch, {"memory_fraction": floor})
|
|
|
|
def test_coverage_stop_does_not_treat_signature_tested_as_coverage(self) -> None:
|
|
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2))
|
|
required_unit = coverage_unit_id("tp", "step_up", 2)
|
|
policy = HarnessPolicy(
|
|
operators=(OperatorSpec(name="step", axis="tp", kind="step_up"),),
|
|
required_coverage_unit_ids=frozenset({required_unit}),
|
|
)
|
|
candidate = enumerate_candidate_set({"tp": 1}, (axis,), policy).eligible[0]
|
|
coverage_state = CoverageState(tested_signatures=frozenset({candidate.signature}))
|
|
candidate_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, coverage_state)
|
|
|
|
stop = validate_coverage_stop(candidate_set, policy, coverage_state)
|
|
|
|
self.assertEqual(candidate_set.eligible, ())
|
|
self.assertEqual(stop.candidate_set_hash, candidate_set.candidate_set_hash)
|
|
self.assertFalse(stop.should_stop)
|
|
self.assertEqual(stop.reason, "coverage_units_missing")
|
|
self.assertEqual(stop.uncovered_unit_ids, (required_unit,))
|
|
|
|
def test_failure_invalidation_uses_conservative_region_not_exact_signature_only(self) -> None:
|
|
axis = AxisSpec(name="tp", kind="ordered_lattice", values=(1, 2, 4, 8))
|
|
policy = HarnessPolicy(
|
|
operators=(OperatorSpec(name="topology_bracket", axis="tp", kind="bracket"),)
|
|
)
|
|
|
|
exact_only = CoverageState(tested_signatures=frozenset({config_signature({"tp": 4})}))
|
|
exact_set = enumerate_candidate_set({"tp": 1}, (axis,), policy, exact_only)
|
|
self.assertEqual({candidate.target_value for candidate in exact_set.eligible}, {2, 8})
|
|
|
|
region = ordered_lattice_failure_region(
|
|
axis,
|
|
4,
|
|
direction="up",
|
|
reason="launch_failure_at_or_above_parallel_size",
|
|
)
|
|
regional_set = enumerate_candidate_set(
|
|
{"tp": 1},
|
|
(axis,),
|
|
policy,
|
|
CoverageState(failed_regions=(region,)),
|
|
)
|
|
|
|
self.assertEqual({candidate.target_value for candidate in regional_set.eligible}, {2})
|
|
blocked_targets = {blocked.candidate.target_value for blocked in regional_set.blocked}
|
|
self.assertTrue({4, 8}.issubset(blocked_targets))
|
|
self.assertTrue(
|
|
all("failure_region:tp:ge:4" in blocked.reason for blocked in regional_set.blocked)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|