Window 1 analysis: APC upper bound, B2 window-overlap, figure renderer

Three CPU-only analysis pieces that turn raw Window 1 artifacts into
publishable numbers and figures.

scripts/compute_apc_upper_bound.py
  Block-level trie walk over hash_ids to compute the theoretical APC
  ceiling on a trace, decomposed into intra-session / any-session /
  shared-prefix-only. Gives a fixed reference for what each routing
  policy could *possibly* achieve. w600 result: 79.6% intra-session,
  80.3% any-session, 0.1% shared-prefix.

analysis/characterization/b2_sweep_analysis.py (rewrite)
  Previous version used joined_analysis.interference_index() which
  labeled overlap = "any prefill in any other request during this
  decode". With short-prompt decode load this is always true
  (everyone's prefill overlaps everyone else's decode); n_overlap
  was 239/240 even in the different-worker control.

  New version labels overlap iff the decode's [t_first_token, t_finish]
  intersects an actual large *injection* window, computed from the
  cell's "prefill"-tagged metric rows. Different-worker control now
  cleanly sits at idx ≈ 1.0, same-worker scales monotonically.

analysis/characterization/render_window1_figures.py
  Renders 8 PNGs from the result JSONs: B3 latency / APC vs ceiling
  / APC vs hotspot scatter / per-worker TTFT / failure breakdown,
  B2 TPOT and TTFT curves (overlap vs clean and idx), reuse
  decomposition, KV footprint.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 23:24:54 +08:00
parent b9f324f2e6
commit b7902061d1
3 changed files with 539 additions and 68 deletions

View File

@@ -1,10 +1,24 @@
"""Aggregate B2 microbench cells into a single interference-index sweep summary.
"""Aggregate B2 microbench cells: same- vs different-worker prefill overlap.
Per cell (variant × prefill_size):
- read metrics.jsonl + run_window.json
- slice the shared engine_*.jsonl by run window
- run interference_index() against the slice
- record (variant, prefill_size, n_overlap, n_clean, tpot_p90_*, idx)
For each (variant × prefill_size) cell we have:
- 240 short-prompt decode requests at qps=4
- 4 large-prompt one-token "prefill injections"
The interesting question is *not* "does any other request's prefill overlap
this decode" (the answer is always yes — every decode begins with its own
short prefill, and at qps=4 they overlap each other constantly). The
interesting question is "does an injected large prefill on the *same* worker
materially slow this decode down?".
So we:
1) extract each cell's injection windows = [(t_dispatch, t_finish)
for r in metrics if r.workload=="prefill"];
2) label each decode request as overlap iff its
[t_first_token, t_finish] intersects at least one injection window;
3) compute TPOT p50/p90/p99 for overlap vs clean;
4) the per-cell interference index = TPOT_p90(overlap) /
TPOT_p90(clean). For "different" variant this should hover near 1.0;
for "same" it should rise with prefill_size.
"""
from __future__ import annotations
@@ -13,71 +27,77 @@ import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Any
from analysis.characterization.joined_analysis import (
_percentile,
_vllm_rid_matches,
interference_index,
load_engine_state,
load_jsonl,
write_json,
)
def _slice_engine_state(
engine_state_by_worker: dict[str, list[dict]],
t_start: float,
t_end: float,
) -> dict[str, list[dict]]:
sliced: dict[str, list[dict]] = {}
for worker, steps in engine_state_by_worker.items():
sliced[worker] = [s for s in steps
if t_start <= (s.get("t_unix") or 0.0) <= t_end]
return sliced
def _overlaps(a_start: float, a_end: float, b_start: float, b_end: float) -> bool:
return a_start <= b_end and b_start <= a_end
def _to_joined_shape(metrics_rows: list[dict], variant: str) -> list[dict]:
"""Adapt B2 metric rows to what interference_index expects."""
joined: list[dict] = []
for r in metrics_rows:
if r.get("workload") != "decode":
def _analyze_cell(metrics_rows: list[dict]) -> dict:
prefills = [r for r in metrics_rows if r.get("workload") == "prefill"
and r.get("error") is None]
decodes = [r for r in metrics_rows if r.get("workload") == "decode"
and r.get("error") is None]
inj_windows: list[tuple[float, float]] = []
for p in prefills:
ts = p.get("t_dispatch_unix")
te = p.get("t_finish_unix")
if ts is None or te is None:
continue
joined.append({
"request_id": r["request_id"],
"tpot_s": r.get("tpot_s"),
"ttft_s": r.get("ttft_s"),
"latency_s": r.get("latency_s"),
"endpoint_url": r.get("endpoint"),
"routed_to": r.get("endpoint"),
"t_first_token_unix": (
(r["t_dispatch_unix"] + r["ttft_s"])
if r.get("ttft_s") is not None
and r.get("t_dispatch_unix") is not None else None
),
"t_finish_unix": r.get("t_finish_unix"),
"error": r.get("error"),
})
return joined
inj_windows.append((float(ts), float(te)))
overlap_tpots: list[float] = []
clean_tpots: list[float] = []
overlap_ttfts: list[float] = []
clean_ttfts: list[float] = []
for d in decodes:
ts = d.get("t_dispatch_unix")
te = d.get("t_finish_unix")
if ts is None or te is None:
continue
is_overlap = any(_overlaps(ts, te, ws, we) for ws, we in inj_windows)
tpot = d.get("tpot_s")
ttft = d.get("ttft_s")
if tpot is not None:
(overlap_tpots if is_overlap else clean_tpots).append(float(tpot))
if ttft is not None:
(overlap_ttfts if is_overlap else clean_ttfts).append(float(ttft))
p90_overlap = _percentile(overlap_tpots, 0.90) if overlap_tpots else None
p90_clean = _percentile(clean_tpots, 0.90) if clean_tpots else None
idx = (p90_overlap / p90_clean) if (p90_overlap and p90_clean) else None
return {
"n_prefill_injections": len(prefills),
"n_decode_total": len(decodes),
"n_decode_overlap": len(overlap_tpots),
"n_decode_clean": len(clean_tpots),
"tpot_p50_overlap_s": _percentile(overlap_tpots, 0.50),
"tpot_p90_overlap_s": p90_overlap,
"tpot_p99_overlap_s": _percentile(overlap_tpots, 0.99),
"tpot_p50_clean_s": _percentile(clean_tpots, 0.50),
"tpot_p90_clean_s": p90_clean,
"tpot_p99_clean_s": _percentile(clean_tpots, 0.99),
"ttft_p90_overlap_s": _percentile(overlap_ttfts, 0.90)
if overlap_ttfts else None,
"ttft_p90_clean_s": _percentile(clean_ttfts, 0.90)
if clean_ttfts else None,
"interference_index": idx,
}
def main() -> None:
p = argparse.ArgumentParser(description="B2 sweep aggregation")
p.add_argument("--sweep-dir", type=Path, required=True,
help="Top-level dir produced by scripts/b2_interference.py")
p.add_argument("--engine-state-dir", type=Path, required=True)
p.add_argument("--worker-map", type=str, required=True,
help="URL=worker_id pairs, comma-separated")
p = argparse.ArgumentParser(description="B2 sweep aggregation (window-overlap)")
p.add_argument("--sweep-dir", type=Path, required=True)
p.add_argument("--out", type=Path, default=None)
args = p.parse_args()
worker_map = {}
for entry in args.worker_map.split(","):
url, _, wid = entry.strip().partition("=")
if url and wid:
worker_map[url] = wid
engine_state = load_engine_state(args.engine_state_dir)
rows: list[dict] = []
for variant_dir in sorted(args.sweep_dir.glob("*/")):
if variant_dir.name in ("logs",):
@@ -89,27 +109,16 @@ def main() -> None:
continue
window = json.loads(window_path.read_text())
metrics_rows = load_jsonl(metrics_path)
joined = _to_joined_shape(metrics_rows, variant_dir.name)
sliced = _slice_engine_state(
engine_state, window["t_start_unix"], window["t_end_unix"],
)
idx = interference_index(joined, sliced, worker_map)
cell = _analyze_cell(metrics_rows)
rows.append({
"variant": variant_dir.name,
"prefill_size": int(window["prefill_size"]),
"decode_endpoint": window["decode_endpoint"],
"prefill_endpoint": window["prefill_endpoint"],
"n_decode_requests": sum(1 for r in metrics_rows
if r.get("workload") == "decode"
and r.get("error") is None),
"n_prefill_injections": sum(1 for r in metrics_rows
if r.get("workload") == "prefill"
and r.get("error") is None),
**idx,
**cell,
})
summary = {"rows": rows}
out_path = args.out or args.sweep_dir / "b2_sweep_summary.json"
write_json(out_path, summary)
write_json(out_path, {"rows": rows})
print(json.dumps(rows, indent=2))

View File

@@ -0,0 +1,303 @@
"""Render PNG figures for Window 1 results (B1', B2, B3).
Inputs (all expected under <results-dir>):
- b3_policy_comparison.json (per-policy table)
- b2_sweep_summary.json (per-cell B2 sweep)
- apc_upper_w600.json (theoretical bounds)
- lmetric_reuse.json (intra/cross/shared decomp)
- kv_footprint_summary.json (full trace KV stats)
Outputs (under <out-dir>):
- fig_b3_apc_vs_hotspot.png
- fig_b3_latency_bars.png
- fig_b3_apc_vs_upper.png
- fig_b3_failure_breakdown.png
- fig_b3_per_worker_ttft_p90.png
- fig_b2_tpot_vs_prefill.png
- fig_b2_ttft_vs_prefill.png
- fig_reuse_decomposition.png
- fig_kv_footprint_cdf.png
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
POLICY_ORDER = ["lmetric", "load_only", "sticky", "unified", "capped"]
POLICY_COLOR = {
"lmetric": "#1f77b4",
"load_only": "#ff7f0e",
"sticky": "#d62728",
"unified": "#2ca02c",
"capped": "#9467bd",
}
def _load(results_dir: Path, name: str) -> dict:
return json.loads((results_dir / name).read_text())
def fig_b3_apc_vs_hotspot(comp: dict, upper: dict, out: Path) -> None:
upper_intra = upper["apc_upper_intra_session"]
fig, ax = plt.subplots(figsize=(6, 4.5))
for r in comp["rows"]:
pol = r["policy"]
ax.scatter(r["apc_ratio"] * 100, r["hotspot_index_ttft_p90"],
s=180, color=POLICY_COLOR.get(pol, "gray"), label=pol,
edgecolors="black", linewidths=0.5)
ax.annotate(pol, (r["apc_ratio"] * 100, r["hotspot_index_ttft_p90"]),
xytext=(7, 7), textcoords="offset points",
fontsize=9)
ax.axvline(upper_intra * 100, linestyle="--", color="gray", alpha=0.6,
label=f"intra-session APC upper {upper_intra * 100:.1f}%")
ax.set_xlabel("APC achieved (%)")
ax.set_ylabel("hotspot_index = max(worker TTFT p90) / median")
ax.set_title("B3: APC vs hot-spot tradeoff across policies")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_b3_latency_bars(comp: dict, out: Path) -> None:
by = {r["policy"]: r for r in comp["rows"]}
pols = [p for p in POLICY_ORDER if p in by]
metrics = [("TTFT p90 (s)", "ttft_p90_s"),
("TPOT p90 (ms)", "tpot_p90_s"),
("E2E p90 (s)", "e2e_p90_s")]
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, (label, key) in zip(axes, metrics):
vals = [by[p][key] * (1000 if "TPOT" in label else 1) for p in pols]
ax.bar(pols, vals, color=[POLICY_COLOR.get(p, "gray") for p in pols],
edgecolor="black", linewidth=0.5)
ax.set_title(label)
ax.tick_params(axis="x", rotation=20)
for i, v in enumerate(vals):
ax.text(i, v, f"{v:.1f}", ha="center", va="bottom", fontsize=9)
ax.grid(alpha=0.3, axis="y")
fig.suptitle("B3 headline latencies per policy")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_b3_apc_vs_upper(comp: dict, upper: dict, out: Path) -> None:
by = {r["policy"]: r for r in comp["rows"]}
pols = [p for p in POLICY_ORDER if p in by]
achieved = [by[p]["apc_ratio"] * 100 for p in pols]
fig, ax = plt.subplots(figsize=(6.5, 4))
bars = ax.bar(pols, achieved,
color=[POLICY_COLOR.get(p, "gray") for p in pols],
edgecolor="black", linewidth=0.5)
ax.axhline(upper["apc_upper_intra_session"] * 100, linestyle="--",
color="black", alpha=0.7,
label=f"intra-session ceiling {upper['apc_upper_intra_session'] * 100:.1f}%")
ax.axhline(upper["apc_upper_any_session"] * 100, linestyle=":",
color="darkgray", alpha=0.7,
label=f"any-session ceiling {upper['apc_upper_any_session'] * 100:.1f}%")
for b, v in zip(bars, achieved):
ax.text(b.get_x() + b.get_width() / 2, v + 1, f"{v:.1f}%",
ha="center", fontsize=9)
ax.set_ylim(0, 100)
ax.set_ylabel("APC ratio (%)")
ax.set_title("B3: APC achieved vs theoretical ceiling")
ax.legend(loc="upper right", fontsize=9)
ax.grid(alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_b3_failure_breakdown(comp: dict, out: Path) -> None:
by = {r["policy"]: r for r in comp["rows"]}
pols = [p for p in POLICY_ORDER if p in by]
causes = ["same_worker_prefill_overlap", "hot_worker_queue",
"cache_miss_large_append", "high_kv_occupancy", "unknown"]
cause_color = {
"same_worker_prefill_overlap": "#d62728",
"hot_worker_queue": "#ff7f0e",
"cache_miss_large_append": "#1f77b4",
"high_kv_occupancy": "#8c564b",
"unknown": "#7f7f7f",
}
fig, ax = plt.subplots(figsize=(7, 4.5))
bottom = [0.0] * len(pols)
for c in causes:
vals = [(by[p].get("failure_counts") or {}).get(c, 0) for p in pols]
ax.bar(pols, vals, bottom=bottom, label=c.replace("_", " "),
color=cause_color[c], edgecolor="black", linewidth=0.3)
bottom = [a + b for a, b in zip(bottom, vals)]
for i, total in enumerate(bottom):
ax.text(i, total + 3, f"n={int(total)}", ha="center", fontsize=9)
ax.set_ylabel("slow request count (TTFT > 2× p90 threshold)")
ax.set_title("B3: slow-request cause breakdown per policy")
ax.legend(fontsize=8, loc="upper right")
ax.grid(alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_b3_per_worker_ttft(results_dir: Path, comp: dict, out: Path) -> None:
"""Per-worker TTFT p90 grouped bars; reads each policy's hotspot_index.json."""
by = {r["policy"]: r for r in comp["rows"]}
pols = [p for p in POLICY_ORDER if p in by]
fig, axes = plt.subplots(1, len(pols), figsize=(3 * len(pols), 4),
sharey=True)
if len(pols) == 1:
axes = [axes]
for ax, pol in zip(axes, pols):
path = results_dir / f"per_worker_{pol}.json"
if not path.exists():
ax.text(0.5, 0.5, f"{pol}: no data", ha="center", va="center",
transform=ax.transAxes)
continue
per = json.loads(path.read_text()).get("per_worker_ttft_p90_s") or {}
items = sorted(per.items(), key=lambda kv: int(kv[0].rsplit(":", 1)[1]))
labels = [f"e{int(k.rsplit(':', 1)[1]) - 8000}" for k, _ in items]
vals = [v for _, v in items]
ax.bar(labels, vals, color=POLICY_COLOR.get(pol, "gray"),
edgecolor="black", linewidth=0.5)
for i, v in enumerate(vals):
ax.text(i, v, f"{v:.1f}", ha="center", va="bottom", fontsize=8)
ax.set_title(f"{pol}\nhotspot={by[pol]['hotspot_index_ttft_p90']:.2f}",
fontsize=10)
ax.tick_params(axis="x", labelsize=8)
ax.grid(alpha=0.3, axis="y")
axes[0].set_ylabel("worker TTFT p90 (s)")
fig.suptitle("B3 per-worker TTFT p90 distribution")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_b2_curves(b2: dict, out_tpot: Path, out_ttft: Path) -> None:
sizes = sorted({r["prefill_size"] for r in b2["rows"]})
by_var = {"same": {}, "different": {}}
for r in b2["rows"]:
by_var[r["variant"]][r["prefill_size"]] = r
for name, key, ylabel, ymax_log, out in [
("TPOT", "tpot_p90", "TPOT p90 (ms)", True, out_tpot),
("TTFT", "ttft_p90", "TTFT p90 (s)", True, out_ttft),
]:
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
ax_abs, ax_idx = axes
for variant in ("different", "same"):
xs, ys_o, ys_c, idxs = [], [], [], []
for sz in sizes:
r = by_var[variant].get(sz)
if not r: continue
ov = r.get(f"{key}_overlap_s")
cl = r.get(f"{key}_clean_s")
if ov is None or cl is None: continue
xs.append(sz)
scale = 1000 if name == "TPOT" else 1.0
ys_o.append(ov * scale)
ys_c.append(cl * scale)
idxs.append(ov / cl)
color = "#d62728" if variant == "same" else "#1f77b4"
ax_abs.plot(xs, ys_o, "o-", color=color,
label=f"{variant} (overlap)")
ax_abs.plot(xs, ys_c, "s--", color=color, alpha=0.5,
label=f"{variant} (clean)")
ax_idx.plot(xs, idxs, "o-", color=color, label=variant,
linewidth=2)
ax_abs.set_xscale("log", base=2)
ax_abs.set_yscale("log")
ax_abs.set_xlabel("prefill injection size (tokens)")
ax_abs.set_ylabel(ylabel + " (log)")
ax_abs.set_title(f"B2 {name} absolute (overlap vs clean)")
ax_abs.legend(fontsize=8)
ax_abs.grid(alpha=0.3, which="both")
ax_idx.set_xscale("log", base=2)
if ymax_log:
ax_idx.set_yscale("log")
ax_idx.axhline(1.0, color="black", linestyle=":", alpha=0.5)
ax_idx.set_xlabel("prefill injection size (tokens)")
ax_idx.set_ylabel(f"{name} idx = overlap / clean")
ax_idx.set_title(f"B2 {name} interference index (same vs different worker)")
ax_idx.legend()
ax_idx.grid(alpha=0.3, which="both")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_reuse_decomposition(reuse: dict, out: Path) -> None:
fr = reuse.get("fractions") or {}
labels = ["intra-session", "cross-session", "shared-prefix", "unclassified"]
vals = [fr.get("intra", 0), fr.get("cross", 0),
fr.get("shared", 0), fr.get("unclassified", 0)]
colors = ["#2ca02c", "#ff7f0e", "#9467bd", "#7f7f7f"]
fig, ax = plt.subplots(figsize=(6, 3))
bottom = 0.0
for label, v, c in zip(labels, vals, colors):
ax.barh(["lmetric run"], [v], left=[bottom], color=c, edgecolor="black",
linewidth=0.5, label=f"{label} ({v * 100:.1f}%)")
bottom += v
ax.set_xlabel("fraction of cached_tokens")
ax.set_xlim(0, 1)
ax.set_title("Real reuse decomposition (w600 lmetric run)")
ax.legend(fontsize=9, loc="lower right")
ax.grid(alpha=0.3, axis="x")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def fig_kv_footprint_cdf(kv: dict, out: Path) -> None:
s = kv.get("kv_mib_per_request") or {}
vals = [s.get(k) for k in ("p50", "p90", "p95", "p99")]
labels = ["p50", "p90", "p95", "p99"]
fig, ax = plt.subplots(figsize=(6, 3.5))
ax.bar(labels, vals, color="#1f77b4", edgecolor="black", linewidth=0.5)
for i, v in enumerate(vals):
ax.text(i, v, f"{v:.0f} MiB", ha="center", va="bottom", fontsize=9)
ax.axhline(95 * 1024, color="red", linestyle="--", alpha=0.5,
label="H20 ~95 GiB usable")
ax.set_ylabel("KV bytes per request (MiB)")
ax.set_title("B1' Per-request KV footprint (Qwen3-Coder-30B-A3B, 98304 B/token)")
ax.legend()
ax.grid(alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(out, dpi=120)
plt.close(fig)
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--results-dir", type=Path, required=True)
p.add_argument("--out-dir", type=Path, required=True)
args = p.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
comp = _load(args.results_dir, "b3_policy_comparison.json")
upper = _load(args.results_dir, "apc_upper_w600.json")
b2 = _load(args.results_dir, "b2_sweep_summary.json")
reuse = _load(args.results_dir, "lmetric_reuse.json")
kv = _load(args.results_dir, "kv_footprint_summary.json")
fig_b3_apc_vs_hotspot(comp, upper, args.out_dir / "fig_b3_apc_vs_hotspot.png")
fig_b3_latency_bars(comp, args.out_dir / "fig_b3_latency_bars.png")
fig_b3_apc_vs_upper(comp, upper, args.out_dir / "fig_b3_apc_vs_upper.png")
fig_b3_failure_breakdown(comp, args.out_dir / "fig_b3_failure_breakdown.png")
fig_b3_per_worker_ttft(args.results_dir, comp,
args.out_dir / "fig_b3_per_worker_ttft_p90.png")
fig_b2_curves(b2,
args.out_dir / "fig_b2_tpot_vs_prefill.png",
args.out_dir / "fig_b2_ttft_vs_prefill.png")
fig_reuse_decomposition(reuse, args.out_dir / "fig_reuse_decomposition.png")
fig_kv_footprint_cdf(kv, args.out_dir / "fig_kv_footprint_cdf.png")
print(f"wrote 8 figures to {args.out_dir}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,159 @@
"""Compute theoretical APC upper bound from a trace's hash_ids.
vLLM prefix caching is block-level (BLOCK_SIZE token chunks) and chain-
aware: block i hits the cache iff hash_ids[0..i] matches some previously
seen request's hash_ids[0..i]. The trace's `hash_ids` field is the same
hash chain.
Three variants:
- any_session : trie built across all previously seen requests
- intra_session : trie scoped to the request's own session_id
- shared_prefix : trie built from blocks at position 0, that appear
across >= K distinct sessions (system-prompt proxy)
"""
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from pathlib import Path
BLOCK_SIZE_DEFAULT = 512
def _walk(trie: dict, hashes: list[int]) -> int:
depth = 0
node = trie
for h in hashes:
if h in node:
depth += 1
node = node[h]
else:
break
return depth
def _insert(trie: dict, hashes: list[int]) -> None:
node = trie
for h in hashes:
node = node.setdefault(h, {})
def _resolve_session(row: dict, chat_to_session: dict[int, str]) -> str:
if "session_id" in row:
return str(row["session_id"])
cid = int(row["chat_id"])
pcid = int(row["parent_chat_id"])
if pcid < 0:
sid = str(cid)
else:
sid = chat_to_session.get(pcid, str(pcid))
chat_to_session[cid] = sid
return sid
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--trace", type=Path, required=True)
p.add_argument("--block-size", type=int, default=BLOCK_SIZE_DEFAULT)
p.add_argument("--shared-prefix-min-sessions", type=int, default=8)
p.add_argument("--out", type=Path, default=None)
args = p.parse_args()
rows: list[dict] = []
chat_to_session: dict[int, str] = {}
with args.trace.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
r = json.loads(line)
r["session_id"] = _resolve_session(r, chat_to_session)
rows.append(r)
rows.sort(key=lambda r: float(r.get("timestamp", 0.0)))
global_trie: dict = {}
session_tries: dict[str, dict] = defaultdict(dict)
# First-position block stats: how many sessions hit each top-level
# hash. We approximate "system prefix" as a block seen at position 0
# across many sessions.
pos0_session_set: dict[int, set[str]] = defaultdict(set)
for r in rows:
hids = list(r.get("hash_ids") or [])
if hids:
pos0_session_set[hids[0]].add(r["session_id"])
shared_pos0 = {h for h, s in pos0_session_set.items()
if len(s) >= args.shared_prefix_min_sessions}
total_input = 0
cache_any = 0
cache_intra = 0
cache_shared_only = 0
per_session_input: dict[str, int] = defaultdict(int)
per_session_intra: dict[str, int] = defaultdict(int)
n_with_any_hit = 0
n_with_intra_hit = 0
n_with_shared_hit = 0
for r in rows:
hids = list(r.get("hash_ids") or [])
input_len = int(r.get("input_length") or 0)
sid = r["session_id"]
total_input += input_len
per_session_input[sid] += input_len
g_depth = _walk(global_trie, hids)
s_depth = _walk(session_tries[sid], hids)
# Shared-prefix-only: greedy depth, but stop at first non-shared
# pos0 block (because then no later blocks can be from system
# prefix as a contiguous chain).
sh_depth = 0
if hids and hids[0] in shared_pos0:
sh_depth = 1
# subsequent blocks at deeper positions are NOT modeled as
# "shared system" in this conservative bound.
# accumulate
g_tokens = min(g_depth * args.block_size, input_len)
s_tokens = min(s_depth * args.block_size, input_len)
sh_tokens = min(sh_depth * args.block_size, input_len)
cache_any += g_tokens
cache_intra += s_tokens
cache_shared_only += sh_tokens
per_session_intra[sid] += s_tokens
if g_tokens > 0:
n_with_any_hit += 1
if s_tokens > 0:
n_with_intra_hit += 1
if sh_tokens > 0:
n_with_shared_hit += 1
# update tries
_insert(global_trie, hids)
_insert(session_tries[sid], hids)
out = {
"trace": str(args.trace),
"n_requests": len(rows),
"n_sessions": len({r["session_id"] for r in rows}),
"block_size": args.block_size,
"shared_prefix_min_sessions": args.shared_prefix_min_sessions,
"total_input_tokens": total_input,
"apc_upper_any_session": cache_any / total_input,
"apc_upper_intra_session": cache_intra / total_input,
"apc_upper_shared_prefix_only": cache_shared_only / total_input,
"cached_tokens_any_session": cache_any,
"cached_tokens_intra_session": cache_intra,
"cached_tokens_shared_prefix_only": cache_shared_only,
"n_requests_any_hit": n_with_any_hit,
"n_requests_intra_hit": n_with_intra_hit,
"n_requests_shared_hit": n_with_shared_hit,
"n_shared_pos0_blocks": len(shared_pos0),
}
text = json.dumps(out, indent=2)
if args.out:
args.out.write_text(text)
print(text)
if __name__ == "__main__":
main()