#!/usr/bin/env python3 """Validate Stop-A truncation fidelity from a full-replay trial's probe_details. Given a completed trial that replayed the full window (adaptive_stop disabled), for each probe recompute the L-C-A convergence prefix and compare the feasibility verdict / pass-rate of the truncated prefix against the full probe. This answers: "would Stop-A have changed the measured peak-sustainable-rate?" using only the one full run (no second GPU run needed). Example: PYTHONPATH=src python3 scripts/stop_a_validate.py \ --spec configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json \ --store-root .aituner/stopA-fulldata --tau 0.9 --tau-c 0.90 """ from __future__ import annotations import argparse import json from pathlib import Path from aituner.lca import find_convergence_prefix, resolve_length_mode from aituner.spec import load_study_spec from aituner.trace import load_trace_requests, select_requests_for_threshold def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--spec", type=Path, required=True) ap.add_argument("--store-root", type=Path, required=True) ap.add_argument("--tau", type=float, default=0.9) ap.add_argument("--tau-c", type=float, default=0.90) ap.add_argument("--stable-checks", type=int, default=3) ap.add_argument("--target-pass-rate", type=float, default=0.95) args = ap.parse_args() study = load_study_spec(args.spec) window, requests = load_trace_requests(study, study_spec_path=args.spec) mode = resolve_length_mode(request_mode=study.trace.request_mode) gpu_count = study.hardware.gpu_count detail_files = sorted(args.store_root.glob("*/trials/*/probe_details.jsonl")) if not detail_files: print(f"no probe_details.jsonl under {args.store_root}") return 1 print(f"target_pass_rate={args.target_pass_rate} tau={args.tau} tau_c={args.tau_c}") print( "thresh n_full stop_idx frac full_pass prefix_pass " "full_feas prefix_feas verdict_match" ) mismatches = 0 total = 0 saved_fractions = [] for detail_file in detail_files: with detail_file.open(encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue probe = json.loads(line) threshold = float(probe["threshold"]) outcomes = probe.get("outcomes") or [] # arrival-ordered outcomes that carry an arrival_s and verdict ordered = sorted( (o for o in outcomes if o.get("arrival_s") is not None), key=lambda o: float(o["arrival_s"]), ) n = len(ordered) if n == 0: continue selected = select_requests_for_threshold(requests, threshold=threshold) cp = find_convergence_prefix( selected, window, gpu_count=gpu_count, length_mode=mode, tau=args.tau, tau_c=args.tau_c, stable_checks=args.stable_checks, ) # Map the convergence prefix fraction onto the replayed outcomes. stop_n = max(1, min(n, round(cp.fraction * n))) full_pass = sum(1 for o in ordered if o.get("evaluation")) / n prefix_pass = sum(1 for o in ordered[:stop_n] if o.get("evaluation")) / stop_n full_feas = full_pass >= args.target_pass_rate prefix_feas = prefix_pass >= args.target_pass_rate match = full_feas == prefix_feas total += 1 mismatches += 0 if match else 1 saved_fractions.append(1.0 - cp.fraction) print( f"{threshold:.5f} {n:6d} {stop_n:7d} {cp.fraction:.2f} " f"{full_pass:.3f} {prefix_pass:.3f} " f"{str(full_feas):5s} {str(prefix_feas):5s} {match}" ) if total: avg_saved = sum(saved_fractions) / len(saved_fractions) print( f"\nverdict matches: {total - mismatches}/{total} " f"mean replay saved: {avg_saved*100:.0f}%" ) return 0 if __name__ == "__main__": raise SystemExit(main())