Add Stop-A calibration script (CPU-only convergence curve)
Prints the offered-L-C-A convergence curve and the stop fraction at candidate tau_c values for a raw trace window, to calibrate Stop-A thresholds and compare how late C converges across workloads. No serving required. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
128
scripts/stop_a_calibration.py
Normal file
128
scripts/stop_a_calibration.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Stop-A calibration: print the offered-L-C-A convergence curve for a raw trace window.
|
||||||
|
|
||||||
|
The convergence of prefix-vs-full L-C-A is a deterministic property of the trace
|
||||||
|
metadata (lengths, hash_ids, arrivals), so this runs on CPU without serving the
|
||||||
|
model. Use it to pick tau / tau_c / stable_checks and to compare how late the C
|
||||||
|
dimension converges across workloads (e.g. low-reuse chat vs high-reuse coder).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
PYTHONPATH=src python3 scripts/stop_a_calibration.py \
|
||||||
|
--jsonl /dashscope/.../qwen_chat_blksz_64_032109-032111.jsonl \
|
||||||
|
--block-size 64 --window-start 3600 --window-end 4200 --gpu-count 8 --label chat
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from aituner.lca import find_convergence_prefix, resolve_length_mode
|
||||||
|
from aituner.trace import TraceRequest, WindowRecord
|
||||||
|
|
||||||
|
|
||||||
|
def _session_root(row: dict, root_of: dict) -> object:
|
||||||
|
chat_id = row.get("chat_id")
|
||||||
|
parent = row.get("parent_chat_id")
|
||||||
|
parent_is_root = parent is None or (
|
||||||
|
isinstance(parent, (int, float)) and not isinstance(parent, bool) and int(parent) < 0
|
||||||
|
)
|
||||||
|
root = chat_id if parent_is_root else root_of.get(parent, parent)
|
||||||
|
if chat_id is not None:
|
||||||
|
root_of[chat_id] = root
|
||||||
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
def load_window(jsonl: Path, *, window_start: float, window_end: float) -> list[TraceRequest]:
|
||||||
|
root_of: dict = {}
|
||||||
|
requests: list[TraceRequest] = []
|
||||||
|
with jsonl.open(encoding="utf-8") as handle:
|
||||||
|
for line in handle:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
row = json.loads(line)
|
||||||
|
_session_root(row, root_of) # keep chain complete even outside the window
|
||||||
|
ts = float(row.get("timestamp") or 0.0)
|
||||||
|
if not (window_start <= ts < window_end):
|
||||||
|
continue
|
||||||
|
hash_ids = row.get("hash_ids")
|
||||||
|
requests.append(
|
||||||
|
TraceRequest(
|
||||||
|
row_id=str(row.get("chat_id")),
|
||||||
|
arrival_s=ts - window_start,
|
||||||
|
sampling_u=1.0,
|
||||||
|
body={},
|
||||||
|
prompt_tokens_hint=int(row.get("input_length") or 0),
|
||||||
|
completion_tokens_hint=int(row.get("output_length") or 0),
|
||||||
|
metadata={"hash_ids": hash_ids if isinstance(hash_ids, list) else None},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
requests.sort(key=lambda item: item.arrival_s)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--jsonl", type=Path, required=True)
|
||||||
|
ap.add_argument("--block-size", type=int, required=True)
|
||||||
|
ap.add_argument("--window-start", type=float, default=3600.0)
|
||||||
|
ap.add_argument("--window-end", type=float, default=4200.0)
|
||||||
|
ap.add_argument("--gpu-count", type=int, default=8)
|
||||||
|
ap.add_argument("--length-mode", default="total")
|
||||||
|
ap.add_argument("--label", default="")
|
||||||
|
ap.add_argument("--tau", type=float, default=0.9)
|
||||||
|
ap.add_argument("--max-checks", type=int, default=20)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
requests = load_window(
|
||||||
|
args.jsonl, window_start=args.window_start, window_end=args.window_end
|
||||||
|
)
|
||||||
|
window = WindowRecord(
|
||||||
|
window_id=args.label or args.jsonl.stem,
|
||||||
|
trace_path=args.jsonl,
|
||||||
|
trace_type=args.label or "chat",
|
||||||
|
window_start=0.0,
|
||||||
|
window_end=float(args.window_end - args.window_start),
|
||||||
|
source_payload={"block_size": args.block_size},
|
||||||
|
)
|
||||||
|
mode = resolve_length_mode(length_mode=args.length_mode)
|
||||||
|
rows_with_hash = sum(1 for r in requests if r.metadata.get("hash_ids"))
|
||||||
|
print(
|
||||||
|
f"[{args.label}] requests={len(requests)} rows_with_hash_ids={rows_with_hash} "
|
||||||
|
f"window={args.window_start:.0f}-{args.window_end:.0f}s block_size={args.block_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Full curve (tau_c high so it never short-circuits; we read the curve directly).
|
||||||
|
point = find_convergence_prefix(
|
||||||
|
requests, window, gpu_count=args.gpu_count, length_mode=mode,
|
||||||
|
tau=args.tau, tau_c=1.01, stable_checks=10_000, max_checks=args.max_checks,
|
||||||
|
min_fraction=0.05,
|
||||||
|
)
|
||||||
|
print(" frac time_s L C A")
|
||||||
|
for c in point.checks:
|
||||||
|
s = c["family_similarity"]
|
||||||
|
print(
|
||||||
|
f" {c['fraction']:.2f} {c['time_s']:7.1f} "
|
||||||
|
f"{s['L']:.3f} {s['C']:.3f} {s['A']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop fraction at candidate tau_c values (L,A >= tau, C >= tau_c, stable for W=3).
|
||||||
|
print(" -- stop fraction (tau_L=tau_A=%.2f, W=3) --" % args.tau)
|
||||||
|
for tau_c in (0.85, 0.90, 0.92, 0.95):
|
||||||
|
p = find_convergence_prefix(
|
||||||
|
requests, window, gpu_count=args.gpu_count, length_mode=mode,
|
||||||
|
tau=args.tau, tau_c=tau_c, stable_checks=3, max_checks=args.max_checks,
|
||||||
|
min_fraction=0.05,
|
||||||
|
)
|
||||||
|
verdict = (
|
||||||
|
f"stop@frac={p.fraction:.2f} t={p.stop_time_s:.0f}s"
|
||||||
|
if p.converged
|
||||||
|
else "NEVER (replays full window)"
|
||||||
|
)
|
||||||
|
print(f" tau_c={tau_c:.2f}: {verdict}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Reference in New Issue
Block a user