#!/usr/bin/env python3 """ Plot REAL server-side breakdown from instrumented vLLM events. Reads server_breakdown.csv (from analyze_events.py) and plots stacked bars: - prefill_compute (P-side) - rdma_transfer - other server overhead (dispatch + build_params + completion + promote) Grouped by total prompt tokens, colored by cache hit ratio band. """ import csv import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from pathlib import Path from collections import defaultdict HERE = Path(__file__).parent CSV = HERE / "lifecycle/results/server_breakdown.csv" OUT = HERE / "lifecycle/results/fig_breakdown_real.png" # ── load ───────────────────────────────────────────────────────────────────── rows = list(csv.DictReader(open(CSV))) print(f"Loaded {len(rows)} request breakdowns") def f(r, k, default=0.0): v = r.get(k, "") try: return float(v) if v not in ("", None) else default except ValueError: return default # Compute per-request fields data = [] for r in rows: prompt = int(f(r, "prompt_tokens")) cached = int(f(r, "num_local_cached")) delta = int(f(r, "delta_to_pull")) if prompt == 0 or delta < 0: continue ratio = cached / prompt if prompt > 0 else 0.0 # Some requests have negative prefill_compute (e.g., the trivial 11-token case # where P_zmq_received fires before D_get_num_matched). Skip those. pf = f(r, "prefill_compute_ms") if pf < 0: continue data.append({ "prompt": prompt, "cached": cached, "delta": delta, "ratio": ratio, "prefill_ms": pf, "rdma_ms": f(r, "rdma_transfer_ms"), "dispatch_ms": f(r, "d_to_p_dispatch_ms"), "build_params_ms":f(r, "build_params_ms"), "completion_ms": f(r, "completion_sig_ms"), "promote_ms": f(r, "D_promote_ms"), "rdma_bytes": f(r, "rdma_bytes"), "bandwidth_gbps": f(r, "rdma_bandwidth_gbps"), }) print(f"Usable: {len(data)} requests") # ── bucket by (prompt size, cache band) ────────────────────────────────────── # Total prompt size buckets def bucket_N(n): if n < 1500: return 1024 if n < 6000: return 4096 if n < 22000: return 16384 return 32768 def cache_band(r): if r < 0.1: return "0% (cold)" if r < 0.4: return "~25%" if r < 0.6: return "~50%" return "~75% (hot)" agg = defaultdict(lambda: defaultdict(list)) for d in data: nb = bucket_N(d["prompt"]) cb = cache_band(d["ratio"]) for k in ("prefill_ms", "rdma_ms", "dispatch_ms", "build_params_ms", "completion_ms", "promote_ms", "rdma_bytes", "bandwidth_gbps"): agg[(nb, cb)][k].append(d[k]) # Stat per cell summary = {} for k, v in agg.items(): s = {kk: float(np.median(vv)) for kk, vv in v.items()} s["n"] = len(v["prefill_ms"]) summary[k] = s # ── plot ───────────────────────────────────────────────────────────────────── N_BUCKETS = sorted({k[0] for k in summary}) BANDS_ALL = ["0% (cold)", "~25%", "~50%", "~75% (hot)"] BANDS = [b for b in BANDS_ALL if any(k[1] == b for k in summary)] C_PREFILL = "#d62728" C_RDMA = "#ff7f0e" C_OTHER = "#1f77b4" BAND_ALPHAS = [1.0, 0.75, 0.50, 0.28] BAND_HATCHES = [None, None, "///", "///"] fig, ax = plt.subplots(figsize=(12, 6.5)) nN = len(N_BUCKETS) nB = len(BANDS) bar_w = 0.18 x_centers = np.arange(nN) * 1.0 offsets = np.linspace(-(nB-1)/2, (nB-1)/2, nB) * bar_w ymax_data = 0 for j, band in enumerate(BANDS): alpha = BAND_ALPHAS[j] hatch = BAND_HATCHES[j] xp = x_centers + offsets[j] pf = np.array([summary.get((N, band), {}).get("prefill_ms", 0) for N in N_BUCKETS]) rd = np.array([summary.get((N, band), {}).get("rdma_ms", 0) for N in N_BUCKETS]) ot = np.array([ summary.get((N, band), {}).get("dispatch_ms", 0) + summary.get((N, band), {}).get("build_params_ms",0) + summary.get((N, band), {}).get("completion_ms", 0) + summary.get((N, band), {}).get("promote_ms", 0) for N in N_BUCKETS]) kw = dict(width=bar_w, alpha=alpha, edgecolor="white", linewidth=0.5) if hatch: kw["hatch"] = hatch ax.bar(xp, pf, color=C_PREFILL, **kw) ax.bar(xp, rd, bottom=pf, color=C_RDMA, **kw) ax.bar(xp, ot, bottom=pf+rd, color=C_OTHER, **kw) total = pf + rd + ot ymax_data = max(ymax_data, total.max() if len(total) > 0 else 0) ymax = ymax_data * 1.18 ax.set_ylim(0, ymax) # Value labels for j, band in enumerate(BANDS): alpha = BAND_ALPHAS[j] xp = x_centers + offsets[j] for i, N in enumerate(N_BUCKETS): s = summary.get((N, band)) if s is None: continue total = (s.get("prefill_ms",0) + s.get("rdma_ms",0) + s.get("dispatch_ms",0) + s.get("build_params_ms",0) + s.get("completion_ms",0) + s.get("promote_ms",0)) if total <= 0: continue lbl = f"{total/1000:.1f}s" if total >= 1000 else f"{total:.0f}ms" ax.text(xp[i], total + ymax*0.01, lbl, ha="center", va="bottom", fontsize=7.2, color="black", alpha=max(alpha, 0.55)) # X axis ax.set_xticks(x_centers) ax.set_xticklabels([f"{N//1024}k" for N in N_BUCKETS], fontsize=12) ax.set_xlabel("Total prompt tokens (bucket)", fontsize=12) ax.set_ylabel("Server-side latency (ms)", fontsize=12) ax.set_title( "REAL Server-Side PD-Sep Latency Breakdown\n" "Qwen3-Coder-30B-A3B · H20 · Mooncake · from instrumented vLLM events", fontsize=13, fontweight="bold") ax.yaxis.grid(True, linestyle="--", alpha=0.35) ax.set_axisbelow(True) # Cache band sublabels for j, band in enumerate(BANDS): for x in x_centers: xp = x + offsets[j] short = band.split(" ")[0] ax.text(xp, -ymax*0.035, short, ha="center", va="top", fontsize=7, color="dimgrey", alpha=max(BAND_ALPHAS[j], 0.5)) # Legend phase = [ mpatches.Patch(color=C_PREFILL, label="Prefill compute (P node)"), mpatches.Patch(color=C_RDMA, label="KV transfer (RDMA)"), mpatches.Patch(color=C_OTHER, label="Scheduling overhead (dispatch+params+signal+promote)"), ] spacer = mpatches.Patch(color="none", label="") band_handles = [ mpatches.Patch(facecolor="grey", alpha=BAND_ALPHAS[j], hatch=(BAND_HATCHES[j] or ""), label=f"Cache hit {BANDS[j]}") for j in range(nB) ] ax.legend(handles=phase + [spacer] + band_handles, loc="upper left", fontsize=8.5, framealpha=0.9, ncol=2, columnspacing=1.0) plt.tight_layout(rect=[0, 0.04, 1, 1]) plt.savefig(OUT, dpi=160, bbox_inches="tight") print(f"Saved: {OUT}") # ── print summary ──────────────────────────────────────────────────────────── print(f"\n{'N_bucket':>10} {'band':<15} {'n':>3} | {'prefill':>8} {'rdma':>7} {'other':>6} | {'total':>7}") print("-" * 70) for (N, band) in sorted(summary.keys()): s = summary[(N, band)] other = s["dispatch_ms"] + s["build_params_ms"] + s["completion_ms"] + s["promote_ms"] total = s["prefill_ms"] + s["rdma_ms"] + other print(f"{N:>10} {band:<15} {s['n']:>3} | {s['prefill_ms']:>8.0f} {s['rdma_ms']:>7.0f} {other:>6.1f} | {total:>7.0f}")