Files
Gahow Wang dc8e6dd5a8 v2 exp(a): add remote KV-store (RDMA) tier
Extends the hit-latency microbench to a 4th tier: a remote global-KV-store
hit over RDMA, the Mooncake-Store mechanism. Two kv_both MooncakeConnector
instances (run_rdma.sh); for each prefix length, instance B serves the
request by pulling instance A's cached prefix over RDMA (do_remote_prefill,
via microbench/fresh_setup/mb2_kv_transfer.py) instead of recomputing -- the
timed pull is the remote-hit latency.

Result (TTFT p50, 11 reps): strict tier ordering
GPU(HBM) < CPU(local DRAM) < remote-RDMA-store << miss, gaps growing with
context. At 64k: GPU 0.11s, CPU 0.27s, RDMA 0.97s, miss 15.2s -> miss/RDMA
15.8x, RDMA/CPU 3.6x, CPU/GPU 2.4x. So a global RDMA store is a real win
over recompute (the blog's 46x) but pays the NIC tax (~5-7 GB/s effective)
and sits a tier below local CPU and two below GPU -- reinforcing
GPU-hit-first. README + figure updated to four tiers.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-30 12:48:37 +08:00

94 lines
3.4 KiB
Python

"""Plot exp (a): TTFT vs prefix length for miss / gpu-hit / cpu-hit (+ PCIe floor)."""
import json
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
R = Path(sys.argv[1] if len(sys.argv) > 1 else "v2/exp_a_tier_latency/results")
FIG = Path(sys.argv[2] if len(sys.argv) > 2 else "v2/figs/exp_a_tier_latency.png")
KV_BYTES_PER_TOKEN = 98304
def load(name):
p = R / name
return json.load(open(p)) if p.exists() else None
miss, gpu, cpu, pcie = load("miss.json"), load("gpu.json"), load("cpu.json"), load("pcie.json")
rdma = load("rdma.json")
def series(d):
if not d:
return [], []
items = sorted(((int(k), v["ttft_p50"]) for k, v in d["by_length"].items()), key=lambda x: x[0])
return [a for a, _ in items], [b for _, b in items]
def rdma_series():
"""Remote KV-store hit over RDMA: p50 of t_transfer_s per prefix length
(dst pulls the cached prefix from the remote pool instead of recomputing)."""
if not rdma:
return [], {}
import statistics
from collections import defaultdict
by = defaultdict(list)
for r in rdma["raw"]:
by[r["input_tokens"]].append(r["t_transfer_s"])
xs = sorted(by)
return xs, {L: statistics.median(by[L]) for L in xs}
rdma_x, rdma_p50 = rdma_series()
fig, ax = plt.subplots(figsize=(7.2, 5.0))
for d, lab, mk, c in [(miss, "miss (recompute)", "o", "#d62728"),
(cpu, "CPU-tier hit (local DRAM, PCIe)", "s", "#ff7f0e"),
(gpu, "GPU-tier hit (HBM APC)", "^", "#2ca02c")]:
xs, ys = series(d)
if xs:
ax.plot(xs, ys, marker=mk, label=lab, color=c, linewidth=2, markersize=7)
if rdma_x:
ax.plot(rdma_x, [rdma_p50[L] for L in rdma_x], marker="D", color="#9467bd",
linewidth=2, markersize=7, label="remote KV-store hit (Mooncake RDMA)")
if pcie:
items = sorted(((int(k), v["transfer_s"]) for k, v in pcie["by_length"].items()))
xs = [a for a, _ in items]; ys = [b for _, b in items]
ax.plot(xs, ys, "--", color="#7f7f7f", linewidth=1.4,
label="CPU-hit transfer floor (PCIe H2D)")
ax.set_xscale("log", base=2); ax.set_yscale("log")
ax.set_xlabel("Reused prefix length (tokens)")
ax.set_ylabel("TTFT (s, log)")
ax.set_title("Cost of serving a reused prefix from each KV tier\n"
"Qwen3-Coder-30B-A3B, H20 (local tiers 1 GPU; RDMA pool 2 GPUs)")
ax.grid(True, which="both", alpha=0.3)
ax.legend()
FIG.parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout(); fig.savefig(FIG, dpi=140)
print("wrote", FIG)
# Table
print(f"\n{'L':>7} {'miss':>9} {'rdma':>9} {'cpu':>9} {'gpu':>9} "
f"{'miss/rdma':>9} {'rdma/cpu':>9} {'cpu/gpu':>9}")
allL = sorted({int(k) for d in (miss, gpu, cpu) if d for k in d["by_length"]})
for L in allL:
m = miss["by_length"].get(str(L), {}).get("ttft_p50") if miss else None
c = cpu["by_length"].get(str(L), {}).get("ttft_p50") if cpu else None
g = gpu["by_length"].get(str(L), {}).get("ttft_p50") if gpu else None
rd = rdma_p50.get(L)
f = lambda x: f"{x:.4f}" if x is not None else " - "
rr = lambda a, b: f"{a/b:.1f}x" if (a and b) else " -"
print(f"{L:>7} {f(m):>9} {f(rd):>9} {f(c):>9} {f(g):>9} "
f"{rr(m,rd):>9} {rr(rd,c):>9} {rr(c,g):>9}")
if cpu:
vf = {k: v.get("verified_frac") for k, v in cpu["by_length"].items()}
print("\nCPU-tier verified fraction (ext_hits>0):", vf)