Full sweep result on dash1 GPU 0+1 with vanilla vLLM 0.18.1 +
mooncake-transfer-engine 0.3.11, kv_both connector. Per-stage decomposition
via the instrumentation patch (analyze_mb2.py pairs A's send_blocks with
B's receive_kv enter/finish by time window).
Steady-state (1k..32k tokens, 96 MiB..3 GiB KV):
pure_transfer ≈ size / 9.7 GB/s
rx_overhead ≈ 2–3 ms (ZMQ handshake + P-side setup)
bandwidth ≈ 9.6–10.1 GB/s, very stable
Large-size regime (65k..131k tokens, 6..12 GiB):
p50 bandwidth collapses to 3.4–4.5 GB/s
max bandwidth still hits ~9.7 GB/s (some runs achieve it)
p99 agentic request (11.5 GiB) lands here
Implication for §3.2 PD-disaggregation cost argument:
median agentic decode = 50–200 ms (tool-call JSON output)
median agentic-tail KV transfer (p99 11.5 GiB):
best case (9.7 GB/s) ≈ 1.19 s
observed range 1.5 – 10 s
⇒ KV transfer is 8–100× larger than the decode it enables.
This is intra-node — the lower-bound transfer cost. Inter-node RDMA
will be slower; that's MB2 phase 2.
Adds:
- analyze_mb2.py: pair A.send_blocks ↔ B.receive_kv by time window;
per-size aggregation (n, ms_p50, ms_min/max, GB/s_p50/max)
- plot_mb2.py: log-log transfer-time chart + bandwidth-vs-size chart
- analysis/mb2/A_intra_kvboth.jsonl, B_intra_kvboth.jsonl: raw events
(51 + 102 events including the sanity preamble)
- analysis/mb2/intra_kvboth_breakdown.json: paired and aggregated
- figs/mb2_transfer_time_intra.png, figs/mb2_transfer_bw_intra.png
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
179 lines
6.5 KiB
Python
179 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
||
"""Decompose MB2 transfer events into the per-stage breakdown.
|
||
|
||
Inputs:
|
||
--a-log P-side jsonl with `send_blocks` events
|
||
{event=send_blocks, total_bytes, duration_s, t_start_unix, ...}
|
||
--b-log D-side jsonl with `receive_kv_enter` and `receive_kv_finish` events
|
||
{event=receive_kv_*, t_start_unix, duration_s (on finish), req_ids}
|
||
|
||
Pairing: each B receive_kv_enter is followed (in time order) by exactly one
|
||
receive_kv_finish for the same req_ids set. The send_blocks event on A whose
|
||
t_start_unix falls strictly between enter.t_start_unix and
|
||
enter.t_start_unix + finish.duration_s is the pair-matched transfer.
|
||
|
||
Output:
|
||
per-(input_tokens) summary printed to stdout
|
||
--out JSON with full table + per-size aggregates
|
||
|
||
Per-stage breakdown (paper-grade vocabulary):
|
||
|
||
pure_transfer = send_blocks.duration_s
|
||
Network data movement: batch_transfer_sync_write wall-time on P.
|
||
rx_total = receive_kv_finish.duration_s
|
||
Total time on D from receive_kv() entry to receiving FINISH from P.
|
||
Includes ZMQ round-trip + P-side processing + pure_transfer.
|
||
rx_overhead = rx_total − pure_transfer
|
||
ZMQ handshake + P-side scheduling/setup time.
|
||
|
||
We do NOT report queueing or B-side post-transfer decode here — those
|
||
require correlation with client-side t_step2 timestamps. This script
|
||
operates on log files alone.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import statistics
|
||
from pathlib import Path
|
||
|
||
|
||
def load_events(path: Path) -> list[dict]:
|
||
rows = []
|
||
with path.open() as f:
|
||
for line in f:
|
||
try:
|
||
rows.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
continue
|
||
return rows
|
||
|
||
|
||
def pair_b_events(b_events: list[dict]) -> list[dict]:
|
||
"""Pair receive_kv_enter with the matching receive_kv_finish (by req_ids)."""
|
||
open_by_key: dict[tuple, dict] = {}
|
||
paired = []
|
||
for e in b_events:
|
||
key = tuple(sorted(e.get("req_ids", [])))
|
||
if e["event"] == "receive_kv_enter":
|
||
open_by_key[key] = e
|
||
elif e["event"] == "receive_kv_finish":
|
||
enter = open_by_key.pop(key, None)
|
||
if enter is None:
|
||
continue
|
||
paired.append({
|
||
"req_ids": list(key),
|
||
"rx_t_start_unix": enter["t_start_unix"],
|
||
"rx_duration_s": e["duration_s"],
|
||
"rx_t_end_unix": enter["t_start_unix"] + e["duration_s"],
|
||
"tp_rank": e.get("tp_rank"),
|
||
})
|
||
return paired
|
||
|
||
|
||
def match_a_to_b(a_events: list[dict], b_pairs: list[dict]) -> list[dict]:
|
||
"""For each B pair, find the A send_blocks event whose t_start_unix is
|
||
strictly within [rx_t_start, rx_t_end]. Returns merged rows."""
|
||
a_by_t = sorted(
|
||
(e for e in a_events if e["event"] == "send_blocks"),
|
||
key=lambda e: e["t_start_unix"],
|
||
)
|
||
merged = []
|
||
j = 0
|
||
for p in b_pairs:
|
||
lo = p["rx_t_start_unix"]
|
||
hi = p["rx_t_end_unix"]
|
||
found = None
|
||
# advance j to the first A event in window
|
||
while j < len(a_by_t) and a_by_t[j]["t_start_unix"] < lo:
|
||
j += 1
|
||
if j < len(a_by_t):
|
||
a = a_by_t[j]
|
||
if a["t_start_unix"] <= hi:
|
||
found = a
|
||
j += 1
|
||
if found is None:
|
||
continue
|
||
kv_bytes = found["total_bytes"]
|
||
merged.append({
|
||
"input_tokens_est": kv_bytes // 98304,
|
||
"total_bytes": kv_bytes,
|
||
"pure_transfer_s": found["duration_s"],
|
||
"rx_total_s": p["rx_duration_s"],
|
||
"rx_overhead_s": max(0.0, p["rx_duration_s"] - found["duration_s"]),
|
||
"rx_t_start_unix": p["rx_t_start_unix"],
|
||
"send_t_start_unix": found["t_start_unix"],
|
||
"req_ids": p["req_ids"],
|
||
})
|
||
return merged
|
||
|
||
|
||
def aggregate(rows: list[dict]) -> list[dict]:
|
||
by_size: dict[int, list[dict]] = {}
|
||
for r in rows:
|
||
by_size.setdefault(r["input_tokens_est"], []).append(r)
|
||
summary = []
|
||
for size in sorted(by_size):
|
||
rs = by_size[size]
|
||
pts = [r["pure_transfer_s"] for r in rs]
|
||
rxs = [r["rx_total_s"] for r in rs]
|
||
ovs = [r["rx_overhead_s"] for r in rs]
|
||
size_bytes = rs[0]["total_bytes"]
|
||
size_mib = size_bytes / (1024 * 1024)
|
||
bw = [size_bytes / p / 1e9 for p in pts] # GB/s
|
||
summary.append({
|
||
"input_tokens": size,
|
||
"kv_mib": round(size_mib, 1),
|
||
"n": len(rs),
|
||
"pure_transfer_ms_mean": round(statistics.mean(pts) * 1000, 2),
|
||
"pure_transfer_ms_p50": round(statistics.median(pts) * 1000, 2),
|
||
"pure_transfer_ms_max": round(max(pts) * 1000, 2),
|
||
"pure_transfer_ms_min": round(min(pts) * 1000, 2),
|
||
"rx_total_ms_mean": round(statistics.mean(rxs) * 1000, 2),
|
||
"rx_overhead_ms_mean": round(statistics.mean(ovs) * 1000, 2),
|
||
"throughput_gbps_mean": round(statistics.mean(bw), 2),
|
||
"throughput_gbps_p50": round(statistics.median(bw), 2),
|
||
"throughput_gbps_max": round(max(bw), 2),
|
||
})
|
||
return summary
|
||
|
||
|
||
def main() -> None:
|
||
p = argparse.ArgumentParser()
|
||
p.add_argument("--a-log", type=Path, required=True)
|
||
p.add_argument("--b-log", type=Path, required=True)
|
||
p.add_argument("--out", type=Path, default=None)
|
||
args = p.parse_args()
|
||
|
||
a_events = load_events(args.a_log)
|
||
b_events = load_events(args.b_log)
|
||
b_pairs = pair_b_events(b_events)
|
||
merged = match_a_to_b(a_events, b_pairs)
|
||
summary = aggregate(merged)
|
||
|
||
print(f"loaded {len(a_events)} A events, {len(b_events)} B events; "
|
||
f"paired {len(b_pairs)} B; matched {len(merged)} (A∩B)")
|
||
print()
|
||
print(f"{'in_tok':>8} {'KV_MiB':>8} {'n':>4} "
|
||
f"{'pure_ms':>10} {'rx_ms':>10} {'overhead_ms':>12} "
|
||
f"{'GB/s_p50':>10} {'GB/s_max':>10}")
|
||
for s in summary:
|
||
print(f"{s['input_tokens']:>8} {s['kv_mib']:>8.1f} {s['n']:>4} "
|
||
f"{s['pure_transfer_ms_p50']:>10.1f} "
|
||
f"{s['rx_total_ms_mean']:>10.1f} "
|
||
f"{s['rx_overhead_ms_mean']:>12.1f} "
|
||
f"{s['throughput_gbps_p50']:>10.2f} "
|
||
f"{s['throughput_gbps_max']:>10.2f}")
|
||
|
||
if args.out:
|
||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||
args.out.write_text(json.dumps({
|
||
"rows": merged,
|
||
"summary": summary,
|
||
}, indent=2))
|
||
print(f"\nwrote {args.out}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|