Instrumentation patches (microbench/patches/):
- pd_profile.py: shared event emitter (VLLM_PD_PROFILE_LOG env var)
- apply_patches.py: idempotent patch installer for mooncake_connector.py
and scheduler.py, marks insertions with # PD_PROFILE_PATCH
- analyze_events.py: joins per-process JSONL event logs by transfer_id
into per-request phase durations
Seven events captured per request:
D_get_num_matched → P_zmq_received → P_prefill_done →
P_rdma_start → P_rdma_end → D_recv_complete → D_request_promoted
Driver fix (microbench/lifecycle/driver.py):
seed_prefix_cache now sends via the proxy URL so P and D both cache
the seeded prefix with matching block hashes. Previously seeding D
directly produced different block hashes than the proxy-routed
measurement requests, making incremental transfer impossible.
Real breakdown (fig_breakdown_real.png, server_breakdown.csv, n=93):
prefill_compute 620 ms median (95% of overhead)
rdma_transfer 42 ms median (~71 Gbps effective)
other overhead 10 ms median (dispatch + params + signal + promote)
Mooncake transfer is NOT the bottleneck. Even with bulk RDMA the
transfer cost is <10% of prefill cost for Qwen3-30B-A3B on H20.
214 lines
7.6 KiB
Python
214 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Plot REAL server-side breakdown from instrumented vLLM events.
|
|
|
|
Reads server_breakdown.csv (from analyze_events.py) and plots stacked bars:
|
|
- prefill_compute (P-side)
|
|
- rdma_transfer
|
|
- other server overhead (dispatch + build_params + completion + promote)
|
|
|
|
Grouped by total prompt tokens, colored by cache hit ratio band.
|
|
"""
|
|
|
|
import csv
|
|
import numpy as np
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as mpatches
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
HERE = Path(__file__).parent
|
|
CSV = HERE / "lifecycle/results/server_breakdown.csv"
|
|
OUT = HERE / "lifecycle/results/fig_breakdown_real.png"
|
|
|
|
# ── load ─────────────────────────────────────────────────────────────────────
|
|
rows = list(csv.DictReader(open(CSV)))
|
|
print(f"Loaded {len(rows)} request breakdowns")
|
|
|
|
def f(r, k, default=0.0):
|
|
v = r.get(k, "")
|
|
try:
|
|
return float(v) if v not in ("", None) else default
|
|
except ValueError:
|
|
return default
|
|
|
|
# Compute per-request fields
|
|
data = []
|
|
for r in rows:
|
|
prompt = int(f(r, "prompt_tokens"))
|
|
cached = int(f(r, "num_local_cached"))
|
|
delta = int(f(r, "delta_to_pull"))
|
|
if prompt == 0 or delta < 0:
|
|
continue
|
|
ratio = cached / prompt if prompt > 0 else 0.0
|
|
|
|
# Some requests have negative prefill_compute (e.g., the trivial 11-token case
|
|
# where P_zmq_received fires before D_get_num_matched). Skip those.
|
|
pf = f(r, "prefill_compute_ms")
|
|
if pf < 0:
|
|
continue
|
|
|
|
data.append({
|
|
"prompt": prompt,
|
|
"cached": cached,
|
|
"delta": delta,
|
|
"ratio": ratio,
|
|
"prefill_ms": pf,
|
|
"rdma_ms": f(r, "rdma_transfer_ms"),
|
|
"dispatch_ms": f(r, "d_to_p_dispatch_ms"),
|
|
"build_params_ms":f(r, "build_params_ms"),
|
|
"completion_ms": f(r, "completion_sig_ms"),
|
|
"promote_ms": f(r, "D_promote_ms"),
|
|
"rdma_bytes": f(r, "rdma_bytes"),
|
|
"bandwidth_gbps": f(r, "rdma_bandwidth_gbps"),
|
|
})
|
|
|
|
print(f"Usable: {len(data)} requests")
|
|
|
|
# ── bucket by (prompt size, cache band) ──────────────────────────────────────
|
|
# Total prompt size buckets
|
|
def bucket_N(n):
|
|
if n < 1500: return 1024
|
|
if n < 6000: return 4096
|
|
if n < 22000: return 16384
|
|
return 32768
|
|
|
|
def cache_band(r):
|
|
if r < 0.1: return "0% (cold)"
|
|
if r < 0.4: return "~25%"
|
|
if r < 0.6: return "~50%"
|
|
return "~75% (hot)"
|
|
|
|
agg = defaultdict(lambda: defaultdict(list))
|
|
for d in data:
|
|
nb = bucket_N(d["prompt"])
|
|
cb = cache_band(d["ratio"])
|
|
for k in ("prefill_ms", "rdma_ms", "dispatch_ms",
|
|
"build_params_ms", "completion_ms", "promote_ms",
|
|
"rdma_bytes", "bandwidth_gbps"):
|
|
agg[(nb, cb)][k].append(d[k])
|
|
|
|
# Stat per cell
|
|
summary = {}
|
|
for k, v in agg.items():
|
|
s = {kk: float(np.median(vv)) for kk, vv in v.items()}
|
|
s["n"] = len(v["prefill_ms"])
|
|
summary[k] = s
|
|
|
|
# ── plot ─────────────────────────────────────────────────────────────────────
|
|
N_BUCKETS = sorted({k[0] for k in summary})
|
|
BANDS_ALL = ["0% (cold)", "~25%", "~50%", "~75% (hot)"]
|
|
BANDS = [b for b in BANDS_ALL if any(k[1] == b for k in summary)]
|
|
|
|
C_PREFILL = "#d62728"
|
|
C_RDMA = "#ff7f0e"
|
|
C_OTHER = "#1f77b4"
|
|
|
|
BAND_ALPHAS = [1.0, 0.75, 0.50, 0.28]
|
|
BAND_HATCHES = [None, None, "///", "///"]
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6.5))
|
|
|
|
nN = len(N_BUCKETS)
|
|
nB = len(BANDS)
|
|
bar_w = 0.18
|
|
x_centers = np.arange(nN) * 1.0
|
|
offsets = np.linspace(-(nB-1)/2, (nB-1)/2, nB) * bar_w
|
|
|
|
ymax_data = 0
|
|
for j, band in enumerate(BANDS):
|
|
alpha = BAND_ALPHAS[j]
|
|
hatch = BAND_HATCHES[j]
|
|
xp = x_centers + offsets[j]
|
|
|
|
pf = np.array([summary.get((N, band), {}).get("prefill_ms", 0) for N in N_BUCKETS])
|
|
rd = np.array([summary.get((N, band), {}).get("rdma_ms", 0) for N in N_BUCKETS])
|
|
ot = np.array([
|
|
summary.get((N, band), {}).get("dispatch_ms", 0) +
|
|
summary.get((N, band), {}).get("build_params_ms",0) +
|
|
summary.get((N, band), {}).get("completion_ms", 0) +
|
|
summary.get((N, band), {}).get("promote_ms", 0)
|
|
for N in N_BUCKETS])
|
|
|
|
kw = dict(width=bar_w, alpha=alpha, edgecolor="white", linewidth=0.5)
|
|
if hatch: kw["hatch"] = hatch
|
|
|
|
ax.bar(xp, pf, color=C_PREFILL, **kw)
|
|
ax.bar(xp, rd, bottom=pf, color=C_RDMA, **kw)
|
|
ax.bar(xp, ot, bottom=pf+rd, color=C_OTHER, **kw)
|
|
|
|
total = pf + rd + ot
|
|
ymax_data = max(ymax_data, total.max() if len(total) > 0 else 0)
|
|
|
|
ymax = ymax_data * 1.18
|
|
ax.set_ylim(0, ymax)
|
|
|
|
# Value labels
|
|
for j, band in enumerate(BANDS):
|
|
alpha = BAND_ALPHAS[j]
|
|
xp = x_centers + offsets[j]
|
|
for i, N in enumerate(N_BUCKETS):
|
|
s = summary.get((N, band))
|
|
if s is None: continue
|
|
total = (s.get("prefill_ms",0) + s.get("rdma_ms",0) +
|
|
s.get("dispatch_ms",0) + s.get("build_params_ms",0) +
|
|
s.get("completion_ms",0) + s.get("promote_ms",0))
|
|
if total <= 0: continue
|
|
lbl = f"{total/1000:.1f}s" if total >= 1000 else f"{total:.0f}ms"
|
|
ax.text(xp[i], total + ymax*0.01, lbl,
|
|
ha="center", va="bottom", fontsize=7.2,
|
|
color="black", alpha=max(alpha, 0.55))
|
|
|
|
# X axis
|
|
ax.set_xticks(x_centers)
|
|
ax.set_xticklabels([f"{N//1024}k" for N in N_BUCKETS], fontsize=12)
|
|
ax.set_xlabel("Total prompt tokens (bucket)", fontsize=12)
|
|
ax.set_ylabel("Server-side latency (ms)", fontsize=12)
|
|
ax.set_title(
|
|
"REAL Server-Side PD-Sep Latency Breakdown\n"
|
|
"Qwen3-Coder-30B-A3B · H20 · Mooncake · from instrumented vLLM events",
|
|
fontsize=13, fontweight="bold")
|
|
ax.yaxis.grid(True, linestyle="--", alpha=0.35)
|
|
ax.set_axisbelow(True)
|
|
|
|
# Cache band sublabels
|
|
for j, band in enumerate(BANDS):
|
|
for x in x_centers:
|
|
xp = x + offsets[j]
|
|
short = band.split(" ")[0]
|
|
ax.text(xp, -ymax*0.035, short,
|
|
ha="center", va="top", fontsize=7,
|
|
color="dimgrey", alpha=max(BAND_ALPHAS[j], 0.5))
|
|
|
|
# Legend
|
|
phase = [
|
|
mpatches.Patch(color=C_PREFILL, label="Prefill compute (P node)"),
|
|
mpatches.Patch(color=C_RDMA, label="KV transfer (RDMA)"),
|
|
mpatches.Patch(color=C_OTHER, label="Scheduling overhead (dispatch+params+signal+promote)"),
|
|
]
|
|
spacer = mpatches.Patch(color="none", label="")
|
|
band_handles = [
|
|
mpatches.Patch(facecolor="grey", alpha=BAND_ALPHAS[j],
|
|
hatch=(BAND_HATCHES[j] or ""),
|
|
label=f"Cache hit {BANDS[j]}")
|
|
for j in range(nB)
|
|
]
|
|
ax.legend(handles=phase + [spacer] + band_handles,
|
|
loc="upper left", fontsize=8.5, framealpha=0.9,
|
|
ncol=2, columnspacing=1.0)
|
|
|
|
plt.tight_layout(rect=[0, 0.04, 1, 1])
|
|
plt.savefig(OUT, dpi=160, bbox_inches="tight")
|
|
print(f"Saved: {OUT}")
|
|
|
|
# ── print summary ────────────────────────────────────────────────────────────
|
|
print(f"\n{'N_bucket':>10} {'band':<15} {'n':>3} | {'prefill':>8} {'rdma':>7} {'other':>6} | {'total':>7}")
|
|
print("-" * 70)
|
|
for (N, band) in sorted(summary.keys()):
|
|
s = summary[(N, band)]
|
|
other = s["dispatch_ms"] + s["build_params_ms"] + s["completion_ms"] + s["promote_ms"]
|
|
total = s["prefill_ms"] + s["rdma_ms"] + other
|
|
print(f"{N:>10} {band:<15} {s['n']:>3} | {s['prefill_ms']:>8.0f} {s['rdma_ms']:>7.0f} {other:>6.1f} | {total:>7.0f}")
|