Files
agentic-kvc/microbench/plot_breakdown_real.py
Gahow Wang 72790ae6c1 PD-sep server-side profiling: vLLM patches + per-request breakdown
Instrumentation patches (microbench/patches/):
  - pd_profile.py: shared event emitter (VLLM_PD_PROFILE_LOG env var)
  - apply_patches.py: idempotent patch installer for mooncake_connector.py
    and scheduler.py, marks insertions with # PD_PROFILE_PATCH
  - analyze_events.py: joins per-process JSONL event logs by transfer_id
    into per-request phase durations

Seven events captured per request:
  D_get_num_matched → P_zmq_received → P_prefill_done →
  P_rdma_start → P_rdma_end → D_recv_complete → D_request_promoted

Driver fix (microbench/lifecycle/driver.py):
  seed_prefix_cache now sends via the proxy URL so P and D both cache
  the seeded prefix with matching block hashes. Previously seeding D
  directly produced different block hashes than the proxy-routed
  measurement requests, making incremental transfer impossible.

Real breakdown (fig_breakdown_real.png, server_breakdown.csv, n=93):
  prefill_compute  620 ms median (95% of overhead)
  rdma_transfer     42 ms median (~71 Gbps effective)
  other overhead    10 ms median (dispatch + params + signal + promote)

Mooncake transfer is NOT the bottleneck. Even with bulk RDMA the
transfer cost is <10% of prefill cost for Qwen3-30B-A3B on H20.
2026-05-26 13:59:09 +08:00

214 lines
7.6 KiB
Python

#!/usr/bin/env python3
"""
Plot REAL server-side breakdown from instrumented vLLM events.
Reads server_breakdown.csv (from analyze_events.py) and plots stacked bars:
- prefill_compute (P-side)
- rdma_transfer
- other server overhead (dispatch + build_params + completion + promote)
Grouped by total prompt tokens, colored by cache hit ratio band.
"""
import csv
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
from collections import defaultdict
HERE = Path(__file__).parent
CSV = HERE / "lifecycle/results/server_breakdown.csv"
OUT = HERE / "lifecycle/results/fig_breakdown_real.png"
# ── load ─────────────────────────────────────────────────────────────────────
rows = list(csv.DictReader(open(CSV)))
print(f"Loaded {len(rows)} request breakdowns")
def f(r, k, default=0.0):
v = r.get(k, "")
try:
return float(v) if v not in ("", None) else default
except ValueError:
return default
# Compute per-request fields
data = []
for r in rows:
prompt = int(f(r, "prompt_tokens"))
cached = int(f(r, "num_local_cached"))
delta = int(f(r, "delta_to_pull"))
if prompt == 0 or delta < 0:
continue
ratio = cached / prompt if prompt > 0 else 0.0
# Some requests have negative prefill_compute (e.g., the trivial 11-token case
# where P_zmq_received fires before D_get_num_matched). Skip those.
pf = f(r, "prefill_compute_ms")
if pf < 0:
continue
data.append({
"prompt": prompt,
"cached": cached,
"delta": delta,
"ratio": ratio,
"prefill_ms": pf,
"rdma_ms": f(r, "rdma_transfer_ms"),
"dispatch_ms": f(r, "d_to_p_dispatch_ms"),
"build_params_ms":f(r, "build_params_ms"),
"completion_ms": f(r, "completion_sig_ms"),
"promote_ms": f(r, "D_promote_ms"),
"rdma_bytes": f(r, "rdma_bytes"),
"bandwidth_gbps": f(r, "rdma_bandwidth_gbps"),
})
print(f"Usable: {len(data)} requests")
# ── bucket by (prompt size, cache band) ──────────────────────────────────────
# Total prompt size buckets
def bucket_N(n):
if n < 1500: return 1024
if n < 6000: return 4096
if n < 22000: return 16384
return 32768
def cache_band(r):
if r < 0.1: return "0% (cold)"
if r < 0.4: return "~25%"
if r < 0.6: return "~50%"
return "~75% (hot)"
agg = defaultdict(lambda: defaultdict(list))
for d in data:
nb = bucket_N(d["prompt"])
cb = cache_band(d["ratio"])
for k in ("prefill_ms", "rdma_ms", "dispatch_ms",
"build_params_ms", "completion_ms", "promote_ms",
"rdma_bytes", "bandwidth_gbps"):
agg[(nb, cb)][k].append(d[k])
# Stat per cell
summary = {}
for k, v in agg.items():
s = {kk: float(np.median(vv)) for kk, vv in v.items()}
s["n"] = len(v["prefill_ms"])
summary[k] = s
# ── plot ─────────────────────────────────────────────────────────────────────
N_BUCKETS = sorted({k[0] for k in summary})
BANDS_ALL = ["0% (cold)", "~25%", "~50%", "~75% (hot)"]
BANDS = [b for b in BANDS_ALL if any(k[1] == b for k in summary)]
C_PREFILL = "#d62728"
C_RDMA = "#ff7f0e"
C_OTHER = "#1f77b4"
BAND_ALPHAS = [1.0, 0.75, 0.50, 0.28]
BAND_HATCHES = [None, None, "///", "///"]
fig, ax = plt.subplots(figsize=(12, 6.5))
nN = len(N_BUCKETS)
nB = len(BANDS)
bar_w = 0.18
x_centers = np.arange(nN) * 1.0
offsets = np.linspace(-(nB-1)/2, (nB-1)/2, nB) * bar_w
ymax_data = 0
for j, band in enumerate(BANDS):
alpha = BAND_ALPHAS[j]
hatch = BAND_HATCHES[j]
xp = x_centers + offsets[j]
pf = np.array([summary.get((N, band), {}).get("prefill_ms", 0) for N in N_BUCKETS])
rd = np.array([summary.get((N, band), {}).get("rdma_ms", 0) for N in N_BUCKETS])
ot = np.array([
summary.get((N, band), {}).get("dispatch_ms", 0) +
summary.get((N, band), {}).get("build_params_ms",0) +
summary.get((N, band), {}).get("completion_ms", 0) +
summary.get((N, band), {}).get("promote_ms", 0)
for N in N_BUCKETS])
kw = dict(width=bar_w, alpha=alpha, edgecolor="white", linewidth=0.5)
if hatch: kw["hatch"] = hatch
ax.bar(xp, pf, color=C_PREFILL, **kw)
ax.bar(xp, rd, bottom=pf, color=C_RDMA, **kw)
ax.bar(xp, ot, bottom=pf+rd, color=C_OTHER, **kw)
total = pf + rd + ot
ymax_data = max(ymax_data, total.max() if len(total) > 0 else 0)
ymax = ymax_data * 1.18
ax.set_ylim(0, ymax)
# Value labels
for j, band in enumerate(BANDS):
alpha = BAND_ALPHAS[j]
xp = x_centers + offsets[j]
for i, N in enumerate(N_BUCKETS):
s = summary.get((N, band))
if s is None: continue
total = (s.get("prefill_ms",0) + s.get("rdma_ms",0) +
s.get("dispatch_ms",0) + s.get("build_params_ms",0) +
s.get("completion_ms",0) + s.get("promote_ms",0))
if total <= 0: continue
lbl = f"{total/1000:.1f}s" if total >= 1000 else f"{total:.0f}ms"
ax.text(xp[i], total + ymax*0.01, lbl,
ha="center", va="bottom", fontsize=7.2,
color="black", alpha=max(alpha, 0.55))
# X axis
ax.set_xticks(x_centers)
ax.set_xticklabels([f"{N//1024}k" for N in N_BUCKETS], fontsize=12)
ax.set_xlabel("Total prompt tokens (bucket)", fontsize=12)
ax.set_ylabel("Server-side latency (ms)", fontsize=12)
ax.set_title(
"REAL Server-Side PD-Sep Latency Breakdown\n"
"Qwen3-Coder-30B-A3B · H20 · Mooncake · from instrumented vLLM events",
fontsize=13, fontweight="bold")
ax.yaxis.grid(True, linestyle="--", alpha=0.35)
ax.set_axisbelow(True)
# Cache band sublabels
for j, band in enumerate(BANDS):
for x in x_centers:
xp = x + offsets[j]
short = band.split(" ")[0]
ax.text(xp, -ymax*0.035, short,
ha="center", va="top", fontsize=7,
color="dimgrey", alpha=max(BAND_ALPHAS[j], 0.5))
# Legend
phase = [
mpatches.Patch(color=C_PREFILL, label="Prefill compute (P node)"),
mpatches.Patch(color=C_RDMA, label="KV transfer (RDMA)"),
mpatches.Patch(color=C_OTHER, label="Scheduling overhead (dispatch+params+signal+promote)"),
]
spacer = mpatches.Patch(color="none", label="")
band_handles = [
mpatches.Patch(facecolor="grey", alpha=BAND_ALPHAS[j],
hatch=(BAND_HATCHES[j] or ""),
label=f"Cache hit {BANDS[j]}")
for j in range(nB)
]
ax.legend(handles=phase + [spacer] + band_handles,
loc="upper left", fontsize=8.5, framealpha=0.9,
ncol=2, columnspacing=1.0)
plt.tight_layout(rect=[0, 0.04, 1, 1])
plt.savefig(OUT, dpi=160, bbox_inches="tight")
print(f"Saved: {OUT}")
# ── print summary ────────────────────────────────────────────────────────────
print(f"\n{'N_bucket':>10} {'band':<15} {'n':>3} | {'prefill':>8} {'rdma':>7} {'other':>6} | {'total':>7}")
print("-" * 70)
for (N, band) in sorted(summary.keys()):
s = summary[(N, band)]
other = s["dispatch_ms"] + s["build_params_ms"] + s["completion_ms"] + s["promote_ms"]
total = s["prefill_ms"] + s["rdma_ms"] + other
print(f"{N:>10} {band:<15} {s['n']:>3} | {s['prefill_ms']:>8.0f} {s['rdma_ms']:>7.0f} {other:>6.1f} | {total:>7.0f}")