"""Render the three PD-vs-colo crossover figures from fig_agg JSON dumps. Inputs (produced by `fig_agg.py --json`): analysis/mb5_pd_ablation/fig1_reuse_fixed.json reuse axis (N=8, FIXED real prefill delta=2048; vary cached prefix -> reuse = pfx/(pfx+delta). Controlled-variable: real new-prefill work is constant across the sweep, only the cached fraction (and total context) grows. Supersedes the old fig1.json, which held input=8192 and sliced prefix out of it so delta shrank 15x as reuse rose — a confound, not a pure reuse axis.) analysis/mb5_pd_ablation/fig2.json shape axis (N=8, reuse~70%) analysis/mb5_pd_ablation/fig3_conc32k.json concurrency (in32768/out128, reuse~0.984 = 32256 resident + 512 real new-prefill per turn; retuned 2026-05-31 to the agentic corner so PD pays the full-context per-turn KV-transfer tax while colo keeps it resident; vary N by step 8 up to the mean-E2E<=10s SLO ceiling) Each figure overlays colo + the three PD ratios and marks the PD-best advantage. All three share the corrected (uncontaminated, e13391e-gated-off) stack. """ from __future__ import annotations import json import re from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt ROOT = Path(__file__).resolve().parents[2] DATA = ROOT / "analysis" / "mb5_pd_ablation" OUT = ROOT / "figs" / "mb5_pd_ablation" OUT.mkdir(parents=True, exist_ok=True) PD_ARMS = ["2P+6D", "4P+4D", "6P+2D"] STYLE = { "colo": dict(color="k", marker="o", lw=2.4, ls="-", label="colo (8×kv_both)"), "2P+6D": dict(color="#1f77b4", marker="s", lw=1.6, ls="--", label="PD 2P+6D"), "4P+4D": dict(color="#2ca02c", marker="^", lw=1.6, ls="--", label="PD 4P+4D"), "6P+2D": dict(color="#ff7f0e", marker="v", lw=1.6, ls="--", label="PD 6P+2D"), } def load(name): return json.load(open(DATA / name)) def by_axis(rows, keyfn): """Group rows -> {axis_val: {arm: row}}.""" out = {} for r in rows: k = keyfn(r["name"]) if k is None: continue out.setdefault(k, {})[r["arm"]] = r return out def pd_best(armmap, metric="e2e_p90"): vals = [(a, armmap[a][metric]) for a in PD_ARMS if a in armmap and armmap[a].get(metric) is not None] return min(vals, key=lambda t: t[1]) if vals else (None, None) def series(grp, xs, arm, metric): return [grp[x][arm].get(metric) if arm in grp[x] else None for x in xs] # ---------- Fig 1: reuse axis ---------- def _reuse_pct(name): """Reuse % from a `reuse_p{pfx}_d{delta}_{arm}` run name: pfx/(pfx+delta).""" m = re.search(r"_p(\d+)_d(\d+)", name) if not m: return None pfx, delta = int(m.group(1)), int(m.group(2)) return round(pfx / (pfx + delta) * 100) def fig_reuse(): g = by_axis(load("fig1_reuse_fixed.json"), _reuse_pct) xs = sorted(g) reuse = xs fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.2)) for arm in ["colo", *PD_ARMS]: ax1.plot(reuse, series(g, xs, arm, "e2e_p90"), **STYLE[arm]) ax1.set_xlabel("intra-session KV reuse (%) [fixed real prefill, delta=2048]") ax1.set_ylabel("E2E latency p90 (s)") ax1.set_title("(a) E2E-p90 vs reuse (N=8, delta=2048/out256)") ax1.legend(fontsize=8); ax1.grid(alpha=.3) adv, putil = [], [] for x in xs: co = g[x]["colo"]["e2e_p90"]; _, b = pd_best(g[x]) adv.append(co / b if b else None) a = pd_best(g[x])[0] putil.append(g[x][a].get("pu") if a else None) ax2.plot(reuse, adv, color="purple", marker="D", lw=2, label="PD-best advantage (colo/PD)") ax2.axhline(1.0, color="grey", ls=":", lw=1) ax2.set_xlabel("intra-session KV reuse (%)"); ax2.set_ylabel("advantage (>1 = PD wins)") ax2b = ax2.twinx() ax2b.plot(reuse, putil, color="brown", marker="x", lw=1.4, ls="-.", label="PD-best prefill-GPU util") ax2b.set_ylabel("prefill-GPU util (%)", color="brown"); ax2b.tick_params(axis="y", colors="brown") ax2.set_title("(b) advantage erodes; prefill GPUs go idle") l1, la1 = ax2.get_legend_handles_labels(); l2, la2 = ax2b.get_legend_handles_labels() ax2.legend(l1 + l2, la1 + la2, fontsize=8, loc="center right"); ax2.grid(alpha=.3) fig.suptitle("Fig 1 — Reuse axis (fixed real prefill delta=2048): PD's edge vs rising cache reuse", fontsize=11, y=1.02) fig.tight_layout(); p = OUT / "fig1_reuse_axis.png"; fig.savefig(p, dpi=130, bbox_inches="tight") print("wrote", p) # ---------- Fig 2: shape axis ---------- def fig_shape(): g = by_axis(load("fig2.json"), lambda n: ((int(m.group(1)), int(m.group(2))) if (m := re.search(r"_in(\d+)_out(\d+)_", n)) else None)) xs = sorted(g, key=lambda t: t[0]) # ascending input labels = [f"in{i}\nout{o}" for i, o in xs] xi = list(range(len(xs))) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.2)) for arm in ["colo", *PD_ARMS]: ax1.plot(xi, series(g, xs, arm, "e2e_p90"), **STYLE[arm]) ax1.set_xticks(xi); ax1.set_xticklabels(labels, fontsize=7) ax1.set_xlabel("shape (decode-heavy → prefill-heavy)"); ax1.set_ylabel("E2E latency p90 (s)") ax1.set_title("(a) E2E-p90 vs shape (N=8, reuse~70%)") ax1.legend(fontsize=8); ax1.grid(alpha=.3) adv, comp = [], [] for x in xs: co = g[x]["colo"]["e2e_p90"]; a, b = pd_best(g[x]) adv.append(co / b if b else None) # completion of the worst PD arm (exposes catastrophic ratio) worst = min((g[x][arm]["n"] / g[x][arm]["req"]) for arm in PD_ARMS if arm in g[x]) comp.append(worst * 100) ax2.plot(xi, adv, color="purple", marker="D", lw=2, label="PD-best advantage (colo/PD)") ax2.axhline(1.0, color="grey", ls=":", lw=1) ax2.set_xticks(xi); ax2.set_xticklabels(labels, fontsize=7) ax2.set_xlabel("shape"); ax2.set_ylabel("advantage (>1 = PD wins)") ax2b = ax2.twinx() ax2b.plot(xi, comp, color="red", marker="x", lw=1.4, ls="-.", label="worst-PD-arm completion %") ax2b.set_ylabel("worst PD completion (%)", color="red"); ax2b.tick_params(axis="y", colors="red") ax2b.set_ylim(80, 101) ax2.set_title("(b) advantage peaks mid-sweep; wrong ratio catastrophic at prefill extreme") l1, la1 = ax2.get_legend_handles_labels(); l2, la2 = ax2b.get_legend_handles_labels() ax2.legend(l1 + l2, la1 + la2, fontsize=8, loc="lower left"); ax2.grid(alpha=.3) fig.suptitle("Fig 2 — Shape axis: PD wins decode-heavy, ties prefill-heavy; optimal ratio rotates", fontsize=11, y=1.02) fig.tight_layout(); p = OUT / "fig2_shape_axis.png"; fig.savefig(p, dpi=130, bbox_inches="tight") print("wrote", p) # ---------- Fig 3: concurrency axis ---------- def fig_conc(): g = by_axis(load("fig3_conc32k.json"), lambda n: (int(m.group(1)) if (m := re.search(r"_N(\d+)_", n)) else None)) xs = sorted(g) fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4.2)) # (a) request completion % — the headline (latency percentiles count successes # only, so they understate PD; completion is the honest collapse signal). for arm in ["colo", *PD_ARMS]: comp = [(g[x][arm]["n"] / g[x][arm]["req"] * 100) if arm in g[x] else None for x in xs] ax1.plot(xs, comp, **STYLE[arm]) ax1.axhline(100, color="grey", ls=":", lw=1) ax1.set_xticks(xs); ax1.set_xticklabels(xs, fontsize=7) ax1.set_xlabel("concurrent sessions N"); ax1.set_ylabel("request completion (%)") ax1.set_title("(a) completion: colo 100%, PD collapses"); ax1.legend(fontsize=8); ax1.grid(alpha=.3) for arm in ["colo", *PD_ARMS]: ax2.plot(xs, series(g, xs, arm, "e2e_mean"), **STYLE[arm]) ax2.axhline(10.0, color="red", ls=":", lw=1, label="SLO 10s") ax2.set_yscale("log"); ax2.set_xticks(xs); ax2.set_xticklabels(xs, fontsize=7) ax2.set_xlabel("concurrent sessions N"); ax2.set_ylabel("E2E latency mean (s, log)") ax2.set_title("(b) mean-E2E (successes only)"); ax2.legend(fontsize=8); ax2.grid(alpha=.3, which="both") for arm in ["colo", *PD_ARMS]: ax3.plot(xs, series(g, xs, arm, "tps"), **STYLE[arm]) ax3.set_xticks(xs); ax3.set_xticklabels(xs, fontsize=7) ax3.set_xlabel("concurrent sessions N"); ax3.set_ylabel("throughput (tok/s)") ax3.set_title("(c) TPS"); ax3.legend(fontsize=8); ax3.grid(alpha=.3) fig.suptitle("Fig 3 — Concurrency axis (in32768/out128, reuse~0.984, PD capped 600s / colo uncapped): " "colo degrades gracefully (100% completion), PD collapses earlier as N rises", fontsize=10, y=1.02) fig.tight_layout(); p = OUT / "fig3_concurrency_axis.png"; fig.savefig(p, dpi=130, bbox_inches="tight") print("wrote", p) if __name__ == "__main__": fig_reuse(); fig_shape(); fig_conc()