Confirms snapshot_link works for cuda device pointers, not just host memory. Sender on cuda:0 pushes to receiver on cuda:1 via RDMA over mlx5_60. All 5 sizes (16K, 1M, 16M, 64M, 256M) pass SHA verification. 16 KB 8.3 ms 0.016 Gbps (cold openSegment) 1 MB 0.10 ms 87.6 Gbps 16 MB 0.84 ms 159 Gbps 64 MB 2.52 ms 213 Gbps 256 MB 8.54 ms 251 Gbps (~60% NDR400 line rate) For Inferact-scale sessions (~50K tokens × ~80 KB layer-per-token = ~4 GB), this projects D→P transfer time at ~130 ms — within the "reseed-savings" envelope sketched in design doc §3.2. Files: scripts/snapshot_link_receiver_gpu.py scripts/smoke_snapshot_link_gpu.py Next: SGLang scheduler integration for D-side dump + P-side ingest.
237 lines
8.4 KiB
Python
237 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
|
"""GPU-aware smoke test for snapshot_link RDMA byte transfer.
|
|
|
|
Sender on cuda:0, receiver subprocess on cuda:1. Tests whether
|
|
mooncake's transfer_sync_write can move bytes between two GPUs via
|
|
RDMA (which is what the real D→P flow will need for KV bytes).
|
|
|
|
Usage:
|
|
bash scripts/setup_env.sh && uv run --no-sync python scripts/smoke_snapshot_link_gpu.py
|
|
|
|
The sender uses cuda:0 (--send-gpu); the receiver subprocess uses
|
|
cuda:1 (--recv-gpu) by default.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
|
|
_HERE = Path(__file__).resolve().parent
|
|
sys.path.insert(0, str(_HERE.parent / "src"))
|
|
|
|
|
|
SIZES_BYTES_DEFAULT = [
|
|
1 << 14, # 16 KB
|
|
1 << 20, # 1 MB
|
|
1 << 24, # 16 MB
|
|
1 << 26, # 64 MB
|
|
1 << 28, # 256 MB
|
|
]
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--host", default=os.environ.get("SNAPSHOT_LINK_HOST", "127.0.0.1"))
|
|
ap.add_argument("--ib", default=os.environ.get("SNAPSHOT_LINK_IB", "mlx5_60"))
|
|
ap.add_argument("--recv-port", type=int,
|
|
default=int(os.environ.get("SNAPSHOT_LINK_RECV_PORT", "17787")))
|
|
ap.add_argument("--send-port", type=int,
|
|
default=int(os.environ.get("SNAPSHOT_LINK_SEND_PORT", "17788")))
|
|
ap.add_argument("--max-bytes", type=int, default=256 * 1024 * 1024)
|
|
ap.add_argument("--sizes", default=",".join(str(s) for s in SIZES_BYTES_DEFAULT))
|
|
ap.add_argument("--send-gpu", type=int, default=0)
|
|
ap.add_argument("--recv-gpu", type=int, default=1)
|
|
args = ap.parse_args()
|
|
|
|
sizes = [int(s) for s in args.sizes.split(",")]
|
|
tmpdir = Path(tempfile.mkdtemp(prefix="snapshot_link_gpu_smoke_"))
|
|
control_path = tmpdir / "endpoint.json"
|
|
recv_stderr_log = tmpdir / "recv.stderr.log"
|
|
|
|
recv_cmd = [
|
|
sys.executable,
|
|
str(_HERE / "snapshot_link_receiver_gpu.py"),
|
|
"--host", args.host,
|
|
"--port", str(args.recv_port),
|
|
"--ib", args.ib,
|
|
"--max-bytes", str(args.max_bytes),
|
|
"--control-path", str(control_path),
|
|
"--sizes", args.sizes,
|
|
"--gpu-id", str(args.recv_gpu),
|
|
]
|
|
recv_stderr = open(recv_stderr_log, "w")
|
|
print(f"[sender] receiver cmd: {' '.join(recv_cmd)}", flush=True)
|
|
recv_proc = subprocess.Popen(
|
|
recv_cmd, stdout=subprocess.PIPE, stderr=recv_stderr, bufsize=1,
|
|
universal_newlines=True,
|
|
)
|
|
|
|
try:
|
|
import torch
|
|
if not torch.cuda.is_available():
|
|
print("[sender] FAIL: cuda not available")
|
|
return 1
|
|
torch.cuda.set_device(args.send_gpu)
|
|
|
|
deadline = time.time() + 90.0
|
|
meta = None
|
|
while time.time() < deadline:
|
|
if control_path.exists():
|
|
try:
|
|
meta = json.loads(control_path.read_text())
|
|
if meta.get("ready"):
|
|
break
|
|
except Exception:
|
|
pass
|
|
if recv_proc.poll() is not None:
|
|
_dump_recv_stderr(recv_stderr_log)
|
|
print(f"[sender] FAIL: receiver exited (rc={recv_proc.returncode})")
|
|
return 1
|
|
time.sleep(0.1)
|
|
if meta is None:
|
|
print("[sender] FAIL: receiver endpoint timeout")
|
|
return 1
|
|
print(f"[sender] receiver endpoint: gpu={meta['gpu_id']}, "
|
|
f"sid={meta['session_id']}, ptr={hex(int(meta['base_ptr']))}, "
|
|
f"cap={meta['capacity_bytes']}", flush=True)
|
|
|
|
from agentic_pd_hybrid.snapshot_link import SnapshotPeer, SnapshotEndpoint
|
|
|
|
endpoint = SnapshotEndpoint(
|
|
session_id=meta["session_id"],
|
|
base_ptr=int(meta["base_ptr"]),
|
|
capacity_bytes=int(meta["capacity_bytes"]),
|
|
)
|
|
|
|
peer = SnapshotPeer(
|
|
host=args.host,
|
|
port=args.send_port,
|
|
ib_device=args.ib,
|
|
receive_capacity_bytes=0,
|
|
)
|
|
|
|
# Allocate a sender buffer on cuda:0
|
|
send_tensor = torch.zeros(args.max_bytes, dtype=torch.uint8,
|
|
device=f"cuda:{args.send_gpu}")
|
|
send_ptr = send_tensor.data_ptr()
|
|
ret = peer.engine.register_memory(send_ptr, args.max_bytes)
|
|
if ret != 0:
|
|
print(f"[sender] FAIL: register_memory ret={ret}")
|
|
return 1
|
|
print(f"[sender] own gpu={args.send_gpu}, sid={peer.session_id}, "
|
|
f"buf @ {hex(send_ptr)} ({args.max_bytes} B)", flush=True)
|
|
|
|
transfers = []
|
|
for size in sizes:
|
|
if size > args.max_bytes:
|
|
continue
|
|
# Fill with deterministic pattern on GPU
|
|
seed = int(time.time() * 1e6) & 0xFFFFFFFF
|
|
# Use a simple seeded pattern via torch ops
|
|
gen = torch.Generator(device=f"cuda:{args.send_gpu}")
|
|
gen.manual_seed(seed)
|
|
send_tensor[:size] = torch.randint(0, 256, (size,), dtype=torch.uint8,
|
|
device=f"cuda:{args.send_gpu}",
|
|
generator=gen)
|
|
torch.cuda.synchronize(args.send_gpu)
|
|
# Compute expected hash (host-side)
|
|
host_view = send_tensor[:size].cpu().numpy().tobytes()
|
|
expected_sha = hashlib.sha256(host_view).hexdigest()
|
|
# Push via RDMA
|
|
t0 = time.perf_counter()
|
|
ret = peer.push(endpoint, send_ptr, 0, size, remote_offset=0)
|
|
t1 = time.perf_counter()
|
|
dt_ms = (t1 - t0) * 1000.0
|
|
gbps = (size * 8.0 / 1e9) / max(t1 - t0, 1e-9)
|
|
print(f"[sender] push size={size:>10d} ret={ret} "
|
|
f"dur={dt_ms:>9.3f} ms thru={gbps:>6.3f} Gbps",
|
|
flush=True)
|
|
|
|
# Signal receiver to verify
|
|
signal_path = control_path.with_suffix(f".do{size}")
|
|
ack_path = control_path.with_suffix(f".ack{size}")
|
|
signal_path.write_text(json.dumps({"sha": expected_sha}))
|
|
ack_deadline = time.time() + 90.0
|
|
while time.time() < ack_deadline:
|
|
if ack_path.exists():
|
|
break
|
|
if recv_proc.poll() is not None:
|
|
print(f"[sender] FAIL: receiver died after size={size}")
|
|
_dump_recv_stderr(recv_stderr_log)
|
|
return 1
|
|
time.sleep(0.05)
|
|
transfers.append({
|
|
"size": size, "ret": ret, "dur_ms": round(dt_ms, 3),
|
|
"thru_Gbps": round(gbps, 3), "ack": ack_path.exists(),
|
|
})
|
|
|
|
try:
|
|
recv_proc.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
recv_proc.terminate()
|
|
recv_proc.wait(timeout=5)
|
|
|
|
events = []
|
|
if recv_proc.stdout is not None:
|
|
for raw in recv_proc.stdout:
|
|
raw = raw.strip()
|
|
if not raw:
|
|
continue
|
|
try:
|
|
events.append(json.loads(raw))
|
|
except json.JSONDecodeError:
|
|
events.append({"event": "non-json", "raw": raw})
|
|
|
|
print("=" * 78)
|
|
print("[receiver] events:")
|
|
verify_ok = 0
|
|
verify_fail = 0
|
|
for ev in events:
|
|
print(f" {ev}")
|
|
if ev.get("event") == "verify":
|
|
if ev.get("ok"):
|
|
verify_ok += 1
|
|
else:
|
|
verify_fail += 1
|
|
|
|
recv_stderr.close()
|
|
_dump_recv_stderr(recv_stderr_log, header="--- receiver stderr ---")
|
|
|
|
overall = "PASS" if verify_fail == 0 and verify_ok == len(transfers) else "FAIL"
|
|
print("=" * 78)
|
|
print(f"OVERALL: {overall} verify_ok={verify_ok} verify_fail={verify_fail} "
|
|
f"transfers={len(transfers)} send_gpu={args.send_gpu} recv_gpu={args.recv_gpu}")
|
|
return 0 if overall == "PASS" else 1
|
|
|
|
finally:
|
|
try:
|
|
recv_proc.terminate()
|
|
recv_proc.wait(timeout=5)
|
|
except Exception:
|
|
try:
|
|
recv_proc.kill()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _dump_recv_stderr(path: Path, header: str = "--- receiver stderr (last 60) ---") -> None:
|
|
try:
|
|
text = path.read_text()
|
|
except FileNotFoundError:
|
|
return
|
|
print(header, flush=True)
|
|
for line in text.splitlines()[-60:]:
|
|
print(f" {line}", flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|