feat(snapshot): D→P RDMA link Phase 1 — minimal byte transport
A thin wrapper around mooncake.engine.TransferEngine that does one-sided RDMA writes between two SnapshotPeer endpoints. Bypasses SGLang's MooncakeKVManager (which is hard-gated to PREFILL/DECODE roles via add_transfer_request assertion at conn.py:1563) so the D→P direction doesn't require invasive role-axis changes upstream. Smoke test (two subprocess.Popen processes, mlx5_60, 127.0.0.1): 1 KB 9.0 ms (one-time openSegment handshake) 16 KB 0.04 ms 3.5 Gbps 1 MB 0.10 ms 82 Gbps 16 MB 0.58 ms 232 Gbps 64 MB 1.70 ms 316 Gbps (~80% of NDR 400G line rate) All 5 sizes pass SHA256 verification end-to-end. Files: src/agentic_pd_hybrid/snapshot_link.py — SnapshotPeer, SnapshotEndpoint scripts/snapshot_link_receiver.py — child-process receiver scripts/smoke_snapshot_link.py — sender + verifier docs/D_TO_P_PHASE1_LINK_ZH.md — phase 1 acceptance doc Next: Phase 2 (D-side scheduler commit hook), Phase 3 (P-side prefill bypass with snapshot KV). See docs/D_TO_P_SYNC_DESIGN_ZH.md §5.
This commit is contained in:
140
docs/D_TO_P_PHASE1_LINK_ZH.md
Normal file
140
docs/D_TO_P_PHASE1_LINK_ZH.md
Normal file
@@ -0,0 +1,140 @@
|
||||
# D→P Phase 1:底层 RDMA 链路(已验收)
|
||||
|
||||
**日期**:2026-05-13
|
||||
**状态**:底层链路通过 smoke test 验收
|
||||
**前置**:`docs/D_TO_P_SYNC_DESIGN_ZH.md`
|
||||
**对应 commit**:`feat(snapshot): D→P snapshot link over mooncake RDMA`
|
||||
|
||||
---
|
||||
|
||||
## 0. 一句话
|
||||
|
||||
实现一个独立于 SGLang `MooncakeKVManager` 的**最小 RDMA 字节传输模块**(`src/agentic_pd_hybrid/snapshot_link.py`),双进程 smoke test 跑通 1 KB → 64 MB 一共 5 个 size,全部 SHA 校验通过,64 MB 单次 RDMA write 实测 315 Gbps(mlx5_60 NDR 400 Gb 的约 80%)。
|
||||
|
||||
## 1. 设计动机
|
||||
|
||||
`docs/D_TO_P_SYNC_DESIGN_ZH.md` 选定 Option C(D→P snapshot push + P SessionSlot + prefill bypass)。这个方案的最底层依赖是"D 进程能把字节通过 RDMA 推到 P 进程的预注册缓冲区"。
|
||||
|
||||
直接复用 SGLang 的 `MooncakeKVManager` 不可行:
|
||||
- `add_transfer_request` 在 `conn.py:1563` 硬 assert `disaggregation_mode == PREFILL`
|
||||
- PD pipeline 的发送 / 接收 thread / queue / staging 紧耦合 PD 角色
|
||||
- 改 PD 路径风险大(影响现有 E1/E2/E3 配置)
|
||||
|
||||
因此把 D→P link 单独写成一个轻量模块,直接调 `mooncake.engine.TransferEngine` 的 `transfer_sync_write` / `batch_transfer_sync_write`,不经过 PD pipeline。
|
||||
|
||||
## 2. 实现
|
||||
|
||||
### 2.1 `snapshot_link.SnapshotPeer`
|
||||
|
||||
```python
|
||||
peer = SnapshotPeer(host, port, ib_device, receive_capacity_bytes)
|
||||
endpoint = peer.endpoint # SnapshotEndpoint(session_id, base_ptr, capacity_bytes)
|
||||
peer.register_send_buffer(ptr, length)
|
||||
peer.push(target_endpoint, local_ptr, local_off, length, remote_off=0)
|
||||
peer.batch_push(target, local_addrs, remote_addrs, lengths)
|
||||
peer.read_bytes(offset, length) -> bytes
|
||||
peer.close()
|
||||
```
|
||||
|
||||
- 每个 `SnapshotPeer` 拥有自己的 `TransferEngine`,绑定 `host:port`
|
||||
- `receive_capacity_bytes > 0` 时分配一段 ctypes `c_ubyte` 数组 + `register_memory`
|
||||
- `push` 直接走 `engine.transfer_sync_write(peer_session_id, local_ptr, remote_ptr, length)`
|
||||
- 角色完全对称——任何 `SnapshotPeer` 既可以发送也可以接收,由 caller 决定
|
||||
|
||||
### 2.2 Smoke test 双进程结构
|
||||
|
||||
```
|
||||
父进程 (sender) 子进程 (receiver, subprocess.Popen)
|
||||
│ │
|
||||
│ spawn → ──────────────────────────────►│
|
||||
│ │ SnapshotPeer(recv_capacity=64MB)
|
||||
│ │ write endpoint.json
|
||||
│ read endpoint.json ◄───────────────────│
|
||||
│ │
|
||||
│ SnapshotPeer(no recv buf) │
|
||||
│ register_send_buffer(64MB) │
|
||||
│ │
|
||||
│ for size in [1K, 16K, 1M, 16M, 64M]: │
|
||||
│ fill_pattern(send_buf, seed) │
|
||||
│ peer.push(endpoint, 0, size) ─RDMA──►│
|
||||
│ │ wait signal
|
||||
│ write endpoint.do{size} ────────────►│ read signal seed
|
||||
│ │ compute expected SHA
|
||||
│ │ recv_bytes = peer.read_bytes
|
||||
│ wait endpoint.ack{size} │ compare SHA → emit JSON event
|
||||
│ │ write endpoint.ack{size}
|
||||
│ ... │
|
||||
│ │
|
||||
│ drain child stdout, parse JSON │ exit
|
||||
│ verify each event has ok=true │
|
||||
```
|
||||
|
||||
### 2.3 性能(首次 smoke run)
|
||||
|
||||
| Size | Push duration | Throughput |
|
||||
|---:|---:|---:|
|
||||
| 1 KB | 9.0 ms | 0.001 Gbps |
|
||||
| 16 KB | 0.037 ms | 3.5 Gbps |
|
||||
| 1 MB | 0.102 ms | 82 Gbps |
|
||||
| 16 MB | 0.577 ms | 232 Gbps |
|
||||
| **64 MB** | **1.70 ms** | **316 Gbps** |
|
||||
|
||||
- 1 KB 第一次有 ~9 ms 的 mooncake p2p handshake/openSegment overhead(一次性)
|
||||
- 16 KB 之后是稳态,吞吐随 size 增长接近线速
|
||||
- mlx5_60 是 mlx5 ConnectX-7 NDR 400 Gb(4× 100Gb lanes);64 MB 测到 316 Gbps 是 79% 的链路利用率,对单次 RDMA write 来说正常(剩余空间留给 verb dispatch / completion handling overhead)
|
||||
|
||||
## 3. 验收
|
||||
|
||||
- ✅ 5/5 size SHA 校验全部通过
|
||||
- ✅ 64 MB 一次 RDMA 1.7 ms
|
||||
- ✅ 双进程独立,不耦合 SGLang PD pipeline
|
||||
- ✅ Smoke test 脚本 `scripts/smoke_snapshot_link.py` 可重跑
|
||||
|
||||
## 4. 当前覆盖范围(清单)
|
||||
|
||||
- ✅ Host CPU 内存的 D→P RDMA byte transfer
|
||||
- ✅ 单 IB device (mlx5_60)
|
||||
- ✅ 同节点 loopback(127.0.0.1)
|
||||
- ⏳ GPU 内存(设备指针 + `batch_transfer_write_on_cuda`)—— 现有 `push()` 走 `transfer_sync_write`,对 GPU 指针支持取决于 mooncake 的 protocol;下一步验证
|
||||
- ⏳ 跨节点(远端 IP)—— 设计上一致,未验证
|
||||
- ⏳ 多 D → 单 P(多 sender → 共享 recv buffer 的 offset 协调)—— 留给 Phase 3 整合时设计
|
||||
- ⏳ ZeroCopy 入 SGLang kv_pool slot —— 留给 Phase 2
|
||||
|
||||
## 5. 下一步(Phase 2 / Phase 3)
|
||||
|
||||
详见 `docs/D_TO_P_SYNC_DESIGN_ZH.md` §5。本 phase 1 解锁后,整个 D→P 同步可以正式开始整合到 SGLang scheduler:
|
||||
|
||||
| Phase | 描述 | 风险 |
|
||||
|---|---|---|
|
||||
| 2 | D-side commit hook:`cache_finished_req` 完成后 enqueue snapshot push | 中。需要在 scheduler 后台线程跑 push,不能阻塞 schedule loop |
|
||||
| 3 | P-side snapshot store + prefill bypass:P scheduler 收到 use-snapshot 请求时跳过 `model.forward()`,直接用 snapshot KV 触发 P→D' transfer | **最高**。需要深入 SGLang prefill 流程 |
|
||||
| 4 | agentic-pd-hybrid hook:`_invoke_kvcache_seeded_router` 先 probe P → 决定走 bypass 还是 fallback | 低 |
|
||||
| 5 | CLI flag + structural log | 低 |
|
||||
| 6 | 端到端 smoke + E4 sweep | 中 |
|
||||
|
||||
## 6. 知识沉淀
|
||||
|
||||
### 易踩坑
|
||||
|
||||
| 坑 | 原因 | 修法 |
|
||||
|---|---|---|
|
||||
| 多进程 `multiprocessing.Process` 子进程崩溃信息丢失 | spawn context 下 child 没有继承 parent 的 stderr | 改用 `subprocess.Popen` + stderr 重定向到文件 |
|
||||
| `bytes(ctypes.c_byte * N)` 失败 `ValueError: bytes must be in range(0, 256)` | `c_byte` 是 **signed**,>= 128 的 byte 在 Python 看就是负数 | 用 `c_ubyte` 或 `ctypes.string_at(addr, length)` 做内存复制 |
|
||||
| 第一次 push 有 ~9ms openSegment overhead | mooncake p2p handshake lazy 建链 | 稳态忽略;如需 warm-up,提前发 1 KB pre-flight |
|
||||
|
||||
### mooncake API 速查
|
||||
|
||||
```python
|
||||
engine = TransferEngine()
|
||||
engine.initialize(f"{host}:{port}", "P2PHANDSHAKE", "rdma", ib_device)
|
||||
engine.register_memory(ptr, length) # mr 注册
|
||||
engine.transfer_sync_write(peer_session_id, local_ptr, remote_ptr, length) # RDMA write
|
||||
engine.batch_transfer_sync_write(peer_session_id, [local_ptrs], [remote_ptrs], [lengths])
|
||||
engine.unregister_memory(ptr)
|
||||
```
|
||||
|
||||
`peer_session_id` 是 `"host:rpc_port"`,其中 `rpc_port = peer_engine.get_rpc_port()`。
|
||||
|
||||
---
|
||||
|
||||
**核心句**:D→P 底层 RDMA 链路独立模块跑通,64 MB 1.7 ms / 316 Gbps,与 SGLang PD pipeline 完全解耦。Phase 2/3 可以放心在这上面叠加。
|
||||
244
scripts/smoke_snapshot_link.py
Executable file
244
scripts/smoke_snapshot_link.py
Executable file
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Two-process smoke test for snapshot_link D→P RDMA byte transfer.
|
||||
|
||||
Spawns scripts/snapshot_link_receiver.py via subprocess.Popen with stderr
|
||||
piped to ``<tmpdir>/recv.stderr.log`` for post-mortem if something dies.
|
||||
|
||||
Sender (this process):
|
||||
1. Spawns receiver child, waits for endpoint.json
|
||||
2. Brings up own SnapshotPeer (no recv buffer), registers a send buffer
|
||||
3. For each size: fill pattern, batch_transfer_sync_write, signal child,
|
||||
wait for child's ack
|
||||
4. Reads child's stdout (one JSON event per line) for verification
|
||||
|
||||
Pass = every size yields a child "verify" event with ok=true.
|
||||
|
||||
Usage:
|
||||
bash scripts/setup_env.sh && uv run --no-sync python scripts/smoke_snapshot_link.py
|
||||
|
||||
Env (optional):
|
||||
SNAPSHOT_LINK_HOST default 127.0.0.1
|
||||
SNAPSHOT_LINK_IB default mlx5_60
|
||||
SNAPSHOT_LINK_RECV_PORT default 17777
|
||||
SNAPSHOT_LINK_SEND_PORT default 17778
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ctypes
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
_HERE = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(_HERE.parent / "src"))
|
||||
|
||||
|
||||
SIZES_BYTES_DEFAULT = [
|
||||
1 << 10, # 1 KB
|
||||
1 << 14, # 16 KB
|
||||
1 << 18, # 256 KB
|
||||
1 << 20, # 1 MB
|
||||
1 << 22, # 4 MB
|
||||
1 << 24, # 16 MB
|
||||
1 << 26, # 64 MB
|
||||
]
|
||||
|
||||
|
||||
def _pattern_byte(i: int, seed: int) -> int:
|
||||
return (i * 2654435761 + seed) & 0xFF
|
||||
|
||||
|
||||
def _fill_pattern(buf, length: int, seed: int) -> None:
|
||||
tile_size = 4096
|
||||
tile = bytes(_pattern_byte(i, seed) for i in range(tile_size))
|
||||
tile_arr = (ctypes.c_ubyte * tile_size).from_buffer_copy(tile)
|
||||
n_full = length // tile_size
|
||||
rem = length - n_full * tile_size
|
||||
base = ctypes.addressof(buf)
|
||||
src_addr = ctypes.addressof(tile_arr)
|
||||
for k in range(n_full):
|
||||
ctypes.memmove(base + k * tile_size, src_addr, tile_size)
|
||||
if rem:
|
||||
ctypes.memmove(base + n_full * tile_size, src_addr, rem)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--host", default=os.environ.get("SNAPSHOT_LINK_HOST", "127.0.0.1"))
|
||||
ap.add_argument("--ib", default=os.environ.get("SNAPSHOT_LINK_IB", "mlx5_60"))
|
||||
ap.add_argument("--recv-port", type=int,
|
||||
default=int(os.environ.get("SNAPSHOT_LINK_RECV_PORT", "17777")))
|
||||
ap.add_argument("--send-port", type=int,
|
||||
default=int(os.environ.get("SNAPSHOT_LINK_SEND_PORT", "17778")))
|
||||
ap.add_argument("--max-bytes", type=int, default=128 * 1024 * 1024)
|
||||
ap.add_argument("--sizes", default=",".join(str(s) for s in SIZES_BYTES_DEFAULT))
|
||||
args = ap.parse_args()
|
||||
|
||||
sizes = [int(s) for s in args.sizes.split(",")]
|
||||
tmpdir = Path(tempfile.mkdtemp(prefix="snapshot_link_smoke_"))
|
||||
control_path = tmpdir / "endpoint.json"
|
||||
recv_stderr_log = tmpdir / "recv.stderr.log"
|
||||
|
||||
recv_cmd = [
|
||||
sys.executable,
|
||||
str(_HERE / "snapshot_link_receiver.py"),
|
||||
"--host", args.host,
|
||||
"--port", str(args.recv_port),
|
||||
"--ib", args.ib,
|
||||
"--max-bytes", str(args.max_bytes),
|
||||
"--control-path", str(control_path),
|
||||
"--sizes", args.sizes,
|
||||
]
|
||||
recv_stderr = open(recv_stderr_log, "w")
|
||||
print(f"[sender] launching receiver: {' '.join(recv_cmd)}", flush=True)
|
||||
print(f"[sender] receiver stderr → {recv_stderr_log}", flush=True)
|
||||
recv_proc = subprocess.Popen(
|
||||
recv_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=recv_stderr,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for endpoint metadata
|
||||
deadline = time.time() + 60.0
|
||||
while time.time() < deadline:
|
||||
if control_path.exists():
|
||||
try:
|
||||
meta = json.loads(control_path.read_text())
|
||||
if meta.get("ready"):
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if recv_proc.poll() is not None:
|
||||
_dump_recv_stderr(recv_stderr_log)
|
||||
print(f"[sender] FAIL: receiver exited early (rc={recv_proc.returncode})")
|
||||
return 1
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
print("[sender] FAIL: timed out waiting for receiver endpoint", flush=True)
|
||||
return 1
|
||||
|
||||
print(f"[sender] receiver endpoint: {meta}", flush=True)
|
||||
|
||||
from agentic_pd_hybrid.snapshot_link import SnapshotPeer, SnapshotEndpoint
|
||||
endpoint = SnapshotEndpoint(
|
||||
session_id=meta["session_id"],
|
||||
base_ptr=int(meta["base_ptr"]),
|
||||
capacity_bytes=int(meta["capacity_bytes"]),
|
||||
)
|
||||
peer = SnapshotPeer(
|
||||
host=args.host,
|
||||
port=args.send_port,
|
||||
ib_device=args.ib,
|
||||
receive_capacity_bytes=0,
|
||||
)
|
||||
send_buf = (ctypes.c_byte * args.max_bytes)()
|
||||
send_addr = ctypes.addressof(send_buf)
|
||||
peer.register_send_buffer(send_addr, args.max_bytes)
|
||||
print(f"[sender] own session_id={peer.session_id}, send_buf @ {hex(send_addr)} ({args.max_bytes} B)", flush=True)
|
||||
|
||||
transfers = []
|
||||
for size in sizes:
|
||||
if size > args.max_bytes:
|
||||
continue
|
||||
seed = int(time.time() * 1e6) & 0xFFFFFFFF
|
||||
_fill_pattern(send_buf, size, seed)
|
||||
t0 = time.perf_counter()
|
||||
ret = peer.push(endpoint, send_addr, 0, size, remote_offset=0)
|
||||
t1 = time.perf_counter()
|
||||
dt_ms = (t1 - t0) * 1000.0
|
||||
gbps = (size * 8.0 / 1e9) / max(t1 - t0, 1e-9)
|
||||
print(f"[sender] push size={size:>10d} ret={ret} "
|
||||
f"dur={dt_ms:>9.3f} ms thru={gbps:>6.3f} Gbps",
|
||||
flush=True)
|
||||
signal_path = control_path.with_suffix(f".do{size}")
|
||||
ack_path = control_path.with_suffix(f".ack{size}")
|
||||
signal_path.write_text(str(seed))
|
||||
ack_deadline = time.time() + 60.0
|
||||
while time.time() < ack_deadline:
|
||||
if ack_path.exists():
|
||||
break
|
||||
if recv_proc.poll() is not None:
|
||||
print(f"[sender] FAIL: receiver died after size={size}", flush=True)
|
||||
_dump_recv_stderr(recv_stderr_log)
|
||||
return 1
|
||||
time.sleep(0.05)
|
||||
transfers.append({
|
||||
"size": size, "ret": ret, "dur_ms": round(dt_ms, 3),
|
||||
"thru_Gbps": round(gbps, 3),
|
||||
"ack": ack_path.exists(),
|
||||
})
|
||||
|
||||
peer.close()
|
||||
|
||||
# Drain child stdout — each line is a JSON event
|
||||
try:
|
||||
recv_proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
recv_proc.terminate()
|
||||
recv_proc.wait(timeout=5)
|
||||
|
||||
events = []
|
||||
if recv_proc.stdout is not None:
|
||||
for raw in recv_proc.stdout:
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(raw))
|
||||
except json.JSONDecodeError:
|
||||
events.append({"event": "non-json", "raw": raw})
|
||||
|
||||
print("=" * 78)
|
||||
print("[receiver] events:")
|
||||
verify_ok = 0
|
||||
verify_fail = 0
|
||||
for ev in events:
|
||||
print(f" {ev}")
|
||||
if ev.get("event") == "verify":
|
||||
if ev.get("ok"):
|
||||
verify_ok += 1
|
||||
else:
|
||||
verify_fail += 1
|
||||
|
||||
recv_stderr.close()
|
||||
_dump_recv_stderr(recv_stderr_log, header="--- receiver stderr ---")
|
||||
|
||||
overall = "PASS" if verify_fail == 0 and verify_ok == len(transfers) else "FAIL"
|
||||
print("=" * 78)
|
||||
print(f"OVERALL: {overall} verify_ok={verify_ok} verify_fail={verify_fail} "
|
||||
f"transfers={len(transfers)}")
|
||||
return 0 if overall == "PASS" else 1
|
||||
|
||||
finally:
|
||||
try:
|
||||
recv_proc.terminate()
|
||||
recv_proc.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
recv_proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _dump_recv_stderr(path: Path, header: str = "--- receiver stderr (last 40) ---") -> None:
|
||||
try:
|
||||
text = path.read_text()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
print(header, flush=True)
|
||||
for line in text.splitlines()[-40:]:
|
||||
print(f" {line}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
123
scripts/snapshot_link_receiver.py
Normal file
123
scripts/snapshot_link_receiver.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Receiver-side child process for the snapshot_link smoke test.
|
||||
|
||||
Reads CLI args, brings up a SnapshotPeer with a registered recv buffer,
|
||||
writes endpoint metadata to a control file, then loops: wait for size
|
||||
signal, verify recv buffer, write ack.
|
||||
|
||||
Status events are printed as single-line JSON to stdout for parent to
|
||||
parse.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ctypes
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
|
||||
|
||||
|
||||
def _pattern_byte(i: int, seed: int) -> int:
|
||||
return (i * 2654435761 + seed) & 0xFF
|
||||
|
||||
|
||||
def _fill_pattern(buf, length: int, seed: int) -> None:
|
||||
tile_size = 4096
|
||||
tile = bytes(_pattern_byte(i, seed) for i in range(tile_size))
|
||||
tile_arr = (ctypes.c_ubyte * tile_size).from_buffer_copy(tile)
|
||||
n_full = length // tile_size
|
||||
rem = length - n_full * tile_size
|
||||
base = ctypes.addressof(buf)
|
||||
src_addr = ctypes.addressof(tile_arr)
|
||||
for k in range(n_full):
|
||||
ctypes.memmove(base + k * tile_size, src_addr, tile_size)
|
||||
if rem:
|
||||
ctypes.memmove(base + n_full * tile_size, src_addr, rem)
|
||||
|
||||
|
||||
def _emit(d: dict) -> None:
|
||||
print(json.dumps(d), flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--host", required=True)
|
||||
ap.add_argument("--port", type=int, required=True)
|
||||
ap.add_argument("--ib", required=True)
|
||||
ap.add_argument("--max-bytes", type=int, required=True)
|
||||
ap.add_argument("--control-path", required=True)
|
||||
ap.add_argument("--sizes", required=True, help="comma-separated bytes")
|
||||
args = ap.parse_args()
|
||||
|
||||
sizes = [int(s) for s in args.sizes.split(",")]
|
||||
|
||||
from agentic_pd_hybrid.snapshot_link import SnapshotPeer
|
||||
|
||||
try:
|
||||
peer = SnapshotPeer(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
ib_device=args.ib,
|
||||
receive_capacity_bytes=args.max_bytes,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
_emit({"event": "init-failed", "error": repr(e), "tb": traceback.format_exc()})
|
||||
sys.exit(2)
|
||||
|
||||
endpoint = peer.endpoint
|
||||
Path(args.control_path).write_text(json.dumps({
|
||||
"session_id": endpoint.session_id,
|
||||
"base_ptr": endpoint.base_ptr,
|
||||
"capacity_bytes": endpoint.capacity_bytes,
|
||||
"ready": True,
|
||||
}))
|
||||
_emit({"event": "endpoint-ready", "session_id": endpoint.session_id,
|
||||
"base_ptr": endpoint.base_ptr, "capacity": endpoint.capacity_bytes})
|
||||
|
||||
cp = Path(args.control_path)
|
||||
for size in sizes:
|
||||
if size > args.max_bytes:
|
||||
continue
|
||||
signal_path = cp.with_suffix(f".do{size}")
|
||||
ack_path = cp.with_suffix(f".ack{size}")
|
||||
deadline = time.time() + 120.0
|
||||
while time.time() < deadline:
|
||||
if signal_path.exists():
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
_emit({"event": "no-signal-timeout", "size": size})
|
||||
continue
|
||||
try:
|
||||
seed = int(signal_path.read_text().strip())
|
||||
except Exception as e:
|
||||
_emit({"event": "signal-parse-error", "size": size, "err": repr(e)})
|
||||
continue
|
||||
expected_arr = (ctypes.c_ubyte * size)()
|
||||
_fill_pattern(expected_arr, size, seed)
|
||||
expected_hash = hashlib.sha256(bytes(expected_arr)).hexdigest()
|
||||
recv_bytes = peer.read_bytes(0, size)
|
||||
recv_hash = hashlib.sha256(recv_bytes).hexdigest()
|
||||
ok = recv_hash == expected_hash
|
||||
_emit({
|
||||
"event": "verify",
|
||||
"size": size,
|
||||
"ok": ok,
|
||||
"expected_sha": expected_hash[:16],
|
||||
"got_sha": recv_hash[:16],
|
||||
"first8_recv": recv_bytes[:8].hex(),
|
||||
"last8_recv": recv_bytes[-8:].hex(),
|
||||
})
|
||||
ack_path.write_text("done")
|
||||
|
||||
peer.close()
|
||||
_emit({"event": "receiver-done"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
src/agentic_pd_hybrid/snapshot_link.py
Normal file
266
src/agentic_pd_hybrid/snapshot_link.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Minimal D→P snapshot link over Mooncake RDMA.
|
||||
|
||||
This module provides a thin wrapper around mooncake.engine.TransferEngine
|
||||
for one-sided RDMA writes of KV bytes from a Decode worker (sender) to a
|
||||
Prefill worker (receiver). It deliberately does NOT use the heavyweight
|
||||
MooncakeKVManager pipeline (which is tied to PREFILL/DECODE roles and
|
||||
chunked transfer protocols): we want a simple, testable byte transport
|
||||
that can be reused by SGLang and by stand-alone smoke tests.
|
||||
|
||||
Layout:
|
||||
SnapshotPeer — engine + pre-registered receive buffer (receiver)
|
||||
or sender handle (sender)
|
||||
SnapshotEndpoint — what the receiver advertises so the sender can
|
||||
target it: (session_id, base_ptr, length)
|
||||
SnapshotPusher — sender-side: holds a target endpoint, calls
|
||||
batch_transfer_sync_write
|
||||
|
||||
All transfers are SYNCHRONOUS, single-shot, in-memory.
|
||||
|
||||
Higher layers add: control plane (how D learns P's endpoint), per-session
|
||||
slot allocation, KV format/layout, hand-off into SGLang scheduler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnapshotEndpoint:
|
||||
"""What the receiver advertises so the sender can reach it.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
session_id : str
|
||||
``"host:rpc_port"`` string identifying the receiver's mooncake
|
||||
TransferEngine. Returned by ``TransferEngine.get_rpc_port()``
|
||||
joined with the host the engine was initialized with.
|
||||
base_ptr : int
|
||||
Address of the registered receive buffer on the receiver side.
|
||||
capacity_bytes : int
|
||||
Length of the registered region.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
base_ptr: int
|
||||
capacity_bytes: int
|
||||
|
||||
|
||||
def _import_transfer_engine():
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
"mooncake.engine.TransferEngine is required for snapshot_link. "
|
||||
"Make sure mooncake-transfer-engine is installed in the venv."
|
||||
) from e
|
||||
return TransferEngine
|
||||
|
||||
|
||||
class SnapshotPeer:
|
||||
"""One Mooncake transfer engine endpoint with a registered receive buffer.
|
||||
|
||||
The engine is dedicated to snapshot traffic — it does NOT share state
|
||||
with SGLang's MooncakeKVManager engine. Each SnapshotPeer needs its own
|
||||
host:port to listen on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
ib_device: Optional[str] = None,
|
||||
receive_capacity_bytes: int = 0,
|
||||
protocol: Optional[str] = None,
|
||||
):
|
||||
TransferEngine = _import_transfer_engine()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ib_device = ib_device
|
||||
self.engine = TransferEngine()
|
||||
|
||||
listen = f"{host}:{port}"
|
||||
proto = protocol or os.environ.get("MOONCAKE_PROTOCOL", "rdma")
|
||||
ret = self.engine.initialize(
|
||||
listen,
|
||||
"P2PHANDSHAKE",
|
||||
proto,
|
||||
ib_device or "",
|
||||
)
|
||||
if ret != 0:
|
||||
raise RuntimeError(
|
||||
f"snapshot_link: engine.initialize({listen!r}, proto={proto}, "
|
||||
f"ib={ib_device}) returned {ret}"
|
||||
)
|
||||
|
||||
self._rpc_port = self.engine.get_rpc_port()
|
||||
self._session_id = f"{host}:{self._rpc_port}"
|
||||
|
||||
self._recv_buffer = None
|
||||
self._recv_ptr = 0
|
||||
self._recv_capacity = 0
|
||||
if receive_capacity_bytes > 0:
|
||||
self._allocate_recv_buffer(receive_capacity_bytes)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
logger.info(
|
||||
"SnapshotPeer up at %s (rpc=%d, ib=%s, recv=%d B)",
|
||||
self._session_id,
|
||||
self._rpc_port,
|
||||
ib_device,
|
||||
receive_capacity_bytes,
|
||||
)
|
||||
|
||||
# -- accessors ---------------------------------------------------------
|
||||
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def rpc_port(self) -> int:
|
||||
return self._rpc_port
|
||||
|
||||
@property
|
||||
def endpoint(self) -> SnapshotEndpoint:
|
||||
if self._recv_buffer is None:
|
||||
raise RuntimeError(
|
||||
"SnapshotPeer has no receive buffer; pass receive_capacity_bytes > 0"
|
||||
)
|
||||
return SnapshotEndpoint(
|
||||
session_id=self._session_id,
|
||||
base_ptr=self._recv_ptr,
|
||||
capacity_bytes=self._recv_capacity,
|
||||
)
|
||||
|
||||
# -- buffer management -------------------------------------------------
|
||||
|
||||
def _allocate_recv_buffer(self, length: int) -> None:
|
||||
"""Allocate + register a pinned host buffer for receiving."""
|
||||
# Use c_ubyte (unsigned) so bytes() conversions of the underlying
|
||||
# storage always yield valid byte values.
|
||||
buf = (ctypes.c_ubyte * length)()
|
||||
addr = ctypes.addressof(buf)
|
||||
ret = self.engine.register_memory(addr, length)
|
||||
if ret != 0:
|
||||
raise RuntimeError(
|
||||
f"snapshot_link: register_memory({hex(addr)}, {length}) returned {ret}"
|
||||
)
|
||||
self._recv_buffer = buf
|
||||
self._recv_ptr = addr
|
||||
self._recv_capacity = length
|
||||
|
||||
def read_bytes(self, offset: int, length: int) -> bytes:
|
||||
"""Snapshot the recv buffer at [offset, offset+length) (caller syncs)."""
|
||||
if self._recv_buffer is None:
|
||||
raise RuntimeError("no recv buffer")
|
||||
if offset < 0 or offset + length > self._recv_capacity:
|
||||
raise ValueError(
|
||||
f"read_bytes({offset}, {length}) out of capacity {self._recv_capacity}"
|
||||
)
|
||||
# string_at copies via memcpy and yields a proper bytes object — works
|
||||
# regardless of signed/unsigned underlying storage.
|
||||
return ctypes.string_at(self._recv_ptr + offset, length)
|
||||
|
||||
def register_send_buffer(self, ptr: int, length: int) -> None:
|
||||
"""Register an externally-allocated send buffer for outbound RDMA writes."""
|
||||
with self._lock:
|
||||
ret = self.engine.register_memory(ptr, length)
|
||||
if ret != 0:
|
||||
raise RuntimeError(
|
||||
f"snapshot_link: register send buffer({hex(ptr)}, {length}) returned {ret}"
|
||||
)
|
||||
|
||||
def deregister(self, ptr: int) -> None:
|
||||
with self._lock:
|
||||
try:
|
||||
self.engine.unregister_memory(ptr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# -- transfer ----------------------------------------------------------
|
||||
|
||||
def push(
|
||||
self,
|
||||
target: SnapshotEndpoint,
|
||||
local_ptr: int,
|
||||
local_offset: int,
|
||||
length: int,
|
||||
remote_offset: int = 0,
|
||||
) -> int:
|
||||
"""Synchronously RDMA-write ``length`` bytes from ``local_ptr+local_offset``
|
||||
to ``target.base_ptr+remote_offset`` on the peer identified by
|
||||
``target.session_id``.
|
||||
|
||||
Returns 0 on success, non-zero (or raises) on failure.
|
||||
"""
|
||||
if length <= 0:
|
||||
return 0
|
||||
if remote_offset < 0 or remote_offset + length > target.capacity_bytes:
|
||||
raise ValueError(
|
||||
f"push: remote_offset={remote_offset}, length={length} exceeds "
|
||||
f"target capacity {target.capacity_bytes}"
|
||||
)
|
||||
src = local_ptr + local_offset
|
||||
dst = target.base_ptr + remote_offset
|
||||
try:
|
||||
ret = self.engine.transfer_sync_write(
|
||||
target.session_id, src, dst, length
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_link.push transfer_sync_write threw: %s", e)
|
||||
return -1
|
||||
if ret != 0:
|
||||
logger.warning(
|
||||
"snapshot_link.push transfer_sync_write returned %d (src=%s, "
|
||||
"dst=%s/%s, len=%d)",
|
||||
ret,
|
||||
hex(src),
|
||||
target.session_id,
|
||||
hex(dst),
|
||||
length,
|
||||
)
|
||||
return ret
|
||||
|
||||
def batch_push(
|
||||
self,
|
||||
target: SnapshotEndpoint,
|
||||
local_addrs: list[int],
|
||||
remote_addrs: list[int],
|
||||
lengths: list[int],
|
||||
) -> int:
|
||||
"""Batched RDMA write (one-shot)."""
|
||||
if not local_addrs:
|
||||
return 0
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
target.session_id, local_addrs, remote_addrs, lengths
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_link.batch_push threw: %s", e)
|
||||
return -1
|
||||
return ret
|
||||
|
||||
def close(self) -> None:
|
||||
"""Best-effort shutdown — release the receive buffer registration."""
|
||||
if self._recv_ptr:
|
||||
try:
|
||||
self.engine.unregister_memory(self._recv_ptr)
|
||||
except Exception:
|
||||
pass
|
||||
self._recv_ptr = 0
|
||||
self._recv_capacity = 0
|
||||
self._recv_buffer = None
|
||||
|
||||
|
||||
def make_session_id(host: str, rpc_port: int) -> str:
|
||||
"""Build the ``host:port`` form used as mooncake's session id."""
|
||||
return f"{host}:{rpc_port}"
|
||||
Reference in New Issue
Block a user