Files
agentic-kvc/microbench/connector_tax/cache_sweep/analyze_dst_migration.py
Gahow Wang 1262c9c22e Migration transfer-cost study: KV transfer is slow on busy GPUs
MIGRATION_TRANSFER_COST.md: under real load, migration KV transfer runs at
~3 GB/s vs ~10 GB/s idle. Decomposed (instruments + MB6 microbench) into
~55% RDMA-actual (HBM/PCIe contention with running kernels: 7.6->4.0 GB/s)
+ ~45% control-plane GIL starvation during long prefills. Reproduced on a
fresh upstream venv (byte-identical transfer path) -> upstream/hardware
inherent, not our patch. Layerwise is the wrong lever; the tax is structural
on a loaded agentic cluster. Includes mb6_transfer_under_load + run_mb6,
instrument_dst_migration/mooncake, and the dst/transfer decomposition analyzers.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:01 +08:00

334 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
"""Analyze dst-side migration breakdown for unified_v3 runs.
Joins the proxy `breakdown.json` (per-request route + phase timestamps)
with the dst engine per-PID logs written by
`instrument_dst_migration.py` (`dm_mig_pid<pid>.jsonl`), to attribute
each migration's dst-side wall-clock into:
T_relay proxy decode-sent → dst arrival
T_admission_pre_kv dst arrival → status=WAITING_FOR_REMOTE_KVS
(waiting in dst's scheduler queue before KV pull
is even initiated)
T_kv_pull WAITING_FOR_REMOTE_KVS → finished_recving
(the actual RDMA transfer + connector ack)
T_admission_post_kv finished_recving → first time in self.running
(KV ready, waiting for batch slot)
T_first_iter first scheduled → first generated token
(one decode-iter compute + sampler latency)
Layerwise transfer can at best eliminate T_kv_pull. Everything else is
queueing or compute that layerwise does not touch.
Usage:
python analyze_dst_migration.py \
--proxy-breakdown <RUNDIR>/breakdown.json \
--dst-log-dir <DST_LOG_DIR>
[--output <RUNDIR>/dst_migration_breakdown.csv]
[--plot <RUNDIR>/dst_migration_breakdown.png]
"""
from __future__ import annotations
import argparse
import json
import math
import os
import re
import statistics
import sys
from pathlib import Path
def _core_req_id(rid: str) -> str:
"""Normalize a vLLM engine req_id back to the proxy's request_id.
vLLM wraps the proxy id `S:T:U:N` as `cmpl-S:T:U:N-<dp_rank>-<hex>`.
Strip the `cmpl-` prefix and the trailing `-<digits>-<hex>` suffix so
it joins against the proxy `breakdown.json` request_id.
"""
if not rid:
return rid
s = rid
if s.startswith("cmpl-"):
s = s[len("cmpl-"):]
m = re.match(r"^(.*)-\d+-[0-9a-fA-F]+$", s)
if m:
s = m.group(1)
return s
def _pct(vals: list[float], q: float) -> float:
if not vals:
return float("nan")
vs = sorted(vals)
i = max(0, min(len(vs) - 1, int(math.ceil(q * len(vs))) - 1))
return vs[i]
def _summary(name: str, vals: list[float]) -> dict:
if not vals:
return {"name": name, "n": 0}
return {
"name": name,
"n": len(vals),
"mean_s": statistics.mean(vals),
"p50_s": _pct(vals, 0.5),
"p90_s": _pct(vals, 0.9),
"p99_s": _pct(vals, 0.99),
"max_s": max(vals),
"sum_s": sum(vals),
}
def load_dst_log(dst_log_dir: Path) -> dict[str, dict]:
by_req: dict[str, dict] = {}
found_files = sorted(dst_log_dir.glob("dm_mig_pid*.jsonl"))
print(f"[analyze] dst log files: {len(found_files)} under {dst_log_dir}")
for f in found_files:
with f.open() as fh:
for line in fh:
try:
rec = json.loads(line)
except Exception:
continue
rid = rec.get("req_id")
if not rid:
continue
key = _core_req_id(rid)
rec["_raw_req_id"] = rid
# If a req shows up twice (shouldn't, but be safe), prefer the
# one with t_first_token_unix populated.
prev = by_req.get(key)
if prev is None or (
rec.get("t_first_token_unix") and
not prev.get("t_first_token_unix")
):
by_req[key] = rec
print(f"[analyze] unique dst records: {len(by_req)}")
return by_req
def load_proxy_breakdown(path: Path) -> list[dict]:
with path.open() as fh:
data = json.load(fh)
assert isinstance(data, list), f"unexpected breakdown.json shape: {type(data)}"
return data
def decompose(proxy_recs: list[dict], dst_by_req: dict[str, dict]) -> list[dict]:
"""Build per-migration breakdown rows by joining proxy + dst by req_id."""
rows: list[dict] = []
migrations = [x for x in proxy_recs if x.get("route_class") == "PD_SEP_V2"]
print(f"[analyze] proxy migrations: {len(migrations)} "
f"(of {len(proxy_recs)} total requests)")
miss_in_dst = 0
missing_phases = 0
for p in migrations:
rid = p.get("request_id")
dst = dst_by_req.get(rid)
if dst is None:
miss_in_dst += 1
continue
if dst.get("t_first_token_unix") is None:
missing_phases += 1
# still include the row but mark phases as NaN downstream
t_decode_sent = p.get("t_decode_sent_unix")
t_first_tok = p.get("t_first_token_unix")
t_arrival = dst.get("t_arrival_unix")
t_wait_kvs = dst.get("t_wait_for_kvs_unix")
t_kv_done = dst.get("t_kv_recv_done_unix")
t_first_sched = dst.get("t_first_scheduled_unix")
t_first_tok_dst = dst.get("t_first_token_unix")
def _diff(a, b):
if a is None or b is None:
return None
return float(a) - float(b)
rows.append({
"request_id": rid,
"session_id": p.get("session_id"),
"input_length": p.get("input_length"),
"v3_new_local": p.get("v3_new_local"),
"v3_target_idx": p.get("v3_target_idx") or p.get("v3_decode_target_idx"),
"arrival_n_running": (dst.get("arrival_state") or {}).get("n_running"),
"arrival_n_waiting": (dst.get("arrival_state") or {}).get("n_waiting"),
"arrival_pending_prefill_tok": (dst.get("arrival_state") or {}).get("pending_prefill_tok"),
"arrival_n_waiting_for_kvs": (dst.get("arrival_state") or {}).get("n_waiting_for_kvs"),
# Phase durations (seconds)
"T_proxy_total_dst_first_token_s": _diff(t_first_tok, t_decode_sent),
"T_relay_s": _diff(t_arrival, t_decode_sent),
"T_admission_pre_kv_s": _diff(t_wait_kvs, t_arrival),
"T_kv_pull_s": _diff(t_kv_done, t_wait_kvs),
"T_admission_post_kv_s": _diff(t_first_sched, t_kv_done),
"T_first_iter_s": _diff(t_first_tok_dst, t_first_sched),
# Raw timestamps for debugging
"t_decode_sent_unix": t_decode_sent,
"t_dst_arrival_unix": t_arrival,
"t_dst_wait_for_kvs_unix": t_wait_kvs,
"t_dst_kv_recv_done_unix": t_kv_done,
"t_dst_first_scheduled_unix": t_first_sched,
"t_dst_first_token_unix": t_first_tok_dst,
"t_proxy_first_token_unix": t_first_tok,
})
print(f"[analyze] missing in dst log: {miss_in_dst}")
print(f"[analyze] dst record incomplete (no t_first_token): {missing_phases}")
return rows
def emit_summary(rows: list[dict]) -> None:
if not rows:
print("[analyze] no rows — nothing to summarize.")
return
phase_keys = [
"T_proxy_total_dst_first_token_s",
"T_relay_s",
"T_admission_pre_kv_s",
"T_kv_pull_s",
"T_admission_post_kv_s",
"T_first_iter_s",
]
print()
print("=" * 88)
print(f"Migration dst-side phase breakdown (n_migrations={len(rows)})")
print("=" * 88)
print(f"{'phase':<36} {'n':>4} {'mean(s)':>9} {'p50':>8} {'p90':>8} "
f"{'p99':>8} {'max':>8} {'sum(s)':>9}")
print("-" * 88)
for k in phase_keys:
vals = [r[k] for r in rows if r.get(k) is not None]
if not vals:
print(f"{k:<36} {'n/a':>4}")
continue
s = _summary(k, vals)
print(f"{k:<36} {s['n']:>4} {s['mean_s']:>9.3f} {s['p50_s']:>8.3f} "
f"{s['p90_s']:>8.3f} {s['p99_s']:>8.3f} {s['max_s']:>8.3f} "
f"{s['sum_s']:>9.2f}")
print()
print("Aggregate attribution (sum across all migrations):")
sums = {}
for k in ("T_relay_s", "T_admission_pre_kv_s", "T_kv_pull_s",
"T_admission_post_kv_s", "T_first_iter_s"):
sums[k] = sum(r[k] for r in rows if r.get(k) is not None)
total = sum(sums.values())
total_proxy = sum(r["T_proxy_total_dst_first_token_s"] for r in rows
if r.get("T_proxy_total_dst_first_token_s") is not None)
print(f" decomposed sum : {total:>8.2f} s")
print(f" proxy total sum : {total_proxy:>8.2f} s "
f"(should be ~equal; gap = uninstrumented)")
if total > 0:
for k, v in sums.items():
print(f" {k:<28} {v:>8.2f} s ({v/total*100:5.1f} %)")
# Headline: "How much could layerwise save?"
layerwise_addressable = sums.get("T_kv_pull_s", 0.0)
queue_residual = sum(v for k, v in sums.items() if k != "T_kv_pull_s")
print()
print("Layerwise-addressable vs queue-residual:")
print(f" T_kv_pull_s (addressable by layerwise) : {layerwise_addressable:>8.2f} s "
f"({layerwise_addressable / total * 100 if total else 0:5.1f} %)")
print(f" everything else (queue/admission/iter) : {queue_residual:>8.2f} s "
f"({queue_residual / total * 100 if total else 0:5.1f} %)")
def write_csv(rows: list[dict], path: Path) -> None:
import csv
if not rows:
path.write_text("")
return
fields = list(rows[0].keys())
with path.open("w", newline="") as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
w.writerows(rows)
print(f"[analyze] wrote CSV: {path} (n={len(rows)})")
def maybe_plot(rows: list[dict], out_path: Path) -> None:
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
except Exception as e:
print(f"[analyze] matplotlib unavailable ({e}); skipping plot.")
return
if not rows:
return
rows_sorted = sorted(
rows,
key=lambda r: r.get("T_proxy_total_dst_first_token_s") or 0.0,
)
n = len(rows_sorted)
idx = list(range(n))
def col(k):
return [(r.get(k) or 0.0) for r in rows_sorted]
relay = col("T_relay_s")
pre = col("T_admission_pre_kv_s")
pull = col("T_kv_pull_s")
post = col("T_admission_post_kv_s")
first_iter = col("T_first_iter_s")
fig, ax = plt.subplots(figsize=(11, 5))
bot = [0.0] * n
for vals, label, color in [
(relay, "HTTP relay", "#cccccc"),
(pre, "admission pre-KV", "#f4a261"),
(pull, "KV pull (layerwise-addressable)", "#e76f51"),
(post, "admission post-KV", "#2a9d8f"),
(first_iter, "first decode iter", "#264653"),
]:
ax.bar(idx, vals, bottom=bot, color=color, label=label, width=0.85)
bot = [b + v for b, v in zip(bot, vals)]
ax.set_xticks(idx)
ax.set_xticklabels([str(i + 1) for i in idx], rotation=0, fontsize=8)
ax.set_xlabel("Migrated request (sorted by total dst wait, ascending)")
ax.set_ylabel("Time (s)")
ax.set_title("Per-migration dst-side phase breakdown (v3 unified_v3 run)")
ax.legend(loc="upper left", fontsize=9)
ax.grid(axis="y", linestyle=":", alpha=0.5)
fig.tight_layout()
fig.savefig(out_path, dpi=120)
plt.close(fig)
print(f"[analyze] wrote plot: {out_path}")
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--proxy-breakdown", type=Path, required=True)
p.add_argument("--dst-log-dir", type=Path, required=True)
p.add_argument("--output", type=Path, default=None,
help="CSV path (default: <run>/dst_migration_breakdown.csv)")
p.add_argument("--plot", type=Path, default=None,
help="PNG path (default: <run>/dst_migration_breakdown.png)")
args = p.parse_args()
if not args.proxy_breakdown.is_file():
sys.exit(f"missing proxy breakdown: {args.proxy_breakdown}")
if not args.dst_log_dir.is_dir():
sys.exit(f"missing dst log dir: {args.dst_log_dir}")
run_dir = args.proxy_breakdown.parent
out_csv = args.output or (run_dir / "dst_migration_breakdown.csv")
out_png = args.plot or (run_dir / "dst_migration_breakdown.png")
proxy_recs = load_proxy_breakdown(args.proxy_breakdown)
dst_by_req = load_dst_log(args.dst_log_dir)
rows = decompose(proxy_recs, dst_by_req)
emit_summary(rows)
write_csv(rows, out_csv)
maybe_plot(rows, out_png)
if __name__ == "__main__":
main()