#!/usr/bin/env python3 """ Stacked-bar breakdown of PD-sep request latency. Axes: X : total input length (N_total), grouped by cache hit ratio Stacks: prefill compute (red) | KV transfer RDMA (orange) | decode (steelblue) Measured constants (H20, Qwen3-Coder-30B-A3B, from microbench): cold_prefill_ms(n) ≈ 0.072 * n (interference D=1 prefill_ttft, n=2k-16k) kv_transfer_ms(n) = 35 + n * 96KB * 8 / 25Gbps (warm Mooncake RDMA) decode_ms = output_tokens * 7.0ms/token """ import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from pathlib import Path HERE = Path(__file__).parent OUT = HERE / "lifecycle/results/fig_breakdown.png" OUT.parent.mkdir(parents=True, exist_ok=True) # ── measured constants ─────────────────────────────────────────────────────── MS_PER_TOK_COLD = 0.072 # ms / new token (cold prefill, linear regime) KV_BYTES_PER_TOK = 2*48*4*128*2 # 98304 B per token (Qwen3-30B-A3B) RDMA_BW_GBPS = 25 # effective Mooncake bandwidth (measured) RDMA_OVERHEAD_MS = 35 # warm-connection fixed overhead (measured) DECODE_MS_PER_TOK = 7.0 # TPOT baseline p50 OUTPUT_TOKENS = 128 # representative output length for decode bar def prefill_ms(n_new): return MS_PER_TOK_COLD * max(1, n_new) def transfer_ms(n_new): kv_bytes = KV_BYTES_PER_TOK * max(1, n_new) bw_ms = kv_bytes * 8 / (RDMA_BW_GBPS * 1e9) * 1000 return RDMA_OVERHEAD_MS + bw_ms # ── sweep parameters ───────────────────────────────────────────────────────── N_TOTALS = [1024, 2048, 4096, 8192, 16384, 32768] CACHE_RATIOS = [0.0, 0.25, 0.50, 0.75] CR_LABELS = ["0%", "25%", "50%", "75%"] CR_ALPHAS = [1.0, 0.75, 0.50, 0.28] CR_HATCHES = [None, None, "///", "///"] C_PREFILL = "#d62728" C_TRANSFER = "#ff7f0e" C_DECODE = "#1f77b4" # ── compute matrices ───────────────────────────────────────────────────────── nN, nC = len(N_TOTALS), len(CACHE_RATIOS) pf_mat = np.zeros((nN, nC)) tr_mat = np.zeros((nN, nC)) dec_mat = np.zeros((nN, nC)) for i, N in enumerate(N_TOTALS): for j, cr in enumerate(CACHE_RATIOS): n_new = max(1, int(N * (1 - cr))) pf_mat[i,j] = prefill_ms(n_new) tr_mat[i,j] = transfer_ms(n_new) dec_mat[i,j] = DECODE_MS_PER_TOK * OUTPUT_TOKENS # ── plot ───────────────────────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(13, 6.5)) bar_w = 0.18 group_gap = 1.0 x_centers = np.arange(nN) * group_gap offsets = np.linspace(-(nC-1)/2, (nC-1)/2, nC) * bar_w for j in range(nC): xp = x_centers + offsets[j] pf = pf_mat[:, j] tr = tr_mat[:, j] dc = dec_mat[:, j] alpha = CR_ALPHAS[j] hatch = CR_HATCHES[j] kw = dict(width=bar_w, alpha=alpha, edgecolor="white", linewidth=0.5) if hatch: kw["hatch"] = hatch ax.bar(xp, pf, color=C_PREFILL, **kw) ax.bar(xp, tr, bottom=pf, color=C_TRANSFER, **kw) ax.bar(xp, dc, bottom=pf+tr, color=C_DECODE, **kw) # value labels on top for xpos, total in zip(xp, pf + tr + dc): s = f"{total/1000:.1f}s" if total >= 1000 else f"{total:.0f}ms" ax.text(xpos, total + ax.get_ylim()[1]*0.01, s, ha="center", va="bottom", fontsize=7.2, color="black", alpha=max(alpha, 0.5)) # recompute ylim-based offsets after first pass ymax = (pf_mat + tr_mat + dec_mat).max() * 1.18 ax.set_ylim(0, ymax) # re-draw labels with correct ylim for j in range(nC): xp = x_centers + offsets[j] total = pf_mat[:,j] + tr_mat[:,j] + dec_mat[:,j] alpha = CR_ALPHAS[j] for xpos, t in zip(xp, total): s = f"{t/1000:.1f}s" if t >= 1000 else f"{t:.0f}ms" # already drawn above (approximate); skip redraw # cache-ratio sub-labels below bars for j in range(nC): for xi, x in enumerate(x_centers): xp = x + offsets[j] ax.text(xp, -ymax * 0.032, CR_LABELS[j], ha="center", va="top", fontsize=7.8, color="dimgrey", alpha=max(CR_ALPHAS[j], 0.4)) ax.text(x_centers[0] + offsets[0] - bar_w, -ymax * 0.032, "cache\nhit:", ha="right", va="top", fontsize=7.5, color="dimgrey", style="italic") ax.set_xticks(x_centers) ax.set_xticklabels([f"{N//1024}k" for N in N_TOTALS], fontsize=12) ax.set_xlabel("Total input tokens (N)", fontsize=12) ax.set_ylabel("Latency (ms)", fontsize=12) ax.set_title( "PD-Disaggregated Request Latency Breakdown\n" "Qwen3-Coder-30B-A3B · H20 · Mooncake RDMA · output=128 tokens", fontsize=13, fontweight="bold") ax.yaxis.grid(True, linestyle="--", alpha=0.35) ax.set_axisbelow(True) # ── legend ──────────────────────────────────────────────────────────────────── phase_h = [ mpatches.Patch(color=C_PREFILL, label="Prefill compute (P node)"), mpatches.Patch(color=C_TRANSFER, label="KV transfer (Mooncake RDMA)"), mpatches.Patch(color=C_DECODE, label="Decode generation (D node)"), ] spacer = mpatches.Patch(color="none", label="") cr_h = [ mpatches.Patch(facecolor="grey", alpha=CR_ALPHAS[j], hatch=(CR_HATCHES[j] or ""), label=f"KV cache hit {CR_LABELS[j]}") for j in range(nC) ] ax.legend(handles=phase_h + [spacer] + cr_h, loc="upper left", fontsize=9, framealpha=0.9, ncol=2, columnspacing=1.2, handlelength=1.5) plt.tight_layout(rect=[0, 0.05, 1, 1]) plt.savefig(OUT, dpi=160, bbox_inches="tight") print(f"Saved: {OUT}") # ── print table ────────────────────────────────────────────────────────────── print(f"\n{'N':>6} {'cache%':>7} | {'prefill':>8} {'transfer':>9} {'decode':>8} | {'E2E':>8}") print("-" * 60) for i, N in enumerate(N_TOTALS): for j, cr in enumerate(CACHE_RATIOS): pf = pf_mat[i,j]; tr = tr_mat[i,j]; dc = dec_mat[i,j] print(f"{N:>6} {cr*100:>6.0f}% | {pf:>8.0f} {tr:>9.0f} {dc:>8.0f} | {pf+tr+dc:>8.0f}") print()