#!/usr/bin/env python3 """Stop-A calibration: print the offered-L-C-A convergence curve for a raw trace window. The convergence of prefix-vs-full L-C-A is a deterministic property of the trace metadata (lengths, hash_ids, arrivals), so this runs on CPU without serving the model. Use it to pick tau / tau_c / stable_checks and to compare how late the C dimension converges across workloads (e.g. low-reuse chat vs high-reuse coder). Example: PYTHONPATH=src python3 scripts/stop_a_calibration.py \ --jsonl /dashscope/.../qwen_chat_blksz_64_032109-032111.jsonl \ --block-size 64 --window-start 3600 --window-end 4200 --gpu-count 8 --label chat """ from __future__ import annotations import argparse import json from pathlib import Path from aituner.lca import find_convergence_prefix, resolve_length_mode from aituner.trace import TraceRequest, WindowRecord def _session_root(row: dict, root_of: dict) -> object: chat_id = row.get("chat_id") parent = row.get("parent_chat_id") parent_is_root = parent is None or ( isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0 ) root = chat_id if parent_is_root else root_of.get(parent, parent) if chat_id is not None: root_of[chat_id] = root return root def load_window(jsonl: Path, *, window_start: float, window_end: float) -> list[TraceRequest]: root_of: dict = {} requests: list[TraceRequest] = [] with jsonl.open(encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue row = json.loads(line) _session_root(row, root_of) # keep chain complete even outside the window ts = float(row.get("timestamp") or 0.0) if not (window_start <= ts < window_end): continue hash_ids = row.get("hash_ids") requests.append( TraceRequest( row_id=str(row.get("chat_id")), arrival_s=ts - window_start, sampling_u=1.0, body={}, prompt_tokens_hint=int(row.get("input_length") or 0), completion_tokens_hint=int(row.get("output_length") or 0), metadata={"hash_ids": hash_ids if isinstance(hash_ids, list) else None}, ) ) requests.sort(key=lambda item: item.arrival_s) return requests def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--jsonl", type=Path, required=True) ap.add_argument("--block-size", type=int, required=True) ap.add_argument("--window-start", type=float, default=3600.0) ap.add_argument("--window-end", type=float, default=4200.0) ap.add_argument("--gpu-count", type=int, default=8) ap.add_argument("--length-mode", default="total") ap.add_argument("--label", default="") ap.add_argument("--tau", type=float, default=0.9) ap.add_argument("--max-checks", type=int, default=20) args = ap.parse_args() requests = load_window( args.jsonl, window_start=args.window_start, window_end=args.window_end ) window = WindowRecord( window_id=args.label or args.jsonl.stem, trace_path=args.jsonl, trace_type=args.label or "chat", window_start=0.0, window_end=float(args.window_end - args.window_start), source_payload={"block_size": args.block_size}, ) mode = resolve_length_mode(length_mode=args.length_mode) rows_with_hash = sum(1 for r in requests if r.metadata.get("hash_ids")) print( f"[{args.label}] requests={len(requests)} rows_with_hash_ids={rows_with_hash} " f"window={args.window_start:.0f}-{args.window_end:.0f}s block_size={args.block_size}" ) # Full curve (tau_c high so it never short-circuits; we read the curve directly). point = find_convergence_prefix( requests, window, gpu_count=args.gpu_count, length_mode=mode, tau=args.tau, tau_c=1.01, stable_checks=10_000, max_checks=args.max_checks, min_fraction=0.05, ) print(" frac time_s L C A") for c in point.checks: s = c["family_similarity"] print( f" {c['fraction']:.2f} {c['time_s']:7.1f} " f"{s['L']:.3f} {s['C']:.3f} {s['A']:.3f}" ) # Stop fraction at candidate tau_c values (L,A >= tau, C >= tau_c, stable for W=3). print(" -- stop fraction (tau_L=tau_A=%.2f, W=3) --" % args.tau) for tau_c in (0.85, 0.90, 0.92, 0.95): p = find_convergence_prefix( requests, window, gpu_count=args.gpu_count, length_mode=mode, tau=args.tau, tau_c=tau_c, stable_checks=3, max_checks=args.max_checks, min_fraction=0.05, ) verdict = ( f"stop@frac={p.fraction:.2f} t={p.stop_time_s:.0f}s" if p.converged else "NEVER (replays full window)" ) print(f" tau_c={tau_c:.2f}: {verdict}") return 0 if __name__ == "__main__": raise SystemExit(main())