plot_interference.py reads the interference sweep summary (4 D × 4 P × 3 reps,
cold prefill prompts) and produces:
fig_interference_heatmap.png
TPOT p90 interference index over (D, P): 14x at D=8 P=2k → 214x at D=1 P=32k.
fig_interference_lines.png
(a) TPOT p90 during prefill vs P, log-y, one line per D + baseline dashed
(b) Cold prefill TTFT vs P (interference window length)
Confirms B2 finding: cold prefill on the same worker stalls overlapping
decodes for 14-214x baseline TPOT. The interference window grows linearly
with P (from ~140ms at 2k to ~4.6s at 32k) and is essentially independent
of decode batch size — prefill compute time dominates.
159 lines
6.0 KiB
Python
159 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Plot prefill-decode interference results (Microbench 1).
|
||
|
||
Reads interference/results/summary.csv and produces two figures:
|
||
|
||
1. fig_interference_heatmap.png
|
||
Heatmap of TPOT p90 interference index (during/baseline) over (D, P).
|
||
|
||
2. fig_interference_lines.png
|
||
Two-panel: TPOT p90 during prefill (absolute, log) and prefill TTFT,
|
||
one line per decode batch size D, x-axis = prefill tokens P.
|
||
"""
|
||
|
||
import csv
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
|
||
HERE = Path(__file__).parent
|
||
CSV = HERE / "interference/results/summary.csv"
|
||
OUT_DIR = HERE / "interference/results"
|
||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# ── load + aggregate (median across reps) ────────────────────────────────────
|
||
rows = list(csv.DictReader(open(CSV)))
|
||
|
||
agg = defaultdict(list)
|
||
for r in rows:
|
||
D = int(r["decode_batch_size"])
|
||
P = int(r["new_prefill_tokens"])
|
||
if D == 0:
|
||
continue
|
||
bl_p90 = float(r["tpot_baseline_p90_ms"])
|
||
bl_p50 = float(r["tpot_baseline_p50_ms"])
|
||
dur_p90 = float(r["tpot_during_prefill_p90_ms"])
|
||
dur_p50 = float(r["tpot_during_prefill_p50_ms"])
|
||
ttft = float(r["prefill_ttft_ms"])
|
||
if bl_p90 <= 0 or dur_p90 <= 0:
|
||
continue
|
||
agg[(D, P)].append({
|
||
"bl_p50": bl_p50, "bl_p90": bl_p90,
|
||
"dur_p50": dur_p50, "dur_p90": dur_p90,
|
||
"ttft": ttft,
|
||
"idx_p90": dur_p90 / bl_p90,
|
||
"idx_p50": dur_p50 / bl_p50 if bl_p50 > 0 else 0,
|
||
})
|
||
|
||
stat = {k: {kk: float(np.median([e[kk] for e in v])) for kk in v[0]} for k, v in agg.items()}
|
||
|
||
D_VALUES = sorted({k[0] for k in stat})
|
||
P_VALUES = sorted({k[1] for k in stat})
|
||
|
||
# ── Figure 1: heatmap of interference index (TPOT p90 during / baseline) ─────
|
||
mat = np.full((len(D_VALUES), len(P_VALUES)), np.nan)
|
||
for i, D in enumerate(D_VALUES):
|
||
for j, P in enumerate(P_VALUES):
|
||
s = stat.get((D, P))
|
||
if s:
|
||
mat[i, j] = s["idx_p90"]
|
||
|
||
fig, ax = plt.subplots(figsize=(9.5, 5))
|
||
im = ax.imshow(mat, cmap="YlOrRd", aspect="auto",
|
||
norm=matplotlib.colors.LogNorm(vmin=1, vmax=mat[~np.isnan(mat)].max()))
|
||
|
||
ax.set_xticks(range(len(P_VALUES)))
|
||
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES], fontsize=11)
|
||
ax.set_yticks(range(len(D_VALUES)))
|
||
ax.set_yticklabels([f"D={D}" for D in D_VALUES], fontsize=11)
|
||
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
|
||
ax.set_ylabel("Decode batch size", fontsize=12)
|
||
ax.set_title("Prefill-Decode Interference Index\n"
|
||
"TPOT p90 during prefill / TPOT p90 baseline (log color)",
|
||
fontsize=13, fontweight="bold")
|
||
|
||
# annotate each cell
|
||
for i in range(len(D_VALUES)):
|
||
for j in range(len(P_VALUES)):
|
||
v = mat[i, j]
|
||
if not np.isnan(v):
|
||
txt_color = "white" if v > 50 else "black"
|
||
ax.text(j, i, f"{v:.0f}x",
|
||
ha="center", va="center",
|
||
fontsize=10, color=txt_color, fontweight="bold")
|
||
|
||
cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02)
|
||
cbar.set_label("Interference index (×)", fontsize=10)
|
||
|
||
plt.tight_layout()
|
||
out_heatmap = OUT_DIR / "fig_interference_heatmap.png"
|
||
plt.savefig(out_heatmap, dpi=160, bbox_inches="tight")
|
||
print(f"Saved: {out_heatmap}")
|
||
plt.close(fig)
|
||
|
||
# ── Figure 2: lines, two panels ──────────────────────────────────────────────
|
||
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
|
||
|
||
D_COLORS = {1: "#1f77b4", 2: "#2ca02c", 4: "#ff7f0e", 8: "#d62728"}
|
||
|
||
# Panel A: TPOT p90 during prefill (absolute, log y)
|
||
ax = axes[0]
|
||
for D in D_VALUES:
|
||
ys_dur = [stat.get((D, P), {}).get("dur_p90", np.nan) for P in P_VALUES]
|
||
ys_bl = [stat.get((D, P), {}).get("bl_p90", np.nan) for P in P_VALUES]
|
||
color = D_COLORS.get(D, "gray")
|
||
ax.plot(P_VALUES, ys_dur, "o-", color=color, label=f"D={D} (during prefill)",
|
||
linewidth=2, markersize=7)
|
||
ax.plot(P_VALUES, ys_bl, "s--", color=color, alpha=0.4,
|
||
label=f"D={D} (baseline)", linewidth=1, markersize=5)
|
||
|
||
ax.set_xscale("log", base=2)
|
||
ax.set_yscale("log")
|
||
ax.set_xticks(P_VALUES)
|
||
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES])
|
||
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
|
||
ax.set_ylabel("TPOT p90 (ms, log)", fontsize=12)
|
||
ax.set_title("Decode TPOT during prefill chunk", fontsize=12, fontweight="bold")
|
||
ax.grid(True, which="both", linestyle="--", alpha=0.4)
|
||
ax.legend(fontsize=8, loc="upper left", ncol=2)
|
||
|
||
# Panel B: prefill TTFT vs P
|
||
ax = axes[1]
|
||
for D in D_VALUES:
|
||
ys = [stat.get((D, P), {}).get("ttft", np.nan) for P in P_VALUES]
|
||
color = D_COLORS.get(D, "gray")
|
||
ax.plot(P_VALUES, ys, "o-", color=color, label=f"D={D}",
|
||
linewidth=2, markersize=7)
|
||
|
||
ax.set_xscale("log", base=2)
|
||
ax.set_xticks(P_VALUES)
|
||
ax.set_xticklabels([f"{P//1024}k" for P in P_VALUES])
|
||
ax.set_xlabel("Cold prefill size (P tokens)", fontsize=12)
|
||
ax.set_ylabel("Prefill TTFT (ms)", fontsize=12)
|
||
ax.set_title("Cold prefill duration (interference window length)",
|
||
fontsize=12, fontweight="bold")
|
||
ax.grid(True, linestyle="--", alpha=0.4)
|
||
ax.legend(fontsize=9, loc="upper left", title="Decode batch")
|
||
|
||
fig.suptitle(
|
||
"Prefill-Decode Interference · Qwen3-Coder-30B-A3B · H20 · chunk_size=8192",
|
||
fontsize=13, fontweight="bold", y=1.02)
|
||
|
||
plt.tight_layout()
|
||
out_lines = OUT_DIR / "fig_interference_lines.png"
|
||
plt.savefig(out_lines, dpi=160, bbox_inches="tight")
|
||
print(f"Saved: {out_lines}")
|
||
plt.close(fig)
|
||
|
||
# ── print summary table ──────────────────────────────────────────────────────
|
||
print(f"\n{'D':>3} {'P':>6} | {'bl_p90':>7} {'dur_p90':>8} {'idx_p90':>7} | {'ttft':>7}")
|
||
print("-" * 55)
|
||
for (D, P) in sorted(stat.keys()):
|
||
s = stat[(D, P)]
|
||
print(f"{D:>3} {P:>6} | {s['bl_p90']:>6.2f} {s['dur_p90']:>7.1f} {s['idx_p90']:>5.1f}x | {s['ttft']:>6.0f}")
|