#!/usr/bin/env python3 """Decompose MB2 transfer events into the per-stage breakdown. Inputs: --a-log P-side jsonl with `send_blocks` events {event=send_blocks, total_bytes, duration_s, t_start_unix, ...} --b-log D-side jsonl with `receive_kv_enter` and `receive_kv_finish` events {event=receive_kv_*, t_start_unix, duration_s (on finish), req_ids} Pairing: each B receive_kv_enter is followed (in time order) by exactly one receive_kv_finish for the same req_ids set. The send_blocks event on A whose t_start_unix falls strictly between enter.t_start_unix and enter.t_start_unix + finish.duration_s is the pair-matched transfer. Output: per-(input_tokens) summary printed to stdout --out JSON with full table + per-size aggregates Per-stage breakdown (paper-grade vocabulary): pure_transfer = send_blocks.duration_s Network data movement: batch_transfer_sync_write wall-time on P. rx_total = receive_kv_finish.duration_s Total time on D from receive_kv() entry to receiving FINISH from P. Includes ZMQ round-trip + P-side processing + pure_transfer. rx_overhead = rx_total − pure_transfer ZMQ handshake + P-side scheduling/setup time. We do NOT report queueing or B-side post-transfer decode here — those require correlation with client-side t_step2 timestamps. This script operates on log files alone. """ from __future__ import annotations import argparse import json import statistics from pathlib import Path def load_events(path: Path) -> list[dict]: rows = [] with path.open() as f: for line in f: try: rows.append(json.loads(line)) except json.JSONDecodeError: continue return rows def pair_b_events(b_events: list[dict]) -> list[dict]: """Pair receive_kv_enter with the matching receive_kv_finish (by req_ids).""" open_by_key: dict[tuple, dict] = {} paired = [] for e in b_events: key = tuple(sorted(e.get("req_ids", []))) if e["event"] == "receive_kv_enter": open_by_key[key] = e elif e["event"] == "receive_kv_finish": enter = open_by_key.pop(key, None) if enter is None: continue paired.append({ "req_ids": list(key), "rx_t_start_unix": enter["t_start_unix"], "rx_duration_s": e["duration_s"], "rx_t_end_unix": enter["t_start_unix"] + e["duration_s"], "tp_rank": e.get("tp_rank"), }) return paired def match_a_to_b(a_events: list[dict], b_pairs: list[dict]) -> list[dict]: """For each B pair, find the A send_blocks event whose t_start_unix is strictly within [rx_t_start, rx_t_end]. Returns merged rows.""" a_by_t = sorted( (e for e in a_events if e["event"] == "send_blocks"), key=lambda e: e["t_start_unix"], ) merged = [] j = 0 for p in b_pairs: lo = p["rx_t_start_unix"] hi = p["rx_t_end_unix"] found = None # advance j to the first A event in window while j < len(a_by_t) and a_by_t[j]["t_start_unix"] < lo: j += 1 if j < len(a_by_t): a = a_by_t[j] if a["t_start_unix"] <= hi: found = a j += 1 if found is None: continue kv_bytes = found["total_bytes"] merged.append({ "input_tokens_est": kv_bytes // 98304, "total_bytes": kv_bytes, "pure_transfer_s": found["duration_s"], "rx_total_s": p["rx_duration_s"], "rx_overhead_s": max(0.0, p["rx_duration_s"] - found["duration_s"]), "rx_t_start_unix": p["rx_t_start_unix"], "send_t_start_unix": found["t_start_unix"], "req_ids": p["req_ids"], }) return merged def aggregate(rows: list[dict]) -> list[dict]: by_size: dict[int, list[dict]] = {} for r in rows: by_size.setdefault(r["input_tokens_est"], []).append(r) summary = [] for size in sorted(by_size): rs = by_size[size] pts = [r["pure_transfer_s"] for r in rs] rxs = [r["rx_total_s"] for r in rs] ovs = [r["rx_overhead_s"] for r in rs] size_bytes = rs[0]["total_bytes"] size_mib = size_bytes / (1024 * 1024) bw = [size_bytes / p / 1e9 for p in pts] # GB/s summary.append({ "input_tokens": size, "kv_mib": round(size_mib, 1), "n": len(rs), "pure_transfer_ms_mean": round(statistics.mean(pts) * 1000, 2), "pure_transfer_ms_p50": round(statistics.median(pts) * 1000, 2), "pure_transfer_ms_max": round(max(pts) * 1000, 2), "pure_transfer_ms_min": round(min(pts) * 1000, 2), "rx_total_ms_mean": round(statistics.mean(rxs) * 1000, 2), "rx_overhead_ms_mean": round(statistics.mean(ovs) * 1000, 2), "throughput_gbps_mean": round(statistics.mean(bw), 2), "throughput_gbps_p50": round(statistics.median(bw), 2), "throughput_gbps_max": round(max(bw), 2), }) return summary def main() -> None: p = argparse.ArgumentParser() p.add_argument("--a-log", type=Path, required=True) p.add_argument("--b-log", type=Path, required=True) p.add_argument("--out", type=Path, default=None) args = p.parse_args() a_events = load_events(args.a_log) b_events = load_events(args.b_log) b_pairs = pair_b_events(b_events) merged = match_a_to_b(a_events, b_pairs) summary = aggregate(merged) print(f"loaded {len(a_events)} A events, {len(b_events)} B events; " f"paired {len(b_pairs)} B; matched {len(merged)} (A∩B)") print() print(f"{'in_tok':>8} {'KV_MiB':>8} {'n':>4} " f"{'pure_ms':>10} {'rx_ms':>10} {'overhead_ms':>12} " f"{'GB/s_p50':>10} {'GB/s_max':>10}") for s in summary: print(f"{s['input_tokens']:>8} {s['kv_mib']:>8.1f} {s['n']:>4} " f"{s['pure_transfer_ms_p50']:>10.1f} " f"{s['rx_total_ms_mean']:>10.1f} " f"{s['rx_overhead_ms_mean']:>12.1f} " f"{s['throughput_gbps_p50']:>10.2f} " f"{s['throughput_gbps_max']:>10.2f}") if args.out: args.out.parent.mkdir(parents=True, exist_ok=True) args.out.write_text(json.dumps({ "rows": merged, "summary": summary, }, indent=2)) print(f"\nwrote {args.out}") if __name__ == "__main__": main()