Files
agentic-pd-hybrid/scripts/smoke_snapshot_link_gpu.py
Claude Code Agent 7216507773 feat(snapshot): D→P RDMA Phase 1b — GPU pointer path verified
Confirms snapshot_link works for cuda device pointers, not just host
memory. Sender on cuda:0 pushes to receiver on cuda:1 via RDMA over
mlx5_60. All 5 sizes (16K, 1M, 16M, 64M, 256M) pass SHA verification.

  16 KB     8.3 ms   0.016 Gbps  (cold openSegment)
  1 MB      0.10 ms  87.6 Gbps
  16 MB     0.84 ms  159 Gbps
  64 MB     2.52 ms  213 Gbps
  256 MB    8.54 ms  251 Gbps    (~60% NDR400 line rate)

For Inferact-scale sessions (~50K tokens × ~80 KB layer-per-token =
~4 GB), this projects D→P transfer time at ~130 ms — within the
"reseed-savings" envelope sketched in design doc §3.2.

Files:
  scripts/snapshot_link_receiver_gpu.py
  scripts/smoke_snapshot_link_gpu.py

Next: SGLang scheduler integration for D-side dump + P-side ingest.
2026-05-13 00:59:43 +08:00

237 lines
8.4 KiB
Python

#!/usr/bin/env python3
"""GPU-aware smoke test for snapshot_link RDMA byte transfer.
Sender on cuda:0, receiver subprocess on cuda:1. Tests whether
mooncake's transfer_sync_write can move bytes between two GPUs via
RDMA (which is what the real D→P flow will need for KV bytes).
Usage:
bash scripts/setup_env.sh && uv run --no-sync python scripts/smoke_snapshot_link_gpu.py
The sender uses cuda:0 (--send-gpu); the receiver subprocess uses
cuda:1 (--recv-gpu) by default.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
_HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(_HERE.parent / "src"))
SIZES_BYTES_DEFAULT = [
1 << 14, # 16 KB
1 << 20, # 1 MB
1 << 24, # 16 MB
1 << 26, # 64 MB
1 << 28, # 256 MB
]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--host", default=os.environ.get("SNAPSHOT_LINK_HOST", "127.0.0.1"))
ap.add_argument("--ib", default=os.environ.get("SNAPSHOT_LINK_IB", "mlx5_60"))
ap.add_argument("--recv-port", type=int,
default=int(os.environ.get("SNAPSHOT_LINK_RECV_PORT", "17787")))
ap.add_argument("--send-port", type=int,
default=int(os.environ.get("SNAPSHOT_LINK_SEND_PORT", "17788")))
ap.add_argument("--max-bytes", type=int, default=256 * 1024 * 1024)
ap.add_argument("--sizes", default=",".join(str(s) for s in SIZES_BYTES_DEFAULT))
ap.add_argument("--send-gpu", type=int, default=0)
ap.add_argument("--recv-gpu", type=int, default=1)
args = ap.parse_args()
sizes = [int(s) for s in args.sizes.split(",")]
tmpdir = Path(tempfile.mkdtemp(prefix="snapshot_link_gpu_smoke_"))
control_path = tmpdir / "endpoint.json"
recv_stderr_log = tmpdir / "recv.stderr.log"
recv_cmd = [
sys.executable,
str(_HERE / "snapshot_link_receiver_gpu.py"),
"--host", args.host,
"--port", str(args.recv_port),
"--ib", args.ib,
"--max-bytes", str(args.max_bytes),
"--control-path", str(control_path),
"--sizes", args.sizes,
"--gpu-id", str(args.recv_gpu),
]
recv_stderr = open(recv_stderr_log, "w")
print(f"[sender] receiver cmd: {' '.join(recv_cmd)}", flush=True)
recv_proc = subprocess.Popen(
recv_cmd, stdout=subprocess.PIPE, stderr=recv_stderr, bufsize=1,
universal_newlines=True,
)
try:
import torch
if not torch.cuda.is_available():
print("[sender] FAIL: cuda not available")
return 1
torch.cuda.set_device(args.send_gpu)
deadline = time.time() + 90.0
meta = None
while time.time() < deadline:
if control_path.exists():
try:
meta = json.loads(control_path.read_text())
if meta.get("ready"):
break
except Exception:
pass
if recv_proc.poll() is not None:
_dump_recv_stderr(recv_stderr_log)
print(f"[sender] FAIL: receiver exited (rc={recv_proc.returncode})")
return 1
time.sleep(0.1)
if meta is None:
print("[sender] FAIL: receiver endpoint timeout")
return 1
print(f"[sender] receiver endpoint: gpu={meta['gpu_id']}, "
f"sid={meta['session_id']}, ptr={hex(int(meta['base_ptr']))}, "
f"cap={meta['capacity_bytes']}", flush=True)
from agentic_pd_hybrid.snapshot_link import SnapshotPeer, SnapshotEndpoint
endpoint = SnapshotEndpoint(
session_id=meta["session_id"],
base_ptr=int(meta["base_ptr"]),
capacity_bytes=int(meta["capacity_bytes"]),
)
peer = SnapshotPeer(
host=args.host,
port=args.send_port,
ib_device=args.ib,
receive_capacity_bytes=0,
)
# Allocate a sender buffer on cuda:0
send_tensor = torch.zeros(args.max_bytes, dtype=torch.uint8,
device=f"cuda:{args.send_gpu}")
send_ptr = send_tensor.data_ptr()
ret = peer.engine.register_memory(send_ptr, args.max_bytes)
if ret != 0:
print(f"[sender] FAIL: register_memory ret={ret}")
return 1
print(f"[sender] own gpu={args.send_gpu}, sid={peer.session_id}, "
f"buf @ {hex(send_ptr)} ({args.max_bytes} B)", flush=True)
transfers = []
for size in sizes:
if size > args.max_bytes:
continue
# Fill with deterministic pattern on GPU
seed = int(time.time() * 1e6) & 0xFFFFFFFF
# Use a simple seeded pattern via torch ops
gen = torch.Generator(device=f"cuda:{args.send_gpu}")
gen.manual_seed(seed)
send_tensor[:size] = torch.randint(0, 256, (size,), dtype=torch.uint8,
device=f"cuda:{args.send_gpu}",
generator=gen)
torch.cuda.synchronize(args.send_gpu)
# Compute expected hash (host-side)
host_view = send_tensor[:size].cpu().numpy().tobytes()
expected_sha = hashlib.sha256(host_view).hexdigest()
# Push via RDMA
t0 = time.perf_counter()
ret = peer.push(endpoint, send_ptr, 0, size, remote_offset=0)
t1 = time.perf_counter()
dt_ms = (t1 - t0) * 1000.0
gbps = (size * 8.0 / 1e9) / max(t1 - t0, 1e-9)
print(f"[sender] push size={size:>10d} ret={ret} "
f"dur={dt_ms:>9.3f} ms thru={gbps:>6.3f} Gbps",
flush=True)
# Signal receiver to verify
signal_path = control_path.with_suffix(f".do{size}")
ack_path = control_path.with_suffix(f".ack{size}")
signal_path.write_text(json.dumps({"sha": expected_sha}))
ack_deadline = time.time() + 90.0
while time.time() < ack_deadline:
if ack_path.exists():
break
if recv_proc.poll() is not None:
print(f"[sender] FAIL: receiver died after size={size}")
_dump_recv_stderr(recv_stderr_log)
return 1
time.sleep(0.05)
transfers.append({
"size": size, "ret": ret, "dur_ms": round(dt_ms, 3),
"thru_Gbps": round(gbps, 3), "ack": ack_path.exists(),
})
try:
recv_proc.wait(timeout=10)
except subprocess.TimeoutExpired:
recv_proc.terminate()
recv_proc.wait(timeout=5)
events = []
if recv_proc.stdout is not None:
for raw in recv_proc.stdout:
raw = raw.strip()
if not raw:
continue
try:
events.append(json.loads(raw))
except json.JSONDecodeError:
events.append({"event": "non-json", "raw": raw})
print("=" * 78)
print("[receiver] events:")
verify_ok = 0
verify_fail = 0
for ev in events:
print(f" {ev}")
if ev.get("event") == "verify":
if ev.get("ok"):
verify_ok += 1
else:
verify_fail += 1
recv_stderr.close()
_dump_recv_stderr(recv_stderr_log, header="--- receiver stderr ---")
overall = "PASS" if verify_fail == 0 and verify_ok == len(transfers) else "FAIL"
print("=" * 78)
print(f"OVERALL: {overall} verify_ok={verify_ok} verify_fail={verify_fail} "
f"transfers={len(transfers)} send_gpu={args.send_gpu} recv_gpu={args.recv_gpu}")
return 0 if overall == "PASS" else 1
finally:
try:
recv_proc.terminate()
recv_proc.wait(timeout=5)
except Exception:
try:
recv_proc.kill()
except Exception:
pass
def _dump_recv_stderr(path: Path, header: str = "--- receiver stderr (last 60) ---") -> None:
try:
text = path.read_text()
except FileNotFoundError:
return
print(header, flush=True)
for line in text.splitlines()[-60:]:
print(f" {line}", flush=True)
if __name__ == "__main__":
sys.exit(main())