#!/usr/bin/env python3 """ Plot prefill-decode interference results (Microbench 1). Reads interference/results/summary.csv and produces two figures: 1. fig_interference_heatmap.png Heatmap of TPOT p90 interference index (during/baseline) over (D, P). 2. fig_interference_lines.png Two-panel: TPOT p90 during prefill (absolute, log) and prefill TTFT, one line per decode batch size D, x-axis = prefill tokens P. """ import csv from collections import defaultdict from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt HERE = Path(__file__).parent CSV = HERE / "interference/results/summary.csv" OUT_DIR = HERE / "interference/results" OUT_DIR.mkdir(parents=True, exist_ok=True) # ── load + aggregate (median across reps) ──────────────────────────────────── rows = list(csv.DictReader(open(CSV))) agg = defaultdict(list) for r in rows: D = int(r["decode_batch_size"]) P = int(r["new_prefill_tokens"]) if D == 0: continue bl_p90 = float(r["tpot_baseline_p90_ms"]) bl_p50 = float(r["tpot_baseline_p50_ms"]) dur_p90 = float(r["tpot_during_prefill_p90_ms"]) dur_p50 = float(r["tpot_during_prefill_p50_ms"]) ttft = float(r["prefill_ttft_ms"]) if bl_p90 <= 0 or dur_p90 <= 0: continue agg[(D, P)].append({ "bl_p50": bl_p50, "bl_p90": bl_p90, "dur_p50": dur_p50, "dur_p90": dur_p90, "ttft": ttft, "idx_p90": dur_p90 / bl_p90, "idx_p50": dur_p50 / bl_p50 if bl_p50 > 0 else 0, }) stat = {k: {kk: float(np.median([e[kk] for e in v])) for kk in v[0]} for k, v in agg.items()} D_VALUES = sorted({k[0] for k in stat}) P_VALUES = sorted({k[1] for k in stat}) # ── Figure 1: heatmap of interference index (TPOT p90 during / baseline) ───── mat = np.full((len(D_VALUES), len(P_VALUES)), np.nan) for i, D in enumerate(D_VALUES): for j, P in enumerate(P_VALUES): s = stat.get((D, P)) if s: mat[i, j] = s["idx_p90"] fig, ax = plt.subplots(figsize=(9.5, 5)) im = ax.imshow(mat, cmap="YlOrRd", aspect="auto", norm=matplotlib.colors.LogNorm(vmin=1, vmax=mat[~np.isnan(mat)].max())) ax.set_xticks(range(len(P_VALUES))) ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES], fontsize=11) ax.set_yticks(range(len(D_VALUES))) ax.set_yticklabels([f"D={D}" for D in D_VALUES], fontsize=11) ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12) ax.set_ylabel("Decode batch size", fontsize=12) ax.set_title("Prefill-Decode Interference Index\n" "TPOT p90 during prefill / TPOT p90 baseline (log color)", fontsize=13, fontweight="bold") # annotate each cell for i in range(len(D_VALUES)): for j in range(len(P_VALUES)): v = mat[i, j] if not np.isnan(v): txt_color = "white" if v > 50 else "black" ax.text(j, i, f"{v:.0f}x", ha="center", va="center", fontsize=10, color=txt_color, fontweight="bold") cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02) cbar.set_label("Interference index (×)", fontsize=10) plt.tight_layout() out_heatmap = OUT_DIR / "fig_interference_heatmap.png" plt.savefig(out_heatmap, dpi=160, bbox_inches="tight") print(f"Saved: {out_heatmap}") plt.close(fig) # ── Figure 2: lines, two panels ────────────────────────────────────────────── fig, axes = plt.subplots(1, 2, figsize=(13, 5)) D_COLORS = {1: "#1f77b4", 2: "#2ca02c", 4: "#ff7f0e", 8: "#d62728"} # Panel A: TPOT p90 during prefill (absolute, log y) ax = axes[0] for D in D_VALUES: ys_dur = [stat.get((D, P), {}).get("dur_p90", np.nan) for P in P_VALUES] ys_bl = [stat.get((D, P), {}).get("bl_p90", np.nan) for P in P_VALUES] color = D_COLORS.get(D, "gray") ax.plot(P_VALUES, ys_dur, "o-", color=color, label=f"D={D} (during prefill)", linewidth=2, markersize=7) ax.plot(P_VALUES, ys_bl, "s--", color=color, alpha=0.4, label=f"D={D} (baseline)", linewidth=1, markersize=5) ax.set_xscale("log", base=2) ax.set_yscale("log") ax.set_xticks(P_VALUES) ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES]) ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12) ax.set_ylabel("TPOT p90 (ms, log)", fontsize=12) ax.set_title("Decode TPOT during prefill chunk", fontsize=12, fontweight="bold") ax.grid(True, which="both", linestyle="--", alpha=0.4) ax.legend(fontsize=8, loc="upper left", ncol=2) # Panel B: prefill TTFT vs P ax = axes[1] for D in D_VALUES: ys = [stat.get((D, P), {}).get("ttft", np.nan) for P in P_VALUES] color = D_COLORS.get(D, "gray") ax.plot(P_VALUES, ys, "o-", color=color, label=f"D={D}", linewidth=2, markersize=7) ax.set_xscale("log", base=2) ax.set_xticks(P_VALUES) ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES]) ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12) ax.set_ylabel("Prefill TTFT (ms)", fontsize=12) ax.set_title("Cold prefill duration (interference window length)", fontsize=12, fontweight="bold") ax.grid(True, linestyle="--", alpha=0.4) ax.legend(fontsize=9, loc="upper left", title="Decode batch") fig.suptitle( "Prefill-Decode Interference · Qwen3-Coder-30B-A3B · H20 · chunk_size=8192", fontsize=13, fontweight="bold", y=1.02) plt.tight_layout() out_lines = OUT_DIR / "fig_interference_lines.png" plt.savefig(out_lines, dpi=160, bbox_inches="tight") print(f"Saved: {out_lines}") plt.close(fig) # ── print summary table ────────────────────────────────────────────────────── print(f"\n{'D':>3} {'P':>6} | {'bl_p90':>7} {'dur_p90':>8} {'idx_p90':>7} | {'ttft':>7}") print("-" * 55) for (D, P) in sorted(stat.keys()): s = stat[(D, P)] print(f"{D:>3} {P:>6} | {s['bl_p90']:>6.2f} {s['dur_p90']:>7.1f} {s['idx_p90']:>5.1f}x | {s['ttft']:>6.0f}")