Files
agentic-kvc/microbench/fresh_setup/analyze_mb2.py
Gahow Wang de164e5a64 MB2: pure KV-transfer cost on dash1 intra-node — Mooncake ~9.7 GB/s steady
Full sweep result on dash1 GPU 0+1 with vanilla vLLM 0.18.1 +
mooncake-transfer-engine 0.3.11, kv_both connector. Per-stage decomposition
via the instrumentation patch (analyze_mb2.py pairs A's send_blocks with
B's receive_kv enter/finish by time window).

Steady-state (1k..32k tokens, 96 MiB..3 GiB KV):
   pure_transfer ≈ size / 9.7 GB/s
   rx_overhead   ≈ 2–3 ms (ZMQ handshake + P-side setup)
   bandwidth     ≈ 9.6–10.1 GB/s, very stable

Large-size regime (65k..131k tokens, 6..12 GiB):
   p50 bandwidth collapses to 3.4–4.5 GB/s
   max bandwidth still hits ~9.7 GB/s (some runs achieve it)
   p99 agentic request (11.5 GiB) lands here

Implication for §3.2 PD-disaggregation cost argument:
   median agentic decode = 50–200 ms (tool-call JSON output)
   median agentic-tail KV transfer (p99 11.5 GiB):
     best case (9.7 GB/s)  ≈ 1.19 s
     observed range         1.5 – 10 s
   ⇒ KV transfer is 8–100× larger than the decode it enables.

This is intra-node — the lower-bound transfer cost. Inter-node RDMA
will be slower; that's MB2 phase 2.

Adds:
- analyze_mb2.py: pair A.send_blocks ↔ B.receive_kv by time window;
  per-size aggregation (n, ms_p50, ms_min/max, GB/s_p50/max)
- plot_mb2.py: log-log transfer-time chart + bandwidth-vs-size chart
- analysis/mb2/A_intra_kvboth.jsonl, B_intra_kvboth.jsonl: raw events
  (51 + 102 events including the sanity preamble)
- analysis/mb2/intra_kvboth_breakdown.json: paired and aggregated
- figs/mb2_transfer_time_intra.png, figs/mb2_transfer_bw_intra.png

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-27 19:04:03 +08:00

179 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""Decompose MB2 transfer events into the per-stage breakdown.
Inputs:
--a-log P-side jsonl with `send_blocks` events
{event=send_blocks, total_bytes, duration_s, t_start_unix, ...}
--b-log D-side jsonl with `receive_kv_enter` and `receive_kv_finish` events
{event=receive_kv_*, t_start_unix, duration_s (on finish), req_ids}
Pairing: each B receive_kv_enter is followed (in time order) by exactly one
receive_kv_finish for the same req_ids set. The send_blocks event on A whose
t_start_unix falls strictly between enter.t_start_unix and
enter.t_start_unix + finish.duration_s is the pair-matched transfer.
Output:
per-(input_tokens) summary printed to stdout
--out JSON with full table + per-size aggregates
Per-stage breakdown (paper-grade vocabulary):
pure_transfer = send_blocks.duration_s
Network data movement: batch_transfer_sync_write wall-time on P.
rx_total = receive_kv_finish.duration_s
Total time on D from receive_kv() entry to receiving FINISH from P.
Includes ZMQ round-trip + P-side processing + pure_transfer.
rx_overhead = rx_total pure_transfer
ZMQ handshake + P-side scheduling/setup time.
We do NOT report queueing or B-side post-transfer decode here — those
require correlation with client-side t_step2 timestamps. This script
operates on log files alone.
"""
from __future__ import annotations
import argparse
import json
import statistics
from pathlib import Path
def load_events(path: Path) -> list[dict]:
rows = []
with path.open() as f:
for line in f:
try:
rows.append(json.loads(line))
except json.JSONDecodeError:
continue
return rows
def pair_b_events(b_events: list[dict]) -> list[dict]:
"""Pair receive_kv_enter with the matching receive_kv_finish (by req_ids)."""
open_by_key: dict[tuple, dict] = {}
paired = []
for e in b_events:
key = tuple(sorted(e.get("req_ids", [])))
if e["event"] == "receive_kv_enter":
open_by_key[key] = e
elif e["event"] == "receive_kv_finish":
enter = open_by_key.pop(key, None)
if enter is None:
continue
paired.append({
"req_ids": list(key),
"rx_t_start_unix": enter["t_start_unix"],
"rx_duration_s": e["duration_s"],
"rx_t_end_unix": enter["t_start_unix"] + e["duration_s"],
"tp_rank": e.get("tp_rank"),
})
return paired
def match_a_to_b(a_events: list[dict], b_pairs: list[dict]) -> list[dict]:
"""For each B pair, find the A send_blocks event whose t_start_unix is
strictly within [rx_t_start, rx_t_end]. Returns merged rows."""
a_by_t = sorted(
(e for e in a_events if e["event"] == "send_blocks"),
key=lambda e: e["t_start_unix"],
)
merged = []
j = 0
for p in b_pairs:
lo = p["rx_t_start_unix"]
hi = p["rx_t_end_unix"]
found = None
# advance j to the first A event in window
while j < len(a_by_t) and a_by_t[j]["t_start_unix"] < lo:
j += 1
if j < len(a_by_t):
a = a_by_t[j]
if a["t_start_unix"] <= hi:
found = a
j += 1
if found is None:
continue
kv_bytes = found["total_bytes"]
merged.append({
"input_tokens_est": kv_bytes // 98304,
"total_bytes": kv_bytes,
"pure_transfer_s": found["duration_s"],
"rx_total_s": p["rx_duration_s"],
"rx_overhead_s": max(0.0, p["rx_duration_s"] - found["duration_s"]),
"rx_t_start_unix": p["rx_t_start_unix"],
"send_t_start_unix": found["t_start_unix"],
"req_ids": p["req_ids"],
})
return merged
def aggregate(rows: list[dict]) -> list[dict]:
by_size: dict[int, list[dict]] = {}
for r in rows:
by_size.setdefault(r["input_tokens_est"], []).append(r)
summary = []
for size in sorted(by_size):
rs = by_size[size]
pts = [r["pure_transfer_s"] for r in rs]
rxs = [r["rx_total_s"] for r in rs]
ovs = [r["rx_overhead_s"] for r in rs]
size_bytes = rs[0]["total_bytes"]
size_mib = size_bytes / (1024 * 1024)
bw = [size_bytes / p / 1e9 for p in pts] # GB/s
summary.append({
"input_tokens": size,
"kv_mib": round(size_mib, 1),
"n": len(rs),
"pure_transfer_ms_mean": round(statistics.mean(pts) * 1000, 2),
"pure_transfer_ms_p50": round(statistics.median(pts) * 1000, 2),
"pure_transfer_ms_max": round(max(pts) * 1000, 2),
"pure_transfer_ms_min": round(min(pts) * 1000, 2),
"rx_total_ms_mean": round(statistics.mean(rxs) * 1000, 2),
"rx_overhead_ms_mean": round(statistics.mean(ovs) * 1000, 2),
"throughput_gbps_mean": round(statistics.mean(bw), 2),
"throughput_gbps_p50": round(statistics.median(bw), 2),
"throughput_gbps_max": round(max(bw), 2),
})
return summary
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--a-log", type=Path, required=True)
p.add_argument("--b-log", type=Path, required=True)
p.add_argument("--out", type=Path, default=None)
args = p.parse_args()
a_events = load_events(args.a_log)
b_events = load_events(args.b_log)
b_pairs = pair_b_events(b_events)
merged = match_a_to_b(a_events, b_pairs)
summary = aggregate(merged)
print(f"loaded {len(a_events)} A events, {len(b_events)} B events; "
f"paired {len(b_pairs)} B; matched {len(merged)} (A∩B)")
print()
print(f"{'in_tok':>8} {'KV_MiB':>8} {'n':>4} "
f"{'pure_ms':>10} {'rx_ms':>10} {'overhead_ms':>12} "
f"{'GB/s_p50':>10} {'GB/s_max':>10}")
for s in summary:
print(f"{s['input_tokens']:>8} {s['kv_mib']:>8.1f} {s['n']:>4} "
f"{s['pure_transfer_ms_p50']:>10.1f} "
f"{s['rx_total_ms_mean']:>10.1f} "
f"{s['rx_overhead_ms_mean']:>12.1f} "
f"{s['throughput_gbps_p50']:>10.2f} "
f"{s['throughput_gbps_max']:>10.2f}")
if args.out:
args.out.parent.mkdir(parents=True, exist_ok=True)
args.out.write_text(json.dumps({
"rows": merged,
"summary": summary,
}, indent=2))
print(f"\nwrote {args.out}")
if __name__ == "__main__":
main()