Files
aituner/scripts/stop_a_calibration.py
Gahow Wang 08e53fd897 Add Stop-A calibration script (CPU-only convergence curve)
Prints the offered-L-C-A convergence curve and the stop fraction at candidate
tau_c values for a raw trace window, to calibrate Stop-A thresholds and compare
how late C converges across workloads. No serving required.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:10:02 +08:00

129 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""Stop-A calibration: print the offered-L-C-A convergence curve for a raw trace window.
The convergence of prefix-vs-full L-C-A is a deterministic property of the trace
metadata (lengths, hash_ids, arrivals), so this runs on CPU without serving the
model. Use it to pick tau / tau_c / stable_checks and to compare how late the C
dimension converges across workloads (e.g. low-reuse chat vs high-reuse coder).
Example:
PYTHONPATH=src python3 scripts/stop_a_calibration.py \
--jsonl /dashscope/.../qwen_chat_blksz_64_032109-032111.jsonl \
--block-size 64 --window-start 3600 --window-end 4200 --gpu-count 8 --label chat
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from aituner.lca import find_convergence_prefix, resolve_length_mode
from aituner.trace import TraceRequest, WindowRecord
def _session_root(row: dict, root_of: dict) -> object:
chat_id = row.get("chat_id")
parent = row.get("parent_chat_id")
parent_is_root = parent is None or (
isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0
)
root = chat_id if parent_is_root else root_of.get(parent, parent)
if chat_id is not None:
root_of[chat_id] = root
return root
def load_window(jsonl: Path, *, window_start: float, window_end: float) -> list[TraceRequest]:
root_of: dict = {}
requests: list[TraceRequest] = []
with jsonl.open(encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
row = json.loads(line)
_session_root(row, root_of) # keep chain complete even outside the window
ts = float(row.get("timestamp") or 0.0)
if not (window_start <= ts < window_end):
continue
hash_ids = row.get("hash_ids")
requests.append(
TraceRequest(
row_id=str(row.get("chat_id")),
arrival_s=ts - window_start,
sampling_u=1.0,
body={},
prompt_tokens_hint=int(row.get("input_length") or 0),
completion_tokens_hint=int(row.get("output_length") or 0),
metadata={"hash_ids": hash_ids if isinstance(hash_ids, list) else None},
)
)
requests.sort(key=lambda item: item.arrival_s)
return requests
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--jsonl", type=Path, required=True)
ap.add_argument("--block-size", type=int, required=True)
ap.add_argument("--window-start", type=float, default=3600.0)
ap.add_argument("--window-end", type=float, default=4200.0)
ap.add_argument("--gpu-count", type=int, default=8)
ap.add_argument("--length-mode", default="total")
ap.add_argument("--label", default="")
ap.add_argument("--tau", type=float, default=0.9)
ap.add_argument("--max-checks", type=int, default=20)
args = ap.parse_args()
requests = load_window(
args.jsonl, window_start=args.window_start, window_end=args.window_end
)
window = WindowRecord(
window_id=args.label or args.jsonl.stem,
trace_path=args.jsonl,
trace_type=args.label or "chat",
window_start=0.0,
window_end=float(args.window_end - args.window_start),
source_payload={"block_size": args.block_size},
)
mode = resolve_length_mode(length_mode=args.length_mode)
rows_with_hash = sum(1 for r in requests if r.metadata.get("hash_ids"))
print(
f"[{args.label}] requests={len(requests)} rows_with_hash_ids={rows_with_hash} "
f"window={args.window_start:.0f}-{args.window_end:.0f}s block_size={args.block_size}"
)
# Full curve (tau_c high so it never short-circuits; we read the curve directly).
point = find_convergence_prefix(
requests, window, gpu_count=args.gpu_count, length_mode=mode,
tau=args.tau, tau_c=1.01, stable_checks=10_000, max_checks=args.max_checks,
min_fraction=0.05,
)
print(" frac time_s L C A")
for c in point.checks:
s = c["family_similarity"]
print(
f" {c['fraction']:.2f} {c['time_s']:7.1f} "
f"{s['L']:.3f} {s['C']:.3f} {s['A']:.3f}"
)
# Stop fraction at candidate tau_c values (L,A >= tau, C >= tau_c, stable for W=3).
print(" -- stop fraction (tau_L=tau_A=%.2f, W=3) --" % args.tau)
for tau_c in (0.85, 0.90, 0.92, 0.95):
p = find_convergence_prefix(
requests, window, gpu_count=args.gpu_count, length_mode=mode,
tau=args.tau, tau_c=tau_c, stable_checks=3, max_checks=args.max_checks,
min_fraction=0.05,
)
verdict = (
f"stop@frac={p.fraction:.2f} t={p.stop_time_s:.0f}s"
if p.converged
else "NEVER (replays full window)"
)
print(f" tau_c={tau_c:.2f}: {verdict}")
return 0
if __name__ == "__main__":
raise SystemExit(main())