#!/usr/bin/env python3 """Plot MB2 transfer-time + bandwidth curves.""" from __future__ import annotations import argparse import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np def main() -> None: p = argparse.ArgumentParser() p.add_argument("--breakdown", type=Path, required=True, help="JSON from analyze_mb2.py") p.add_argument("--out-time", type=Path, default=Path("figs/mb2_transfer_time.png")) p.add_argument("--out-bw", type=Path, default=Path("figs/mb2_transfer_bw.png")) p.add_argument("--label", default="intra-node (kv_both, dash1 GPU 0+1)") args = p.parse_args() d = json.loads(args.breakdown.read_text()) # `rows` is optional (send-only analyzer skips per-request joining). # Drop the spurious 16-token events from any rows present. if "rows" in d: _ = [r for r in d["rows"] if r["input_tokens_est"] >= 64] summary = [s for s in d["summary"] if s["input_tokens"] >= 64] kv_mib = [s["kv_mib"] for s in summary] p50_ms = [s["pure_transfer_ms_p50"] for s in summary] min_ms = [s["pure_transfer_ms_min"] for s in summary] max_ms = [s["pure_transfer_ms_max"] for s in summary] bw_p50 = [s["throughput_gbps_p50"] for s in summary] bw_max = [s["throughput_gbps_max"] for s in summary] # ---- pure transfer time vs KV size (log-log) ---- fig, ax = plt.subplots(figsize=(8, 5)) ax.errorbar(kv_mib, p50_ms, yerr=[np.array(p50_ms) - np.array(min_ms), np.array(max_ms) - np.array(p50_ms)], fmt="o-", color="#1f77b4", lw=2, markersize=7, capsize=4, label="pure_transfer (batch_transfer_sync_write)") # 9.7 GB/s reference line ref_bw_gbps = 9.7 ref_x = np.array(kv_mib) ref_y_ms = (ref_x * 1024 * 1024) / (ref_bw_gbps * 1e9) * 1000 ax.plot(ref_x, ref_y_ms, "--", color="#888", alpha=0.7, label=f"ideal {ref_bw_gbps:.1f} GB/s reference") # agentic-relevant horizontal markers for name, ms in [("typical chatbot decode (~5 s)", 5000), ("typical agentic decode (~50–200 ms)", 100)]: ax.axhline(ms, color="#c44e52", lw=0.8, ls=":", alpha=0.5) ax.text(kv_mib[-1] * 0.85, ms * 1.15, name, fontsize=8, color="#7a1d1d", ha="right") # p99 agentic KV vertical marker ax.axvline(11500, color="#c44e52", lw=0.8, ls=":", alpha=0.5) ax.text(11500, 0.7, "p99 agentic\nrequest 11.5 GiB", fontsize=8, color="#7a1d1d", ha="center") ax.set_xscale("log") ax.set_yscale("log") ax.set_xlabel("KV transfer size (MiB)") ax.set_ylabel("Pure transfer time (ms, log)") ax.set_title(f"MB2: KV transfer time vs size — {args.label}") ax.grid(True, which="both", alpha=0.3) ax.legend(loc="upper left", fontsize=9) args.out_time.parent.mkdir(parents=True, exist_ok=True) fig.tight_layout() fig.savefig(args.out_time, dpi=150) plt.close(fig) print(f"wrote {args.out_time}") # ---- bandwidth vs KV size ---- fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(kv_mib, bw_p50, "o-", color="#2ca02c", lw=2, markersize=7, label="bandwidth p50") ax.plot(kv_mib, bw_max, "x--", color="#ff7f0e", lw=1.5, markersize=8, label="bandwidth max") ax.axhline(9.7, color="#888", ls="--", alpha=0.6, label="steady-state ≈ 9.7 GB/s") ax.set_xscale("log") ax.set_xlabel("KV transfer size (MiB)") ax.set_ylabel("Effective bandwidth (GB/s)") ax.set_ylim(0, 12) ax.set_title(f"MB2: KV transfer bandwidth vs size — {args.label}") ax.grid(True, which="both", alpha=0.3) ax.legend(loc="lower left", fontsize=9) args.out_bw.parent.mkdir(parents=True, exist_ok=True) fig.tight_layout() fig.savefig(args.out_bw, dpi=150) plt.close(fig) print(f"wrote {args.out_bw}") if __name__ == "__main__": main()