diff --git a/analysis/characterization/b2_sweep_analysis.py b/analysis/characterization/b2_sweep_analysis.py index 0acc6bb..33fed43 100644 --- a/analysis/characterization/b2_sweep_analysis.py +++ b/analysis/characterization/b2_sweep_analysis.py @@ -1,10 +1,24 @@ -"""Aggregate B2 microbench cells into a single interference-index sweep summary. +"""Aggregate B2 microbench cells: same- vs different-worker prefill overlap. -Per cell (variant × prefill_size): -- read metrics.jsonl + run_window.json -- slice the shared engine_*.jsonl by run window -- run interference_index() against the slice -- record (variant, prefill_size, n_overlap, n_clean, tpot_p90_*, idx) +For each (variant × prefill_size) cell we have: +- 240 short-prompt decode requests at qps=4 +- 4 large-prompt one-token "prefill injections" + +The interesting question is *not* "does any other request's prefill overlap +this decode" (the answer is always yes — every decode begins with its own +short prefill, and at qps=4 they overlap each other constantly). The +interesting question is "does an injected large prefill on the *same* worker +materially slow this decode down?". + +So we: +1) extract each cell's injection windows = [(t_dispatch, t_finish) + for r in metrics if r.workload=="prefill"]; +2) label each decode request as overlap iff its + [t_first_token, t_finish] intersects at least one injection window; +3) compute TPOT p50/p90/p99 for overlap vs clean; +4) the per-cell interference index = TPOT_p90(overlap) / + TPOT_p90(clean). For "different" variant this should hover near 1.0; + for "same" it should rise with prefill_size. """ from __future__ import annotations @@ -13,71 +27,77 @@ import argparse import json from collections import defaultdict from pathlib import Path -from typing import Any from analysis.characterization.joined_analysis import ( _percentile, - _vllm_rid_matches, - interference_index, - load_engine_state, load_jsonl, write_json, ) -def _slice_engine_state( - engine_state_by_worker: dict[str, list[dict]], - t_start: float, - t_end: float, -) -> dict[str, list[dict]]: - sliced: dict[str, list[dict]] = {} - for worker, steps in engine_state_by_worker.items(): - sliced[worker] = [s for s in steps - if t_start <= (s.get("t_unix") or 0.0) <= t_end] - return sliced +def _overlaps(a_start: float, a_end: float, b_start: float, b_end: float) -> bool: + return a_start <= b_end and b_start <= a_end -def _to_joined_shape(metrics_rows: list[dict], variant: str) -> list[dict]: - """Adapt B2 metric rows to what interference_index expects.""" - joined: list[dict] = [] - for r in metrics_rows: - if r.get("workload") != "decode": +def _analyze_cell(metrics_rows: list[dict]) -> dict: + prefills = [r for r in metrics_rows if r.get("workload") == "prefill" + and r.get("error") is None] + decodes = [r for r in metrics_rows if r.get("workload") == "decode" + and r.get("error") is None] + + inj_windows: list[tuple[float, float]] = [] + for p in prefills: + ts = p.get("t_dispatch_unix") + te = p.get("t_finish_unix") + if ts is None or te is None: continue - joined.append({ - "request_id": r["request_id"], - "tpot_s": r.get("tpot_s"), - "ttft_s": r.get("ttft_s"), - "latency_s": r.get("latency_s"), - "endpoint_url": r.get("endpoint"), - "routed_to": r.get("endpoint"), - "t_first_token_unix": ( - (r["t_dispatch_unix"] + r["ttft_s"]) - if r.get("ttft_s") is not None - and r.get("t_dispatch_unix") is not None else None - ), - "t_finish_unix": r.get("t_finish_unix"), - "error": r.get("error"), - }) - return joined + inj_windows.append((float(ts), float(te))) + + overlap_tpots: list[float] = [] + clean_tpots: list[float] = [] + overlap_ttfts: list[float] = [] + clean_ttfts: list[float] = [] + for d in decodes: + ts = d.get("t_dispatch_unix") + te = d.get("t_finish_unix") + if ts is None or te is None: + continue + is_overlap = any(_overlaps(ts, te, ws, we) for ws, we in inj_windows) + tpot = d.get("tpot_s") + ttft = d.get("ttft_s") + if tpot is not None: + (overlap_tpots if is_overlap else clean_tpots).append(float(tpot)) + if ttft is not None: + (overlap_ttfts if is_overlap else clean_ttfts).append(float(ttft)) + + p90_overlap = _percentile(overlap_tpots, 0.90) if overlap_tpots else None + p90_clean = _percentile(clean_tpots, 0.90) if clean_tpots else None + idx = (p90_overlap / p90_clean) if (p90_overlap and p90_clean) else None + return { + "n_prefill_injections": len(prefills), + "n_decode_total": len(decodes), + "n_decode_overlap": len(overlap_tpots), + "n_decode_clean": len(clean_tpots), + "tpot_p50_overlap_s": _percentile(overlap_tpots, 0.50), + "tpot_p90_overlap_s": p90_overlap, + "tpot_p99_overlap_s": _percentile(overlap_tpots, 0.99), + "tpot_p50_clean_s": _percentile(clean_tpots, 0.50), + "tpot_p90_clean_s": p90_clean, + "tpot_p99_clean_s": _percentile(clean_tpots, 0.99), + "ttft_p90_overlap_s": _percentile(overlap_ttfts, 0.90) + if overlap_ttfts else None, + "ttft_p90_clean_s": _percentile(clean_ttfts, 0.90) + if clean_ttfts else None, + "interference_index": idx, + } def main() -> None: - p = argparse.ArgumentParser(description="B2 sweep aggregation") - p.add_argument("--sweep-dir", type=Path, required=True, - help="Top-level dir produced by scripts/b2_interference.py") - p.add_argument("--engine-state-dir", type=Path, required=True) - p.add_argument("--worker-map", type=str, required=True, - help="URL=worker_id pairs, comma-separated") + p = argparse.ArgumentParser(description="B2 sweep aggregation (window-overlap)") + p.add_argument("--sweep-dir", type=Path, required=True) p.add_argument("--out", type=Path, default=None) args = p.parse_args() - worker_map = {} - for entry in args.worker_map.split(","): - url, _, wid = entry.strip().partition("=") - if url and wid: - worker_map[url] = wid - - engine_state = load_engine_state(args.engine_state_dir) rows: list[dict] = [] for variant_dir in sorted(args.sweep_dir.glob("*/")): if variant_dir.name in ("logs",): @@ -89,27 +109,16 @@ def main() -> None: continue window = json.loads(window_path.read_text()) metrics_rows = load_jsonl(metrics_path) - joined = _to_joined_shape(metrics_rows, variant_dir.name) - sliced = _slice_engine_state( - engine_state, window["t_start_unix"], window["t_end_unix"], - ) - idx = interference_index(joined, sliced, worker_map) + cell = _analyze_cell(metrics_rows) rows.append({ "variant": variant_dir.name, "prefill_size": int(window["prefill_size"]), "decode_endpoint": window["decode_endpoint"], "prefill_endpoint": window["prefill_endpoint"], - "n_decode_requests": sum(1 for r in metrics_rows - if r.get("workload") == "decode" - and r.get("error") is None), - "n_prefill_injections": sum(1 for r in metrics_rows - if r.get("workload") == "prefill" - and r.get("error") is None), - **idx, + **cell, }) - summary = {"rows": rows} out_path = args.out or args.sweep_dir / "b2_sweep_summary.json" - write_json(out_path, summary) + write_json(out_path, {"rows": rows}) print(json.dumps(rows, indent=2)) diff --git a/analysis/characterization/render_window1_figures.py b/analysis/characterization/render_window1_figures.py new file mode 100644 index 0000000..967d6a7 --- /dev/null +++ b/analysis/characterization/render_window1_figures.py @@ -0,0 +1,303 @@ +"""Render PNG figures for Window 1 results (B1', B2, B3). + +Inputs (all expected under ): +- b3_policy_comparison.json (per-policy table) +- b2_sweep_summary.json (per-cell B2 sweep) +- apc_upper_w600.json (theoretical bounds) +- lmetric_reuse.json (intra/cross/shared decomp) +- kv_footprint_summary.json (full trace KV stats) + +Outputs (under ): +- fig_b3_apc_vs_hotspot.png +- fig_b3_latency_bars.png +- fig_b3_apc_vs_upper.png +- fig_b3_failure_breakdown.png +- fig_b3_per_worker_ttft_p90.png +- fig_b2_tpot_vs_prefill.png +- fig_b2_ttft_vs_prefill.png +- fig_reuse_decomposition.png +- fig_kv_footprint_cdf.png +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +POLICY_ORDER = ["lmetric", "load_only", "sticky", "unified", "capped"] +POLICY_COLOR = { + "lmetric": "#1f77b4", + "load_only": "#ff7f0e", + "sticky": "#d62728", + "unified": "#2ca02c", + "capped": "#9467bd", +} + + +def _load(results_dir: Path, name: str) -> dict: + return json.loads((results_dir / name).read_text()) + + +def fig_b3_apc_vs_hotspot(comp: dict, upper: dict, out: Path) -> None: + upper_intra = upper["apc_upper_intra_session"] + fig, ax = plt.subplots(figsize=(6, 4.5)) + for r in comp["rows"]: + pol = r["policy"] + ax.scatter(r["apc_ratio"] * 100, r["hotspot_index_ttft_p90"], + s=180, color=POLICY_COLOR.get(pol, "gray"), label=pol, + edgecolors="black", linewidths=0.5) + ax.annotate(pol, (r["apc_ratio"] * 100, r["hotspot_index_ttft_p90"]), + xytext=(7, 7), textcoords="offset points", + fontsize=9) + ax.axvline(upper_intra * 100, linestyle="--", color="gray", alpha=0.6, + label=f"intra-session APC upper {upper_intra * 100:.1f}%") + ax.set_xlabel("APC achieved (%)") + ax.set_ylabel("hotspot_index = max(worker TTFT p90) / median") + ax.set_title("B3: APC vs hot-spot tradeoff across policies") + ax.grid(alpha=0.3) + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_b3_latency_bars(comp: dict, out: Path) -> None: + by = {r["policy"]: r for r in comp["rows"]} + pols = [p for p in POLICY_ORDER if p in by] + metrics = [("TTFT p90 (s)", "ttft_p90_s"), + ("TPOT p90 (ms)", "tpot_p90_s"), + ("E2E p90 (s)", "e2e_p90_s")] + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + for ax, (label, key) in zip(axes, metrics): + vals = [by[p][key] * (1000 if "TPOT" in label else 1) for p in pols] + ax.bar(pols, vals, color=[POLICY_COLOR.get(p, "gray") for p in pols], + edgecolor="black", linewidth=0.5) + ax.set_title(label) + ax.tick_params(axis="x", rotation=20) + for i, v in enumerate(vals): + ax.text(i, v, f"{v:.1f}", ha="center", va="bottom", fontsize=9) + ax.grid(alpha=0.3, axis="y") + fig.suptitle("B3 headline latencies per policy") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_b3_apc_vs_upper(comp: dict, upper: dict, out: Path) -> None: + by = {r["policy"]: r for r in comp["rows"]} + pols = [p for p in POLICY_ORDER if p in by] + achieved = [by[p]["apc_ratio"] * 100 for p in pols] + fig, ax = plt.subplots(figsize=(6.5, 4)) + bars = ax.bar(pols, achieved, + color=[POLICY_COLOR.get(p, "gray") for p in pols], + edgecolor="black", linewidth=0.5) + ax.axhline(upper["apc_upper_intra_session"] * 100, linestyle="--", + color="black", alpha=0.7, + label=f"intra-session ceiling {upper['apc_upper_intra_session'] * 100:.1f}%") + ax.axhline(upper["apc_upper_any_session"] * 100, linestyle=":", + color="darkgray", alpha=0.7, + label=f"any-session ceiling {upper['apc_upper_any_session'] * 100:.1f}%") + for b, v in zip(bars, achieved): + ax.text(b.get_x() + b.get_width() / 2, v + 1, f"{v:.1f}%", + ha="center", fontsize=9) + ax.set_ylim(0, 100) + ax.set_ylabel("APC ratio (%)") + ax.set_title("B3: APC achieved vs theoretical ceiling") + ax.legend(loc="upper right", fontsize=9) + ax.grid(alpha=0.3, axis="y") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_b3_failure_breakdown(comp: dict, out: Path) -> None: + by = {r["policy"]: r for r in comp["rows"]} + pols = [p for p in POLICY_ORDER if p in by] + causes = ["same_worker_prefill_overlap", "hot_worker_queue", + "cache_miss_large_append", "high_kv_occupancy", "unknown"] + cause_color = { + "same_worker_prefill_overlap": "#d62728", + "hot_worker_queue": "#ff7f0e", + "cache_miss_large_append": "#1f77b4", + "high_kv_occupancy": "#8c564b", + "unknown": "#7f7f7f", + } + fig, ax = plt.subplots(figsize=(7, 4.5)) + bottom = [0.0] * len(pols) + for c in causes: + vals = [(by[p].get("failure_counts") or {}).get(c, 0) for p in pols] + ax.bar(pols, vals, bottom=bottom, label=c.replace("_", " "), + color=cause_color[c], edgecolor="black", linewidth=0.3) + bottom = [a + b for a, b in zip(bottom, vals)] + for i, total in enumerate(bottom): + ax.text(i, total + 3, f"n={int(total)}", ha="center", fontsize=9) + ax.set_ylabel("slow request count (TTFT > 2× p90 threshold)") + ax.set_title("B3: slow-request cause breakdown per policy") + ax.legend(fontsize=8, loc="upper right") + ax.grid(alpha=0.3, axis="y") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_b3_per_worker_ttft(results_dir: Path, comp: dict, out: Path) -> None: + """Per-worker TTFT p90 grouped bars; reads each policy's hotspot_index.json.""" + by = {r["policy"]: r for r in comp["rows"]} + pols = [p for p in POLICY_ORDER if p in by] + fig, axes = plt.subplots(1, len(pols), figsize=(3 * len(pols), 4), + sharey=True) + if len(pols) == 1: + axes = [axes] + for ax, pol in zip(axes, pols): + path = results_dir / f"per_worker_{pol}.json" + if not path.exists(): + ax.text(0.5, 0.5, f"{pol}: no data", ha="center", va="center", + transform=ax.transAxes) + continue + per = json.loads(path.read_text()).get("per_worker_ttft_p90_s") or {} + items = sorted(per.items(), key=lambda kv: int(kv[0].rsplit(":", 1)[1])) + labels = [f"e{int(k.rsplit(':', 1)[1]) - 8000}" for k, _ in items] + vals = [v for _, v in items] + ax.bar(labels, vals, color=POLICY_COLOR.get(pol, "gray"), + edgecolor="black", linewidth=0.5) + for i, v in enumerate(vals): + ax.text(i, v, f"{v:.1f}", ha="center", va="bottom", fontsize=8) + ax.set_title(f"{pol}\nhotspot={by[pol]['hotspot_index_ttft_p90']:.2f}", + fontsize=10) + ax.tick_params(axis="x", labelsize=8) + ax.grid(alpha=0.3, axis="y") + axes[0].set_ylabel("worker TTFT p90 (s)") + fig.suptitle("B3 per-worker TTFT p90 distribution") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_b2_curves(b2: dict, out_tpot: Path, out_ttft: Path) -> None: + sizes = sorted({r["prefill_size"] for r in b2["rows"]}) + by_var = {"same": {}, "different": {}} + for r in b2["rows"]: + by_var[r["variant"]][r["prefill_size"]] = r + + for name, key, ylabel, ymax_log, out in [ + ("TPOT", "tpot_p90", "TPOT p90 (ms)", True, out_tpot), + ("TTFT", "ttft_p90", "TTFT p90 (s)", True, out_ttft), + ]: + fig, axes = plt.subplots(1, 2, figsize=(11, 4)) + ax_abs, ax_idx = axes + for variant in ("different", "same"): + xs, ys_o, ys_c, idxs = [], [], [], [] + for sz in sizes: + r = by_var[variant].get(sz) + if not r: continue + ov = r.get(f"{key}_overlap_s") + cl = r.get(f"{key}_clean_s") + if ov is None or cl is None: continue + xs.append(sz) + scale = 1000 if name == "TPOT" else 1.0 + ys_o.append(ov * scale) + ys_c.append(cl * scale) + idxs.append(ov / cl) + color = "#d62728" if variant == "same" else "#1f77b4" + ax_abs.plot(xs, ys_o, "o-", color=color, + label=f"{variant} (overlap)") + ax_abs.plot(xs, ys_c, "s--", color=color, alpha=0.5, + label=f"{variant} (clean)") + ax_idx.plot(xs, idxs, "o-", color=color, label=variant, + linewidth=2) + ax_abs.set_xscale("log", base=2) + ax_abs.set_yscale("log") + ax_abs.set_xlabel("prefill injection size (tokens)") + ax_abs.set_ylabel(ylabel + " (log)") + ax_abs.set_title(f"B2 {name} absolute (overlap vs clean)") + ax_abs.legend(fontsize=8) + ax_abs.grid(alpha=0.3, which="both") + + ax_idx.set_xscale("log", base=2) + if ymax_log: + ax_idx.set_yscale("log") + ax_idx.axhline(1.0, color="black", linestyle=":", alpha=0.5) + ax_idx.set_xlabel("prefill injection size (tokens)") + ax_idx.set_ylabel(f"{name} idx = overlap / clean") + ax_idx.set_title(f"B2 {name} interference index (same vs different worker)") + ax_idx.legend() + ax_idx.grid(alpha=0.3, which="both") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_reuse_decomposition(reuse: dict, out: Path) -> None: + fr = reuse.get("fractions") or {} + labels = ["intra-session", "cross-session", "shared-prefix", "unclassified"] + vals = [fr.get("intra", 0), fr.get("cross", 0), + fr.get("shared", 0), fr.get("unclassified", 0)] + colors = ["#2ca02c", "#ff7f0e", "#9467bd", "#7f7f7f"] + fig, ax = plt.subplots(figsize=(6, 3)) + bottom = 0.0 + for label, v, c in zip(labels, vals, colors): + ax.barh(["lmetric run"], [v], left=[bottom], color=c, edgecolor="black", + linewidth=0.5, label=f"{label} ({v * 100:.1f}%)") + bottom += v + ax.set_xlabel("fraction of cached_tokens") + ax.set_xlim(0, 1) + ax.set_title("Real reuse decomposition (w600 lmetric run)") + ax.legend(fontsize=9, loc="lower right") + ax.grid(alpha=0.3, axis="x") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def fig_kv_footprint_cdf(kv: dict, out: Path) -> None: + s = kv.get("kv_mib_per_request") or {} + vals = [s.get(k) for k in ("p50", "p90", "p95", "p99")] + labels = ["p50", "p90", "p95", "p99"] + fig, ax = plt.subplots(figsize=(6, 3.5)) + ax.bar(labels, vals, color="#1f77b4", edgecolor="black", linewidth=0.5) + for i, v in enumerate(vals): + ax.text(i, v, f"{v:.0f} MiB", ha="center", va="bottom", fontsize=9) + ax.axhline(95 * 1024, color="red", linestyle="--", alpha=0.5, + label="H20 ~95 GiB usable") + ax.set_ylabel("KV bytes per request (MiB)") + ax.set_title("B1' Per-request KV footprint (Qwen3-Coder-30B-A3B, 98304 B/token)") + ax.legend() + ax.grid(alpha=0.3, axis="y") + fig.tight_layout() + fig.savefig(out, dpi=120) + plt.close(fig) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--results-dir", type=Path, required=True) + p.add_argument("--out-dir", type=Path, required=True) + args = p.parse_args() + args.out_dir.mkdir(parents=True, exist_ok=True) + + comp = _load(args.results_dir, "b3_policy_comparison.json") + upper = _load(args.results_dir, "apc_upper_w600.json") + b2 = _load(args.results_dir, "b2_sweep_summary.json") + reuse = _load(args.results_dir, "lmetric_reuse.json") + kv = _load(args.results_dir, "kv_footprint_summary.json") + + fig_b3_apc_vs_hotspot(comp, upper, args.out_dir / "fig_b3_apc_vs_hotspot.png") + fig_b3_latency_bars(comp, args.out_dir / "fig_b3_latency_bars.png") + fig_b3_apc_vs_upper(comp, upper, args.out_dir / "fig_b3_apc_vs_upper.png") + fig_b3_failure_breakdown(comp, args.out_dir / "fig_b3_failure_breakdown.png") + fig_b3_per_worker_ttft(args.results_dir, comp, + args.out_dir / "fig_b3_per_worker_ttft_p90.png") + fig_b2_curves(b2, + args.out_dir / "fig_b2_tpot_vs_prefill.png", + args.out_dir / "fig_b2_ttft_vs_prefill.png") + fig_reuse_decomposition(reuse, args.out_dir / "fig_reuse_decomposition.png") + fig_kv_footprint_cdf(kv, args.out_dir / "fig_kv_footprint_cdf.png") + print(f"wrote 8 figures to {args.out_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/compute_apc_upper_bound.py b/scripts/compute_apc_upper_bound.py new file mode 100644 index 0000000..48485ed --- /dev/null +++ b/scripts/compute_apc_upper_bound.py @@ -0,0 +1,159 @@ +"""Compute theoretical APC upper bound from a trace's hash_ids. + +vLLM prefix caching is block-level (BLOCK_SIZE token chunks) and chain- +aware: block i hits the cache iff hash_ids[0..i] matches some previously +seen request's hash_ids[0..i]. The trace's `hash_ids` field is the same +hash chain. + +Three variants: +- any_session : trie built across all previously seen requests +- intra_session : trie scoped to the request's own session_id +- shared_prefix : trie built from blocks at position 0, that appear + across >= K distinct sessions (system-prompt proxy) +""" + +from __future__ import annotations + +import argparse +import json +from collections import defaultdict +from pathlib import Path + +BLOCK_SIZE_DEFAULT = 512 + + +def _walk(trie: dict, hashes: list[int]) -> int: + depth = 0 + node = trie + for h in hashes: + if h in node: + depth += 1 + node = node[h] + else: + break + return depth + + +def _insert(trie: dict, hashes: list[int]) -> None: + node = trie + for h in hashes: + node = node.setdefault(h, {}) + + +def _resolve_session(row: dict, chat_to_session: dict[int, str]) -> str: + if "session_id" in row: + return str(row["session_id"]) + cid = int(row["chat_id"]) + pcid = int(row["parent_chat_id"]) + if pcid < 0: + sid = str(cid) + else: + sid = chat_to_session.get(pcid, str(pcid)) + chat_to_session[cid] = sid + return sid + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--trace", type=Path, required=True) + p.add_argument("--block-size", type=int, default=BLOCK_SIZE_DEFAULT) + p.add_argument("--shared-prefix-min-sessions", type=int, default=8) + p.add_argument("--out", type=Path, default=None) + args = p.parse_args() + + rows: list[dict] = [] + chat_to_session: dict[int, str] = {} + with args.trace.open("r", encoding="utf-8") as fh: + for line in fh: + line = line.strip() + if not line: + continue + r = json.loads(line) + r["session_id"] = _resolve_session(r, chat_to_session) + rows.append(r) + rows.sort(key=lambda r: float(r.get("timestamp", 0.0))) + + global_trie: dict = {} + session_tries: dict[str, dict] = defaultdict(dict) + # First-position block stats: how many sessions hit each top-level + # hash. We approximate "system prefix" as a block seen at position 0 + # across many sessions. + pos0_session_set: dict[int, set[str]] = defaultdict(set) + for r in rows: + hids = list(r.get("hash_ids") or []) + if hids: + pos0_session_set[hids[0]].add(r["session_id"]) + shared_pos0 = {h for h, s in pos0_session_set.items() + if len(s) >= args.shared_prefix_min_sessions} + + total_input = 0 + cache_any = 0 + cache_intra = 0 + cache_shared_only = 0 + per_session_input: dict[str, int] = defaultdict(int) + per_session_intra: dict[str, int] = defaultdict(int) + n_with_any_hit = 0 + n_with_intra_hit = 0 + n_with_shared_hit = 0 + + for r in rows: + hids = list(r.get("hash_ids") or []) + input_len = int(r.get("input_length") or 0) + sid = r["session_id"] + total_input += input_len + per_session_input[sid] += input_len + + g_depth = _walk(global_trie, hids) + s_depth = _walk(session_tries[sid], hids) + # Shared-prefix-only: greedy depth, but stop at first non-shared + # pos0 block (because then no later blocks can be from system + # prefix as a contiguous chain). + sh_depth = 0 + if hids and hids[0] in shared_pos0: + sh_depth = 1 + # subsequent blocks at deeper positions are NOT modeled as + # "shared system" in this conservative bound. + # accumulate + g_tokens = min(g_depth * args.block_size, input_len) + s_tokens = min(s_depth * args.block_size, input_len) + sh_tokens = min(sh_depth * args.block_size, input_len) + cache_any += g_tokens + cache_intra += s_tokens + cache_shared_only += sh_tokens + per_session_intra[sid] += s_tokens + if g_tokens > 0: + n_with_any_hit += 1 + if s_tokens > 0: + n_with_intra_hit += 1 + if sh_tokens > 0: + n_with_shared_hit += 1 + # update tries + _insert(global_trie, hids) + _insert(session_tries[sid], hids) + + out = { + "trace": str(args.trace), + "n_requests": len(rows), + "n_sessions": len({r["session_id"] for r in rows}), + "block_size": args.block_size, + "shared_prefix_min_sessions": args.shared_prefix_min_sessions, + "total_input_tokens": total_input, + "apc_upper_any_session": cache_any / total_input, + "apc_upper_intra_session": cache_intra / total_input, + "apc_upper_shared_prefix_only": cache_shared_only / total_input, + "cached_tokens_any_session": cache_any, + "cached_tokens_intra_session": cache_intra, + "cached_tokens_shared_prefix_only": cache_shared_only, + "n_requests_any_hit": n_with_any_hit, + "n_requests_intra_hit": n_with_intra_hit, + "n_requests_shared_hit": n_with_shared_hit, + "n_shared_pos0_blocks": len(shared_pos0), + } + text = json.dumps(out, indent=2) + if args.out: + args.out.write_text(text) + print(text) + + +if __name__ == "__main__": + main()