Files
aituner/scripts/stop_a_validate.py
2026-06-15 15:22:48 +08:00

102 lines
4.2 KiB
Python

#!/usr/bin/env python3
"""Validate Stop-A truncation fidelity from a full-replay trial's probe_details.
Given a completed trial that replayed the full window (adaptive_stop disabled), for
each probe recompute the L-C-A convergence prefix and compare the feasibility
verdict / pass-rate of the truncated prefix against the full probe. This answers:
"would Stop-A have changed the measured peak-sustainable-rate?" using only the one
full run (no second GPU run needed).
Example:
PYTHONPATH=src python3 scripts/stop_a_validate.py \
--spec configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json \
--store-root .aituner/stopA-fulldata --tau 0.9 --tau-c 0.90
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from aituner.lca import find_convergence_prefix, resolve_length_mode
from aituner.spec import load_study_spec
from aituner.trace import load_trace_requests, select_requests_for_threshold
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--spec", type=Path, required=True)
ap.add_argument("--store-root", type=Path, required=True)
ap.add_argument("--tau", type=float, default=0.9)
ap.add_argument("--tau-c", type=float, default=0.90)
ap.add_argument("--stable-checks", type=int, default=3)
ap.add_argument("--target-pass-rate", type=float, default=0.95)
args = ap.parse_args()
study = load_study_spec(args.spec)
window, requests = load_trace_requests(study, study_spec_path=args.spec)
mode = resolve_length_mode(request_mode=study.trace.request_mode)
gpu_count = study.hardware.gpu_count
detail_files = sorted(args.store_root.glob("*/trials/*/probe_details.jsonl"))
if not detail_files:
print(f"no probe_details.jsonl under {args.store_root}")
return 1
print(f"target_pass_rate={args.target_pass_rate} tau={args.tau} tau_c={args.tau_c}")
print(
"thresh n_full stop_idx frac full_pass prefix_pass "
"full_feas prefix_feas verdict_match"
)
mismatches = 0
total = 0
saved_fractions = []
for detail_file in detail_files:
with detail_file.open(encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
probe = json.loads(line)
threshold = float(probe["threshold"])
outcomes = probe.get("outcomes") or []
# arrival-ordered outcomes that carry an arrival_s and verdict
ordered = sorted(
(o for o in outcomes if o.get("arrival_s") is not None),
key=lambda o: float(o["arrival_s"]),
)
n = len(ordered)
if n == 0:
continue
selected = select_requests_for_threshold(requests, threshold=threshold)
cp = find_convergence_prefix(
selected, window, gpu_count=gpu_count, length_mode=mode,
tau=args.tau, tau_c=args.tau_c, stable_checks=args.stable_checks,
)
# Map the convergence prefix fraction onto the replayed outcomes.
stop_n = max(1, min(n, round(cp.fraction * n)))
full_pass = sum(1 for o in ordered if o.get("evaluation")) / n
prefix_pass = sum(1 for o in ordered[:stop_n] if o.get("evaluation")) / stop_n
full_feas = full_pass >= args.target_pass_rate
prefix_feas = prefix_pass >= args.target_pass_rate
match = full_feas == prefix_feas
total += 1
mismatches += 0 if match else 1
saved_fractions.append(1.0 - cp.fraction)
print(
f"{threshold:.5f} {n:6d} {stop_n:7d} {cp.fraction:.2f} "
f"{full_pass:.3f} {prefix_pass:.3f} "
f"{str(full_feas):5s} {str(prefix_feas):5s} {match}"
)
if total:
avg_saved = sum(saved_fractions) / len(saved_fractions)
print(
f"\nverdict matches: {total - mismatches}/{total} "
f"mean replay saved: {avg_saved*100:.0f}%"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())