diff --git a/configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json b/configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json index 2bfafd0..787486d 100644 --- a/configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json +++ b/configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json @@ -14,7 +14,7 @@ "engine": { "engine_name": "vllm", "engine_version": "0.20.0", - "exec_path": "/tmp/wjh/venvs/vllm-0.20.0-cu129/bin/vllm", + "exec_path": "/usr/local/bin/vllm", "cwd": "/home/admin/cpfs/wjh/aituner/aituner", "host": "127.0.0.1", "port": 18230, @@ -33,7 +33,11 @@ "base_flags": { "host": "127.0.0.1", "port": 18230, - "served-model-name": "qwen3-30b-a3b-community" + "served-model-name": "qwen3-30b-a3b-community", + "gpu-memory-utilization": 0.9, + "max-model-len": 16384, + "trust-remote-code": true, + "enable-prefix-caching": true }, "tunable_envs": [], "tunable_flags": [ @@ -123,7 +127,7 @@ "low": 0.0, "high": 0.125, "tolerance": 0.001, - "max_probes": 5, + "max_probes": 4, "sample_seed": 20260325 }, "llm": { diff --git a/scripts/stop_a_validate.py b/scripts/stop_a_validate.py new file mode 100644 index 0000000..62f5b66 --- /dev/null +++ b/scripts/stop_a_validate.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Validate Stop-A truncation fidelity from a full-replay trial's probe_details. + +Given a completed trial that replayed the full window (adaptive_stop disabled), for +each probe recompute the L-C-A convergence prefix and compare the feasibility +verdict / pass-rate of the truncated prefix against the full probe. This answers: +"would Stop-A have changed the measured peak-sustainable-rate?" using only the one +full run (no second GPU run needed). + +Example: + PYTHONPATH=src python3 scripts/stop_a_validate.py \ + --spec configs/examples/dash0_qwen30b_a3b_stopA_fulldata.json \ + --store-root .aituner/stopA-fulldata --tau 0.9 --tau-c 0.90 +""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from aituner.lca import find_convergence_prefix, resolve_length_mode +from aituner.spec import load_study_spec +from aituner.trace import load_trace_requests, select_requests_for_threshold + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--spec", type=Path, required=True) + ap.add_argument("--store-root", type=Path, required=True) + ap.add_argument("--tau", type=float, default=0.9) + ap.add_argument("--tau-c", type=float, default=0.90) + ap.add_argument("--stable-checks", type=int, default=3) + ap.add_argument("--target-pass-rate", type=float, default=0.95) + args = ap.parse_args() + + study = load_study_spec(args.spec) + window, requests = load_trace_requests(study, study_spec_path=args.spec) + mode = resolve_length_mode(request_mode=study.trace.request_mode) + gpu_count = study.hardware.gpu_count + + detail_files = sorted(args.store_root.glob("*/trials/*/probe_details.jsonl")) + if not detail_files: + print(f"no probe_details.jsonl under {args.store_root}") + return 1 + + print(f"target_pass_rate={args.target_pass_rate} tau={args.tau} tau_c={args.tau_c}") + print( + "thresh n_full stop_idx frac full_pass prefix_pass " + "full_feas prefix_feas verdict_match" + ) + mismatches = 0 + total = 0 + saved_fractions = [] + for detail_file in detail_files: + with detail_file.open(encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + probe = json.loads(line) + threshold = float(probe["threshold"]) + outcomes = probe.get("outcomes") or [] + # arrival-ordered outcomes that carry an arrival_s and verdict + ordered = sorted( + (o for o in outcomes if o.get("arrival_s") is not None), + key=lambda o: float(o["arrival_s"]), + ) + n = len(ordered) + if n == 0: + continue + selected = select_requests_for_threshold(requests, threshold=threshold) + cp = find_convergence_prefix( + selected, window, gpu_count=gpu_count, length_mode=mode, + tau=args.tau, tau_c=args.tau_c, stable_checks=args.stable_checks, + ) + # Map the convergence prefix fraction onto the replayed outcomes. + stop_n = max(1, min(n, round(cp.fraction * n))) + full_pass = sum(1 for o in ordered if o.get("evaluation")) / n + prefix_pass = sum(1 for o in ordered[:stop_n] if o.get("evaluation")) / stop_n + full_feas = full_pass >= args.target_pass_rate + prefix_feas = prefix_pass >= args.target_pass_rate + match = full_feas == prefix_feas + total += 1 + mismatches += 0 if match else 1 + saved_fractions.append(1.0 - cp.fraction) + print( + f"{threshold:.5f} {n:6d} {stop_n:7d} {cp.fraction:.2f} " + f"{full_pass:.3f} {prefix_pass:.3f} " + f"{str(full_feas):5s} {str(prefix_feas):5s} {match}" + ) + if total: + avg_saved = sum(saved_fractions) / len(saved_fractions) + print( + f"\nverdict matches: {total - mismatches}/{total} " + f"mean replay saved: {avg_saved*100:.0f}%" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())