Files
agentic-kvc/microbench/plot_interference.py
Gahow Wang 06dd175441 Microbench 1 plots: prefill-decode interference heatmap + lines
plot_interference.py reads the interference sweep summary (4 D × 4 P × 3 reps,
cold prefill prompts) and produces:

  fig_interference_heatmap.png
    TPOT p90 interference index over (D, P): 14x at D=8 P=2k → 214x at D=1 P=32k.

  fig_interference_lines.png
    (a) TPOT p90 during prefill vs P, log-y, one line per D + baseline dashed
    (b) Cold prefill TTFT vs P (interference window length)

Confirms B2 finding: cold prefill on the same worker stalls overlapping
decodes for 14-214x baseline TPOT. The interference window grows linearly
with P (from ~140ms at 2k to ~4.6s at 32k) and is essentially independent
of decode batch size — prefill compute time dominates.
2026-05-26 14:21:30 +08:00

159 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Plot prefill-decode interference results (Microbench 1).
Reads interference/results/summary.csv and produces two figures:
1. fig_interference_heatmap.png
Heatmap of TPOT p90 interference index (during/baseline) over (D, P).
2. fig_interference_lines.png
Two-panel: TPOT p90 during prefill (absolute, log) and prefill TTFT,
one line per decode batch size D, x-axis = prefill tokens P.
"""
import csv
from collections import defaultdict
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
HERE = Path(__file__).parent
CSV = HERE / "interference/results/summary.csv"
OUT_DIR = HERE / "interference/results"
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ── load + aggregate (median across reps) ────────────────────────────────────
rows = list(csv.DictReader(open(CSV)))
agg = defaultdict(list)
for r in rows:
D = int(r["decode_batch_size"])
P = int(r["new_prefill_tokens"])
if D == 0:
continue
bl_p90 = float(r["tpot_baseline_p90_ms"])
bl_p50 = float(r["tpot_baseline_p50_ms"])
dur_p90 = float(r["tpot_during_prefill_p90_ms"])
dur_p50 = float(r["tpot_during_prefill_p50_ms"])
ttft = float(r["prefill_ttft_ms"])
if bl_p90 <= 0 or dur_p90 <= 0:
continue
agg[(D, P)].append({
"bl_p50": bl_p50, "bl_p90": bl_p90,
"dur_p50": dur_p50, "dur_p90": dur_p90,
"ttft": ttft,
"idx_p90": dur_p90 / bl_p90,
"idx_p50": dur_p50 / bl_p50 if bl_p50 > 0 else 0,
})
stat = {k: {kk: float(np.median([e[kk] for e in v])) for kk in v[0]} for k, v in agg.items()}
D_VALUES = sorted({k[0] for k in stat})
P_VALUES = sorted({k[1] for k in stat})
# ── Figure 1: heatmap of interference index (TPOT p90 during / baseline) ─────
mat = np.full((len(D_VALUES), len(P_VALUES)), np.nan)
for i, D in enumerate(D_VALUES):
for j, P in enumerate(P_VALUES):
s = stat.get((D, P))
if s:
mat[i, j] = s["idx_p90"]
fig, ax = plt.subplots(figsize=(9.5, 5))
im = ax.imshow(mat, cmap="YlOrRd", aspect="auto",
norm=matplotlib.colors.LogNorm(vmin=1, vmax=mat[~np.isnan(mat)].max()))
ax.set_xticks(range(len(P_VALUES)))
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES], fontsize=11)
ax.set_yticks(range(len(D_VALUES)))
ax.set_yticklabels([f"D={D}" for D in D_VALUES], fontsize=11)
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
ax.set_ylabel("Decode batch size", fontsize=12)
ax.set_title("Prefill-Decode Interference Index\n"
"TPOT p90 during prefill / TPOT p90 baseline (log color)",
fontsize=13, fontweight="bold")
# annotate each cell
for i in range(len(D_VALUES)):
for j in range(len(P_VALUES)):
v = mat[i, j]
if not np.isnan(v):
txt_color = "white" if v > 50 else "black"
ax.text(j, i, f"{v:.0f}x",
ha="center", va="center",
fontsize=10, color=txt_color, fontweight="bold")
cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)
cbar.set_label("Interference index (×)", fontsize=10)
plt.tight_layout()
out_heatmap = OUT_DIR / "fig_interference_heatmap.png"
plt.savefig(out_heatmap, dpi=160, bbox_inches="tight")
print(f"Saved: {out_heatmap}")
plt.close(fig)
# ── Figure 2: lines, two panels ──────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
D_COLORS = {1: "#1f77b4", 2: "#2ca02c", 4: "#ff7f0e", 8: "#d62728"}
# Panel A: TPOT p90 during prefill (absolute, log y)
ax = axes[0]
for D in D_VALUES:
ys_dur = [stat.get((D, P), {}).get("dur_p90", np.nan) for P in P_VALUES]
ys_bl = [stat.get((D, P), {}).get("bl_p90", np.nan) for P in P_VALUES]
color = D_COLORS.get(D, "gray")
ax.plot(P_VALUES, ys_dur, "o-", color=color, label=f"D={D} (during prefill)",
linewidth=2, markersize=7)
ax.plot(P_VALUES, ys_bl, "s--", color=color, alpha=0.4,
label=f"D={D} (baseline)", linewidth=1, markersize=5)
ax.set_xscale("log", base=2)
ax.set_yscale("log")
ax.set_xticks(P_VALUES)
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES])
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
ax.set_ylabel("TPOT p90 (ms, log)", fontsize=12)
ax.set_title("Decode TPOT during prefill chunk", fontsize=12, fontweight="bold")
ax.grid(True, which="both", linestyle="--", alpha=0.4)
ax.legend(fontsize=8, loc="upper left", ncol=2)
# Panel B: prefill TTFT vs P
ax = axes[1]
for D in D_VALUES:
ys = [stat.get((D, P), {}).get("ttft", np.nan) for P in P_VALUES]
color = D_COLORS.get(D, "gray")
ax.plot(P_VALUES, ys, "o-", color=color, label=f"D={D}",
linewidth=2, markersize=7)
ax.set_xscale("log", base=2)
ax.set_xticks(P_VALUES)
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES])
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
ax.set_ylabel("Prefill TTFT (ms)", fontsize=12)
ax.set_title("Cold prefill duration (interference window length)",
fontsize=12, fontweight="bold")
ax.grid(True, linestyle="--", alpha=0.4)
ax.legend(fontsize=9, loc="upper left", title="Decode batch")
fig.suptitle(
"Prefill-Decode Interference · Qwen3-Coder-30B-A3B · H20 · chunk_size=8192",
fontsize=13, fontweight="bold", y=1.02)
plt.tight_layout()
out_lines = OUT_DIR / "fig_interference_lines.png"
plt.savefig(out_lines, dpi=160, bbox_inches="tight")
print(f"Saved: {out_lines}")
plt.close(fig)
# ── print summary table ──────────────────────────────────────────────────────
print(f"\n{'D':>3} {'P':>6} | {'bl_p90':>7} {'dur_p90':>8} {'idx_p90':>7} | {'ttft':>7}")
print("-" * 55)
for (D, P) in sorted(stat.keys()):
s = stat[(D, P)]
print(f"{D:>3} {P:>6} | {s['bl_p90']:>6.2f} {s['dur_p90']:>7.1f} {s['idx_p90']:>5.1f}x | {s['ttft']:>6.0f}")