Files
agentic-kvc/microbench/connector_tax/cache_sweep/analyze_migration_log.py
Gahow Wang 1262c9c22e Migration transfer-cost study: KV transfer is slow on busy GPUs
MIGRATION_TRANSFER_COST.md: under real load, migration KV transfer runs at
~3 GB/s vs ~10 GB/s idle. Decomposed (instruments + MB6 microbench) into
~55% RDMA-actual (HBM/PCIe contention with running kernels: 7.6->4.0 GB/s)
+ ~45% control-plane GIL starvation during long prefills. Reproduced on a
fresh upstream venv (byte-identical transfer path) -> upstream/hardware
inherent, not our patch. Layerwise is the wrong lever; the tax is structural
on a loaded agentic cluster. Includes mb6_transfer_under_load + run_mb6,
instrument_dst_migration/mooncake, and the dst/transfer decomposition analyzers.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:01 +08:00

134 lines
5.6 KiB
Python

#!/usr/bin/env python3
"""Per-migration log + per-instance summary for a v3 trace replay.
Reads <run_dir>/breakdown.json and <run_dir>/metrics.jsonl and emits:
1. A row per migration showing src→dst, per-side state snapshots, and
the resulting TTFT.
2. Histograms: migrations received per inst, sent per inst, all
(src→dst) pairs.
3. Post-rotation tail: how many turns of migrated sessions ended up on
each inst (downstream impact of rotation).
4. Anti-hotspot signal: recent_mig_received_in_window at decision time.
Run any v3 replay through this to spot pathological clustering of
migrations on the same dst within a short window.
Usage:
python analyze_migration_log.py <RUN_DIR>
where <RUN_DIR> contains breakdown.json + metrics.jsonl (i.e. the proxy's
per-policy output folder, e.g. .../b3_v3_20260527_1344/unified_v3).
"""
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path
def main(run_dir: Path) -> None:
bd = json.load(open(run_dir / "breakdown.json"))
m = {json.loads(l)["request_id"]: json.loads(l)
for l in open(run_dir / "metrics.jsonl")}
mig = [e for e in bd if e.get("v3_migrate")]
mig.sort(key=lambda x: x.get("t_decision_unix", 0))
print(f"=== {len(mig)} migrations in {run_dir.name} ===\n")
cols = (
"#", "t_rel", "session", "turn",
"src", "dst", "src_nreq", "src_dec_tok",
"dst_nreq", "dst_cache", "dst_recent_recv",
"inlen", "self_ttft_ms",
)
print(" " + " ".join(f"{c:>13}" for c in cols))
print("-" * (15 * len(cols)))
t0 = mig[0]["t_decision_unix"] if mig else 0
for i, e in enumerate(mig):
rid = e["request_id"]
src_idx = e.get("v3_src_idx", e["chosen_idx"])
dst_idx = e.get("v3_target_idx", -1)
src_state = e.get("v3_src_state") or {}
dst_state = e.get("v3_target_state") or {}
cands = {c["idx"]: c for c in e.get("candidate_scores", [])}
# Fall back to candidate_scores if dedicated v3_*_state fields aren't present.
src_nreq = src_state.get("num_requests", cands.get(src_idx, {}).get("num_requests", "-"))
src_dec_tok = src_state.get("ongoing_decode_tokens",
cands.get(src_idx, {}).get("ongoing_decode_tokens", "-"))
dst_nreq = dst_state.get("num_requests", cands.get(dst_idx, {}).get("num_requests", "-"))
dst_cache = e.get("v3_target_cache_hit", dst_state.get("cache_hit_estimate", 0))
dst_recent = e.get("v3_target_recent_received",
dst_state.get("recent_mig_received_in_window", "-"))
inlen = e.get("input_length") or m.get(rid, {}).get("input_length", 0)
ttft = m.get(rid, {}).get("ttft_s") or 0
t_rel = e["t_decision_unix"] - t0
turn = m.get(rid, {}).get("turn_id", "?")
print(
f" {i+1:>13} {t_rel:>13.1f} {e['session_id']:>13} {turn:>13} "
f"{src_idx:>13} {dst_idx:>13} {src_nreq:>13} {src_dec_tok:>13} "
f"{dst_nreq:>13} {dst_cache:>13} {dst_recent:>13} "
f"{inlen:>13} {ttft*1000:>13.0f}"
)
# Aggregate counts
print("\n=== Migrations TO each instance ===")
to_count = Counter(e.get("v3_target_idx", -1) for e in mig)
for idx in range(8):
print(f" inst_{idx}: {to_count.get(idx, 0)} migrations received")
print("\n=== Migrations FROM each instance ===")
from_count = Counter(e.get("v3_src_idx", e["chosen_idx"]) for e in mig)
for idx in range(8):
print(f" inst_{idx}: {from_count.get(idx, 0)} migrations sent")
print("\n=== Migration pairs (src→dst, count) ===")
pair_count = Counter(
(e.get("v3_src_idx", e["chosen_idx"]), e.get("v3_target_idx", -1))
for e in mig
)
for (s, d), n in sorted(pair_count.items(), key=lambda x: -x[1]):
print(f" {s}{d}: {n}")
print("\n=== Sessions migrating multiple times ===")
sess_mig = defaultdict(list)
for e in mig:
sess_mig[e["session_id"]].append(
(e.get("t_decision_unix", 0),
e.get("v3_src_idx", e["chosen_idx"]),
e.get("v3_target_idx", -1))
)
multi = {s: ev for s, ev in sess_mig.items() if len(ev) > 1}
if not multi:
print(" (none)")
for sess, events in sorted(multi.items()):
chain = "".join(f"{s}->{d}" for _, s, d in sorted(events))
print(f" session {sess}: {chain}")
# Recent-received hotspot signal — non-zero values mean the picker
# accepted a target that recently got another migration.
print("\n=== Anti-hotspot signal: dst.recent_mig_received_in_window ===")
rec = [e.get("v3_target_recent_received", 0) for e in mig]
if rec:
nonzero = [v for v in rec if v]
print(f" total migrations: {len(rec)}, "
f"with recent_received > 0: {len(nonzero)}, "
f"max recent_received: {max(rec)}")
# Post-rotation tail: turns of migrated sessions after their LAST mig
print("\n=== Post-rotation tail per inst (turns of migrated sessions after last mig) ===")
tail = Counter()
for sess, events in sess_mig.items():
final_dst = sorted(events)[-1][2]
last_t = max(t for t, _, _ in events)
sess_turns = [mm for rid, mm in m.items() if mm["session_id"] == sess]
tail[final_dst] += sum(1 for mm in sess_turns
if mm.get("t_dispatch_unix", 0) > last_t)
for idx in range(8):
print(f" inst_{idx}: {tail.get(idx, 0)} tail turns")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("usage: analyze_migration_log.py <run_dir>", file=sys.stderr)
sys.exit(1)
main(Path(sys.argv[1]))