Add Stop-A calibration script (CPU-only convergence curve)
Prints the offered-L-C-A convergence curve and the stop fraction at candidate tau_c values for a raw trace window, to calibrate Stop-A thresholds and compare how late C converges across workloads. No serving required. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
128
scripts/stop_a_calibration.py
Normal file
128
scripts/stop_a_calibration.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Stop-A calibration: print the offered-L-C-A convergence curve for a raw trace window.
|
||||
|
||||
The convergence of prefix-vs-full L-C-A is a deterministic property of the trace
|
||||
metadata (lengths, hash_ids, arrivals), so this runs on CPU without serving the
|
||||
model. Use it to pick tau / tau_c / stable_checks and to compare how late the C
|
||||
dimension converges across workloads (e.g. low-reuse chat vs high-reuse coder).
|
||||
|
||||
Example:
|
||||
PYTHONPATH=src python3 scripts/stop_a_calibration.py \
|
||||
--jsonl /dashscope/.../qwen_chat_blksz_64_032109-032111.jsonl \
|
||||
--block-size 64 --window-start 3600 --window-end 4200 --gpu-count 8 --label chat
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from aituner.lca import find_convergence_prefix, resolve_length_mode
|
||||
from aituner.trace import TraceRequest, WindowRecord
|
||||
|
||||
|
||||
def _session_root(row: dict, root_of: dict) -> object:
|
||||
chat_id = row.get("chat_id")
|
||||
parent = row.get("parent_chat_id")
|
||||
parent_is_root = parent is None or (
|
||||
isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0
|
||||
)
|
||||
root = chat_id if parent_is_root else root_of.get(parent, parent)
|
||||
if chat_id is not None:
|
||||
root_of[chat_id] = root
|
||||
return root
|
||||
|
||||
|
||||
def load_window(jsonl: Path, *, window_start: float, window_end: float) -> list[TraceRequest]:
|
||||
root_of: dict = {}
|
||||
requests: list[TraceRequest] = []
|
||||
with jsonl.open(encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
_session_root(row, root_of) # keep chain complete even outside the window
|
||||
ts = float(row.get("timestamp") or 0.0)
|
||||
if not (window_start <= ts < window_end):
|
||||
continue
|
||||
hash_ids = row.get("hash_ids")
|
||||
requests.append(
|
||||
TraceRequest(
|
||||
row_id=str(row.get("chat_id")),
|
||||
arrival_s=ts - window_start,
|
||||
sampling_u=1.0,
|
||||
body={},
|
||||
prompt_tokens_hint=int(row.get("input_length") or 0),
|
||||
completion_tokens_hint=int(row.get("output_length") or 0),
|
||||
metadata={"hash_ids": hash_ids if isinstance(hash_ids, list) else None},
|
||||
)
|
||||
)
|
||||
requests.sort(key=lambda item: item.arrival_s)
|
||||
return requests
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--jsonl", type=Path, required=True)
|
||||
ap.add_argument("--block-size", type=int, required=True)
|
||||
ap.add_argument("--window-start", type=float, default=3600.0)
|
||||
ap.add_argument("--window-end", type=float, default=4200.0)
|
||||
ap.add_argument("--gpu-count", type=int, default=8)
|
||||
ap.add_argument("--length-mode", default="total")
|
||||
ap.add_argument("--label", default="")
|
||||
ap.add_argument("--tau", type=float, default=0.9)
|
||||
ap.add_argument("--max-checks", type=int, default=20)
|
||||
args = ap.parse_args()
|
||||
|
||||
requests = load_window(
|
||||
args.jsonl, window_start=args.window_start, window_end=args.window_end
|
||||
)
|
||||
window = WindowRecord(
|
||||
window_id=args.label or args.jsonl.stem,
|
||||
trace_path=args.jsonl,
|
||||
trace_type=args.label or "chat",
|
||||
window_start=0.0,
|
||||
window_end=float(args.window_end - args.window_start),
|
||||
source_payload={"block_size": args.block_size},
|
||||
)
|
||||
mode = resolve_length_mode(length_mode=args.length_mode)
|
||||
rows_with_hash = sum(1 for r in requests if r.metadata.get("hash_ids"))
|
||||
print(
|
||||
f"[{args.label}] requests={len(requests)} rows_with_hash_ids={rows_with_hash} "
|
||||
f"window={args.window_start:.0f}-{args.window_end:.0f}s block_size={args.block_size}"
|
||||
)
|
||||
|
||||
# Full curve (tau_c high so it never short-circuits; we read the curve directly).
|
||||
point = find_convergence_prefix(
|
||||
requests, window, gpu_count=args.gpu_count, length_mode=mode,
|
||||
tau=args.tau, tau_c=1.01, stable_checks=10_000, max_checks=args.max_checks,
|
||||
min_fraction=0.05,
|
||||
)
|
||||
print(" frac time_s L C A")
|
||||
for c in point.checks:
|
||||
s = c["family_similarity"]
|
||||
print(
|
||||
f" {c['fraction']:.2f} {c['time_s']:7.1f} "
|
||||
f"{s['L']:.3f} {s['C']:.3f} {s['A']:.3f}"
|
||||
)
|
||||
|
||||
# Stop fraction at candidate tau_c values (L,A >= tau, C >= tau_c, stable for W=3).
|
||||
print(" -- stop fraction (tau_L=tau_A=%.2f, W=3) --" % args.tau)
|
||||
for tau_c in (0.85, 0.90, 0.92, 0.95):
|
||||
p = find_convergence_prefix(
|
||||
requests, window, gpu_count=args.gpu_count, length_mode=mode,
|
||||
tau=args.tau, tau_c=tau_c, stable_checks=3, max_checks=args.max_checks,
|
||||
min_fraction=0.05,
|
||||
)
|
||||
verdict = (
|
||||
f"stop@frac={p.fraction:.2f} t={p.stop_time_s:.0f}s"
|
||||
if p.converged
|
||||
else "NEVER (replays full window)"
|
||||
)
|
||||
print(f" tau_c={tau_c:.2f}: {verdict}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user