feat(snapshot): D→P RDMA link Phase 1 — minimal byte transport
A thin wrapper around mooncake.engine.TransferEngine that does one-sided RDMA writes between two SnapshotPeer endpoints. Bypasses SGLang's MooncakeKVManager (which is hard-gated to PREFILL/DECODE roles via add_transfer_request assertion at conn.py:1563) so the D→P direction doesn't require invasive role-axis changes upstream. Smoke test (two subprocess.Popen processes, mlx5_60, 127.0.0.1): 1 KB 9.0 ms (one-time openSegment handshake) 16 KB 0.04 ms 3.5 Gbps 1 MB 0.10 ms 82 Gbps 16 MB 0.58 ms 232 Gbps 64 MB 1.70 ms 316 Gbps (~80% of NDR 400G line rate) All 5 sizes pass SHA256 verification end-to-end. Files: src/agentic_pd_hybrid/snapshot_link.py — SnapshotPeer, SnapshotEndpoint scripts/snapshot_link_receiver.py — child-process receiver scripts/smoke_snapshot_link.py — sender + verifier docs/D_TO_P_PHASE1_LINK_ZH.md — phase 1 acceptance doc Next: Phase 2 (D-side scheduler commit hook), Phase 3 (P-side prefill bypass with snapshot KV). See docs/D_TO_P_SYNC_DESIGN_ZH.md §5.
This commit is contained in:
140
docs/D_TO_P_PHASE1_LINK_ZH.md
Normal file
140
docs/D_TO_P_PHASE1_LINK_ZH.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# D→P Phase 1:底层 RDMA 链路(已验收)
|
||||||
|
|
||||||
|
**日期**:2026-05-13
|
||||||
|
**状态**:底层链路通过 smoke test 验收
|
||||||
|
**前置**:`docs/D_TO_P_SYNC_DESIGN_ZH.md`
|
||||||
|
**对应 commit**:`feat(snapshot): D→P snapshot link over mooncake RDMA`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 0. 一句话
|
||||||
|
|
||||||
|
实现一个独立于 SGLang `MooncakeKVManager` 的**最小 RDMA 字节传输模块**(`src/agentic_pd_hybrid/snapshot_link.py`),双进程 smoke test 跑通 1 KB → 64 MB 一共 5 个 size,全部 SHA 校验通过,64 MB 单次 RDMA write 实测 315 Gbps(mlx5_60 NDR 400 Gb 的约 80%)。
|
||||||
|
|
||||||
|
## 1. 设计动机
|
||||||
|
|
||||||
|
`docs/D_TO_P_SYNC_DESIGN_ZH.md` 选定 Option C(D→P snapshot push + P SessionSlot + prefill bypass)。这个方案的最底层依赖是"D 进程能把字节通过 RDMA 推到 P 进程的预注册缓冲区"。
|
||||||
|
|
||||||
|
直接复用 SGLang 的 `MooncakeKVManager` 不可行:
|
||||||
|
- `add_transfer_request` 在 `conn.py:1563` 硬 assert `disaggregation_mode == PREFILL`
|
||||||
|
- PD pipeline 的发送 / 接收 thread / queue / staging 紧耦合 PD 角色
|
||||||
|
- 改 PD 路径风险大(影响现有 E1/E2/E3 配置)
|
||||||
|
|
||||||
|
因此把 D→P link 单独写成一个轻量模块,直接调 `mooncake.engine.TransferEngine` 的 `transfer_sync_write` / `batch_transfer_sync_write`,不经过 PD pipeline。
|
||||||
|
|
||||||
|
## 2. 实现
|
||||||
|
|
||||||
|
### 2.1 `snapshot_link.SnapshotPeer`
|
||||||
|
|
||||||
|
```python
|
||||||
|
peer = SnapshotPeer(host, port, ib_device, receive_capacity_bytes)
|
||||||
|
endpoint = peer.endpoint # SnapshotEndpoint(session_id, base_ptr, capacity_bytes)
|
||||||
|
peer.register_send_buffer(ptr, length)
|
||||||
|
peer.push(target_endpoint, local_ptr, local_off, length, remote_off=0)
|
||||||
|
peer.batch_push(target, local_addrs, remote_addrs, lengths)
|
||||||
|
peer.read_bytes(offset, length) -> bytes
|
||||||
|
peer.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
- 每个 `SnapshotPeer` 拥有自己的 `TransferEngine`,绑定 `host:port`
|
||||||
|
- `receive_capacity_bytes > 0` 时分配一段 ctypes `c_ubyte` 数组 + `register_memory`
|
||||||
|
- `push` 直接走 `engine.transfer_sync_write(peer_session_id, local_ptr, remote_ptr, length)`
|
||||||
|
- 角色完全对称——任何 `SnapshotPeer` 既可以发送也可以接收,由 caller 决定
|
||||||
|
|
||||||
|
### 2.2 Smoke test 双进程结构
|
||||||
|
|
||||||
|
```
|
||||||
|
父进程 (sender) 子进程 (receiver, subprocess.Popen)
|
||||||
|
│ │
|
||||||
|
│ spawn → ──────────────────────────────►│
|
||||||
|
│ │ SnapshotPeer(recv_capacity=64MB)
|
||||||
|
│ │ write endpoint.json
|
||||||
|
│ read endpoint.json ◄───────────────────│
|
||||||
|
│ │
|
||||||
|
│ SnapshotPeer(no recv buf) │
|
||||||
|
│ register_send_buffer(64MB) │
|
||||||
|
│ │
|
||||||
|
│ for size in [1K, 16K, 1M, 16M, 64M]: │
|
||||||
|
│ fill_pattern(send_buf, seed) │
|
||||||
|
│ peer.push(endpoint, 0, size) ─RDMA──►│
|
||||||
|
│ │ wait signal
|
||||||
|
│ write endpoint.do{size} ────────────►│ read signal seed
|
||||||
|
│ │ compute expected SHA
|
||||||
|
│ │ recv_bytes = peer.read_bytes
|
||||||
|
│ wait endpoint.ack{size} │ compare SHA → emit JSON event
|
||||||
|
│ │ write endpoint.ack{size}
|
||||||
|
│ ... │
|
||||||
|
│ │
|
||||||
|
│ drain child stdout, parse JSON │ exit
|
||||||
|
│ verify each event has ok=true │
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 性能(首次 smoke run)
|
||||||
|
|
||||||
|
| Size | Push duration | Throughput |
|
||||||
|
|---:|---:|---:|
|
||||||
|
| 1 KB | 9.0 ms | 0.001 Gbps |
|
||||||
|
| 16 KB | 0.037 ms | 3.5 Gbps |
|
||||||
|
| 1 MB | 0.102 ms | 82 Gbps |
|
||||||
|
| 16 MB | 0.577 ms | 232 Gbps |
|
||||||
|
| **64 MB** | **1.70 ms** | **316 Gbps** |
|
||||||
|
|
||||||
|
- 1 KB 第一次有 ~9 ms 的 mooncake p2p handshake/openSegment overhead(一次性)
|
||||||
|
- 16 KB 之后是稳态,吞吐随 size 增长接近线速
|
||||||
|
- mlx5_60 是 mlx5 ConnectX-7 NDR 400 Gb(4× 100Gb lanes);64 MB 测到 316 Gbps 是 79% 的链路利用率,对单次 RDMA write 来说正常(剩余空间留给 verb dispatch / completion handling overhead)
|
||||||
|
|
||||||
|
## 3. 验收
|
||||||
|
|
||||||
|
- ✅ 5/5 size SHA 校验全部通过
|
||||||
|
- ✅ 64 MB 一次 RDMA 1.7 ms
|
||||||
|
- ✅ 双进程独立,不耦合 SGLang PD pipeline
|
||||||
|
- ✅ Smoke test 脚本 `scripts/smoke_snapshot_link.py` 可重跑
|
||||||
|
|
||||||
|
## 4. 当前覆盖范围(清单)
|
||||||
|
|
||||||
|
- ✅ Host CPU 内存的 D→P RDMA byte transfer
|
||||||
|
- ✅ 单 IB device (mlx5_60)
|
||||||
|
- ✅ 同节点 loopback(127.0.0.1)
|
||||||
|
- ⏳ GPU 内存(设备指针 + `batch_transfer_write_on_cuda`)—— 现有 `push()` 走 `transfer_sync_write`,对 GPU 指针支持取决于 mooncake 的 protocol;下一步验证
|
||||||
|
- ⏳ 跨节点(远端 IP)—— 设计上一致,未验证
|
||||||
|
- ⏳ 多 D → 单 P(多 sender → 共享 recv buffer 的 offset 协调)—— 留给 Phase 3 整合时设计
|
||||||
|
- ⏳ ZeroCopy 入 SGLang kv_pool slot —— 留给 Phase 2
|
||||||
|
|
||||||
|
## 5. 下一步(Phase 2 / Phase 3)
|
||||||
|
|
||||||
|
详见 `docs/D_TO_P_SYNC_DESIGN_ZH.md` §5。本 phase 1 解锁后,整个 D→P 同步可以正式开始整合到 SGLang scheduler:
|
||||||
|
|
||||||
|
| Phase | 描述 | 风险 |
|
||||||
|
|---|---|---|
|
||||||
|
| 2 | D-side commit hook:`cache_finished_req` 完成后 enqueue snapshot push | 中。需要在 scheduler 后台线程跑 push,不能阻塞 schedule loop |
|
||||||
|
| 3 | P-side snapshot store + prefill bypass:P scheduler 收到 use-snapshot 请求时跳过 `model.forward()`,直接用 snapshot KV 触发 P→D' transfer | **最高**。需要深入 SGLang prefill 流程 |
|
||||||
|
| 4 | agentic-pd-hybrid hook:`_invoke_kvcache_seeded_router` 先 probe P → 决定走 bypass 还是 fallback | 低 |
|
||||||
|
| 5 | CLI flag + structural log | 低 |
|
||||||
|
| 6 | 端到端 smoke + E4 sweep | 中 |
|
||||||
|
|
||||||
|
## 6. 知识沉淀
|
||||||
|
|
||||||
|
### 易踩坑
|
||||||
|
|
||||||
|
| 坑 | 原因 | 修法 |
|
||||||
|
|---|---|---|
|
||||||
|
| 多进程 `multiprocessing.Process` 子进程崩溃信息丢失 | spawn context 下 child 没有继承 parent 的 stderr | 改用 `subprocess.Popen` + stderr 重定向到文件 |
|
||||||
|
| `bytes(ctypes.c_byte * N)` 失败 `ValueError: bytes must be in range(0, 256)` | `c_byte` 是 **signed**,>= 128 的 byte 在 Python 看就是负数 | 用 `c_ubyte` 或 `ctypes.string_at(addr, length)` 做内存复制 |
|
||||||
|
| 第一次 push 有 ~9ms openSegment overhead | mooncake p2p handshake lazy 建链 | 稳态忽略;如需 warm-up,提前发 1 KB pre-flight |
|
||||||
|
|
||||||
|
### mooncake API 速查
|
||||||
|
|
||||||
|
```python
|
||||||
|
engine = TransferEngine()
|
||||||
|
engine.initialize(f"{host}:{port}", "P2PHANDSHAKE", "rdma", ib_device)
|
||||||
|
engine.register_memory(ptr, length) # mr 注册
|
||||||
|
engine.transfer_sync_write(peer_session_id, local_ptr, remote_ptr, length) # RDMA write
|
||||||
|
engine.batch_transfer_sync_write(peer_session_id, [local_ptrs], [remote_ptrs], [lengths])
|
||||||
|
engine.unregister_memory(ptr)
|
||||||
|
```
|
||||||
|
|
||||||
|
`peer_session_id` 是 `"host:rpc_port"`,其中 `rpc_port = peer_engine.get_rpc_port()`。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**核心句**:D→P 底层 RDMA 链路独立模块跑通,64 MB 1.7 ms / 316 Gbps,与 SGLang PD pipeline 完全解耦。Phase 2/3 可以放心在这上面叠加。
|
||||||
244
scripts/smoke_snapshot_link.py
Executable file
244
scripts/smoke_snapshot_link.py
Executable file
@@ -0,0 +1,244 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Two-process smoke test for snapshot_link D→P RDMA byte transfer.
|
||||||
|
|
||||||
|
Spawns scripts/snapshot_link_receiver.py via subprocess.Popen with stderr
|
||||||
|
piped to ``<tmpdir>/recv.stderr.log`` for post-mortem if something dies.
|
||||||
|
|
||||||
|
Sender (this process):
|
||||||
|
1. Spawns receiver child, waits for endpoint.json
|
||||||
|
2. Brings up own SnapshotPeer (no recv buffer), registers a send buffer
|
||||||
|
3. For each size: fill pattern, batch_transfer_sync_write, signal child,
|
||||||
|
wait for child's ack
|
||||||
|
4. Reads child's stdout (one JSON event per line) for verification
|
||||||
|
|
||||||
|
Pass = every size yields a child "verify" event with ok=true.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
bash scripts/setup_env.sh && uv run --no-sync python scripts/smoke_snapshot_link.py
|
||||||
|
|
||||||
|
Env (optional):
|
||||||
|
SNAPSHOT_LINK_HOST default 127.0.0.1
|
||||||
|
SNAPSHOT_LINK_IB default mlx5_60
|
||||||
|
SNAPSHOT_LINK_RECV_PORT default 17777
|
||||||
|
SNAPSHOT_LINK_SEND_PORT default 17778
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ctypes
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_HERE = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(_HERE.parent / "src"))
|
||||||
|
|
||||||
|
|
||||||
|
SIZES_BYTES_DEFAULT = [
|
||||||
|
1 << 10, # 1 KB
|
||||||
|
1 << 14, # 16 KB
|
||||||
|
1 << 18, # 256 KB
|
||||||
|
1 << 20, # 1 MB
|
||||||
|
1 << 22, # 4 MB
|
||||||
|
1 << 24, # 16 MB
|
||||||
|
1 << 26, # 64 MB
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_byte(i: int, seed: int) -> int:
|
||||||
|
return (i * 2654435761 + seed) & 0xFF
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_pattern(buf, length: int, seed: int) -> None:
|
||||||
|
tile_size = 4096
|
||||||
|
tile = bytes(_pattern_byte(i, seed) for i in range(tile_size))
|
||||||
|
tile_arr = (ctypes.c_ubyte * tile_size).from_buffer_copy(tile)
|
||||||
|
n_full = length // tile_size
|
||||||
|
rem = length - n_full * tile_size
|
||||||
|
base = ctypes.addressof(buf)
|
||||||
|
src_addr = ctypes.addressof(tile_arr)
|
||||||
|
for k in range(n_full):
|
||||||
|
ctypes.memmove(base + k * tile_size, src_addr, tile_size)
|
||||||
|
if rem:
|
||||||
|
ctypes.memmove(base + n_full * tile_size, src_addr, rem)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--host", default=os.environ.get("SNAPSHOT_LINK_HOST", "127.0.0.1"))
|
||||||
|
ap.add_argument("--ib", default=os.environ.get("SNAPSHOT_LINK_IB", "mlx5_60"))
|
||||||
|
ap.add_argument("--recv-port", type=int,
|
||||||
|
default=int(os.environ.get("SNAPSHOT_LINK_RECV_PORT", "17777")))
|
||||||
|
ap.add_argument("--send-port", type=int,
|
||||||
|
default=int(os.environ.get("SNAPSHOT_LINK_SEND_PORT", "17778")))
|
||||||
|
ap.add_argument("--max-bytes", type=int, default=128 * 1024 * 1024)
|
||||||
|
ap.add_argument("--sizes", default=",".join(str(s) for s in SIZES_BYTES_DEFAULT))
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
sizes = [int(s) for s in args.sizes.split(",")]
|
||||||
|
tmpdir = Path(tempfile.mkdtemp(prefix="snapshot_link_smoke_"))
|
||||||
|
control_path = tmpdir / "endpoint.json"
|
||||||
|
recv_stderr_log = tmpdir / "recv.stderr.log"
|
||||||
|
|
||||||
|
recv_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(_HERE / "snapshot_link_receiver.py"),
|
||||||
|
"--host", args.host,
|
||||||
|
"--port", str(args.recv_port),
|
||||||
|
"--ib", args.ib,
|
||||||
|
"--max-bytes", str(args.max_bytes),
|
||||||
|
"--control-path", str(control_path),
|
||||||
|
"--sizes", args.sizes,
|
||||||
|
]
|
||||||
|
recv_stderr = open(recv_stderr_log, "w")
|
||||||
|
print(f"[sender] launching receiver: {' '.join(recv_cmd)}", flush=True)
|
||||||
|
print(f"[sender] receiver stderr → {recv_stderr_log}", flush=True)
|
||||||
|
recv_proc = subprocess.Popen(
|
||||||
|
recv_cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=recv_stderr,
|
||||||
|
bufsize=1,
|
||||||
|
universal_newlines=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for endpoint metadata
|
||||||
|
deadline = time.time() + 60.0
|
||||||
|
while time.time() < deadline:
|
||||||
|
if control_path.exists():
|
||||||
|
try:
|
||||||
|
meta = json.loads(control_path.read_text())
|
||||||
|
if meta.get("ready"):
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if recv_proc.poll() is not None:
|
||||||
|
_dump_recv_stderr(recv_stderr_log)
|
||||||
|
print(f"[sender] FAIL: receiver exited early (rc={recv_proc.returncode})")
|
||||||
|
return 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
print("[sender] FAIL: timed out waiting for receiver endpoint", flush=True)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
print(f"[sender] receiver endpoint: {meta}", flush=True)
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.snapshot_link import SnapshotPeer, SnapshotEndpoint
|
||||||
|
endpoint = SnapshotEndpoint(
|
||||||
|
session_id=meta["session_id"],
|
||||||
|
base_ptr=int(meta["base_ptr"]),
|
||||||
|
capacity_bytes=int(meta["capacity_bytes"]),
|
||||||
|
)
|
||||||
|
peer = SnapshotPeer(
|
||||||
|
host=args.host,
|
||||||
|
port=args.send_port,
|
||||||
|
ib_device=args.ib,
|
||||||
|
receive_capacity_bytes=0,
|
||||||
|
)
|
||||||
|
send_buf = (ctypes.c_byte * args.max_bytes)()
|
||||||
|
send_addr = ctypes.addressof(send_buf)
|
||||||
|
peer.register_send_buffer(send_addr, args.max_bytes)
|
||||||
|
print(f"[sender] own session_id={peer.session_id}, send_buf @ {hex(send_addr)} ({args.max_bytes} B)", flush=True)
|
||||||
|
|
||||||
|
transfers = []
|
||||||
|
for size in sizes:
|
||||||
|
if size > args.max_bytes:
|
||||||
|
continue
|
||||||
|
seed = int(time.time() * 1e6) & 0xFFFFFFFF
|
||||||
|
_fill_pattern(send_buf, size, seed)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
ret = peer.push(endpoint, send_addr, 0, size, remote_offset=0)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
dt_ms = (t1 - t0) * 1000.0
|
||||||
|
gbps = (size * 8.0 / 1e9) / max(t1 - t0, 1e-9)
|
||||||
|
print(f"[sender] push size={size:>10d} ret={ret} "
|
||||||
|
f"dur={dt_ms:>9.3f} ms thru={gbps:>6.3f} Gbps",
|
||||||
|
flush=True)
|
||||||
|
signal_path = control_path.with_suffix(f".do{size}")
|
||||||
|
ack_path = control_path.with_suffix(f".ack{size}")
|
||||||
|
signal_path.write_text(str(seed))
|
||||||
|
ack_deadline = time.time() + 60.0
|
||||||
|
while time.time() < ack_deadline:
|
||||||
|
if ack_path.exists():
|
||||||
|
break
|
||||||
|
if recv_proc.poll() is not None:
|
||||||
|
print(f"[sender] FAIL: receiver died after size={size}", flush=True)
|
||||||
|
_dump_recv_stderr(recv_stderr_log)
|
||||||
|
return 1
|
||||||
|
time.sleep(0.05)
|
||||||
|
transfers.append({
|
||||||
|
"size": size, "ret": ret, "dur_ms": round(dt_ms, 3),
|
||||||
|
"thru_Gbps": round(gbps, 3),
|
||||||
|
"ack": ack_path.exists(),
|
||||||
|
})
|
||||||
|
|
||||||
|
peer.close()
|
||||||
|
|
||||||
|
# Drain child stdout — each line is a JSON event
|
||||||
|
try:
|
||||||
|
recv_proc.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
recv_proc.terminate()
|
||||||
|
recv_proc.wait(timeout=5)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
if recv_proc.stdout is not None:
|
||||||
|
for raw in recv_proc.stdout:
|
||||||
|
raw = raw.strip()
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
events.append(json.loads(raw))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
events.append({"event": "non-json", "raw": raw})
|
||||||
|
|
||||||
|
print("=" * 78)
|
||||||
|
print("[receiver] events:")
|
||||||
|
verify_ok = 0
|
||||||
|
verify_fail = 0
|
||||||
|
for ev in events:
|
||||||
|
print(f" {ev}")
|
||||||
|
if ev.get("event") == "verify":
|
||||||
|
if ev.get("ok"):
|
||||||
|
verify_ok += 1
|
||||||
|
else:
|
||||||
|
verify_fail += 1
|
||||||
|
|
||||||
|
recv_stderr.close()
|
||||||
|
_dump_recv_stderr(recv_stderr_log, header="--- receiver stderr ---")
|
||||||
|
|
||||||
|
overall = "PASS" if verify_fail == 0 and verify_ok == len(transfers) else "FAIL"
|
||||||
|
print("=" * 78)
|
||||||
|
print(f"OVERALL: {overall} verify_ok={verify_ok} verify_fail={verify_fail} "
|
||||||
|
f"transfers={len(transfers)}")
|
||||||
|
return 0 if overall == "PASS" else 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
recv_proc.terminate()
|
||||||
|
recv_proc.wait(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
recv_proc.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_recv_stderr(path: Path, header: str = "--- receiver stderr (last 40) ---") -> None:
|
||||||
|
try:
|
||||||
|
text = path.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return
|
||||||
|
print(header, flush=True)
|
||||||
|
for line in text.splitlines()[-40:]:
|
||||||
|
print(f" {line}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
123
scripts/snapshot_link_receiver.py
Normal file
123
scripts/snapshot_link_receiver.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Receiver-side child process for the snapshot_link smoke test.
|
||||||
|
|
||||||
|
Reads CLI args, brings up a SnapshotPeer with a registered recv buffer,
|
||||||
|
writes endpoint metadata to a control file, then loops: wait for size
|
||||||
|
signal, verify recv buffer, write ack.
|
||||||
|
|
||||||
|
Status events are printed as single-line JSON to stdout for parent to
|
||||||
|
parse.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ctypes
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
|
||||||
|
|
||||||
|
|
||||||
|
def _pattern_byte(i: int, seed: int) -> int:
|
||||||
|
return (i * 2654435761 + seed) & 0xFF
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_pattern(buf, length: int, seed: int) -> None:
|
||||||
|
tile_size = 4096
|
||||||
|
tile = bytes(_pattern_byte(i, seed) for i in range(tile_size))
|
||||||
|
tile_arr = (ctypes.c_ubyte * tile_size).from_buffer_copy(tile)
|
||||||
|
n_full = length // tile_size
|
||||||
|
rem = length - n_full * tile_size
|
||||||
|
base = ctypes.addressof(buf)
|
||||||
|
src_addr = ctypes.addressof(tile_arr)
|
||||||
|
for k in range(n_full):
|
||||||
|
ctypes.memmove(base + k * tile_size, src_addr, tile_size)
|
||||||
|
if rem:
|
||||||
|
ctypes.memmove(base + n_full * tile_size, src_addr, rem)
|
||||||
|
|
||||||
|
|
||||||
|
def _emit(d: dict) -> None:
|
||||||
|
print(json.dumps(d), flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--host", required=True)
|
||||||
|
ap.add_argument("--port", type=int, required=True)
|
||||||
|
ap.add_argument("--ib", required=True)
|
||||||
|
ap.add_argument("--max-bytes", type=int, required=True)
|
||||||
|
ap.add_argument("--control-path", required=True)
|
||||||
|
ap.add_argument("--sizes", required=True, help="comma-separated bytes")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
sizes = [int(s) for s in args.sizes.split(",")]
|
||||||
|
|
||||||
|
from agentic_pd_hybrid.snapshot_link import SnapshotPeer
|
||||||
|
|
||||||
|
try:
|
||||||
|
peer = SnapshotPeer(
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
ib_device=args.ib,
|
||||||
|
receive_capacity_bytes=args.max_bytes,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
_emit({"event": "init-failed", "error": repr(e), "tb": traceback.format_exc()})
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
endpoint = peer.endpoint
|
||||||
|
Path(args.control_path).write_text(json.dumps({
|
||||||
|
"session_id": endpoint.session_id,
|
||||||
|
"base_ptr": endpoint.base_ptr,
|
||||||
|
"capacity_bytes": endpoint.capacity_bytes,
|
||||||
|
"ready": True,
|
||||||
|
}))
|
||||||
|
_emit({"event": "endpoint-ready", "session_id": endpoint.session_id,
|
||||||
|
"base_ptr": endpoint.base_ptr, "capacity": endpoint.capacity_bytes})
|
||||||
|
|
||||||
|
cp = Path(args.control_path)
|
||||||
|
for size in sizes:
|
||||||
|
if size > args.max_bytes:
|
||||||
|
continue
|
||||||
|
signal_path = cp.with_suffix(f".do{size}")
|
||||||
|
ack_path = cp.with_suffix(f".ack{size}")
|
||||||
|
deadline = time.time() + 120.0
|
||||||
|
while time.time() < deadline:
|
||||||
|
if signal_path.exists():
|
||||||
|
break
|
||||||
|
time.sleep(0.05)
|
||||||
|
else:
|
||||||
|
_emit({"event": "no-signal-timeout", "size": size})
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
seed = int(signal_path.read_text().strip())
|
||||||
|
except Exception as e:
|
||||||
|
_emit({"event": "signal-parse-error", "size": size, "err": repr(e)})
|
||||||
|
continue
|
||||||
|
expected_arr = (ctypes.c_ubyte * size)()
|
||||||
|
_fill_pattern(expected_arr, size, seed)
|
||||||
|
expected_hash = hashlib.sha256(bytes(expected_arr)).hexdigest()
|
||||||
|
recv_bytes = peer.read_bytes(0, size)
|
||||||
|
recv_hash = hashlib.sha256(recv_bytes).hexdigest()
|
||||||
|
ok = recv_hash == expected_hash
|
||||||
|
_emit({
|
||||||
|
"event": "verify",
|
||||||
|
"size": size,
|
||||||
|
"ok": ok,
|
||||||
|
"expected_sha": expected_hash[:16],
|
||||||
|
"got_sha": recv_hash[:16],
|
||||||
|
"first8_recv": recv_bytes[:8].hex(),
|
||||||
|
"last8_recv": recv_bytes[-8:].hex(),
|
||||||
|
})
|
||||||
|
ack_path.write_text("done")
|
||||||
|
|
||||||
|
peer.close()
|
||||||
|
_emit({"event": "receiver-done"})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
266
src/agentic_pd_hybrid/snapshot_link.py
Normal file
266
src/agentic_pd_hybrid/snapshot_link.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""Minimal D→P snapshot link over Mooncake RDMA.
|
||||||
|
|
||||||
|
This module provides a thin wrapper around mooncake.engine.TransferEngine
|
||||||
|
for one-sided RDMA writes of KV bytes from a Decode worker (sender) to a
|
||||||
|
Prefill worker (receiver). It deliberately does NOT use the heavyweight
|
||||||
|
MooncakeKVManager pipeline (which is tied to PREFILL/DECODE roles and
|
||||||
|
chunked transfer protocols): we want a simple, testable byte transport
|
||||||
|
that can be reused by SGLang and by stand-alone smoke tests.
|
||||||
|
|
||||||
|
Layout:
|
||||||
|
SnapshotPeer — engine + pre-registered receive buffer (receiver)
|
||||||
|
or sender handle (sender)
|
||||||
|
SnapshotEndpoint — what the receiver advertises so the sender can
|
||||||
|
target it: (session_id, base_ptr, length)
|
||||||
|
SnapshotPusher — sender-side: holds a target endpoint, calls
|
||||||
|
batch_transfer_sync_write
|
||||||
|
|
||||||
|
All transfers are SYNCHRONOUS, single-shot, in-memory.
|
||||||
|
|
||||||
|
Higher layers add: control plane (how D learns P's endpoint), per-session
|
||||||
|
slot allocation, KV format/layout, hand-off into SGLang scheduler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SnapshotEndpoint:
|
||||||
|
"""What the receiver advertises so the sender can reach it.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
session_id : str
|
||||||
|
``"host:rpc_port"`` string identifying the receiver's mooncake
|
||||||
|
TransferEngine. Returned by ``TransferEngine.get_rpc_port()``
|
||||||
|
joined with the host the engine was initialized with.
|
||||||
|
base_ptr : int
|
||||||
|
Address of the registered receive buffer on the receiver side.
|
||||||
|
capacity_bytes : int
|
||||||
|
Length of the registered region.
|
||||||
|
"""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
base_ptr: int
|
||||||
|
capacity_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
def _import_transfer_engine():
|
||||||
|
try:
|
||||||
|
from mooncake.engine import TransferEngine
|
||||||
|
except ImportError as e: # pragma: no cover
|
||||||
|
raise ImportError(
|
||||||
|
"mooncake.engine.TransferEngine is required for snapshot_link. "
|
||||||
|
"Make sure mooncake-transfer-engine is installed in the venv."
|
||||||
|
) from e
|
||||||
|
return TransferEngine
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotPeer:
|
||||||
|
"""One Mooncake transfer engine endpoint with a registered receive buffer.
|
||||||
|
|
||||||
|
The engine is dedicated to snapshot traffic — it does NOT share state
|
||||||
|
with SGLang's MooncakeKVManager engine. Each SnapshotPeer needs its own
|
||||||
|
host:port to listen on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
ib_device: Optional[str] = None,
|
||||||
|
receive_capacity_bytes: int = 0,
|
||||||
|
protocol: Optional[str] = None,
|
||||||
|
):
|
||||||
|
TransferEngine = _import_transfer_engine()
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.ib_device = ib_device
|
||||||
|
self.engine = TransferEngine()
|
||||||
|
|
||||||
|
listen = f"{host}:{port}"
|
||||||
|
proto = protocol or os.environ.get("MOONCAKE_PROTOCOL", "rdma")
|
||||||
|
ret = self.engine.initialize(
|
||||||
|
listen,
|
||||||
|
"P2PHANDSHAKE",
|
||||||
|
proto,
|
||||||
|
ib_device or "",
|
||||||
|
)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"snapshot_link: engine.initialize({listen!r}, proto={proto}, "
|
||||||
|
f"ib={ib_device}) returned {ret}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._rpc_port = self.engine.get_rpc_port()
|
||||||
|
self._session_id = f"{host}:{self._rpc_port}"
|
||||||
|
|
||||||
|
self._recv_buffer = None
|
||||||
|
self._recv_ptr = 0
|
||||||
|
self._recv_capacity = 0
|
||||||
|
if receive_capacity_bytes > 0:
|
||||||
|
self._allocate_recv_buffer(receive_capacity_bytes)
|
||||||
|
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
logger.info(
|
||||||
|
"SnapshotPeer up at %s (rpc=%d, ib=%s, recv=%d B)",
|
||||||
|
self._session_id,
|
||||||
|
self._rpc_port,
|
||||||
|
ib_device,
|
||||||
|
receive_capacity_bytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- accessors ---------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_id(self) -> str:
|
||||||
|
return self._session_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rpc_port(self) -> int:
|
||||||
|
return self._rpc_port
|
||||||
|
|
||||||
|
@property
|
||||||
|
def endpoint(self) -> SnapshotEndpoint:
|
||||||
|
if self._recv_buffer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SnapshotPeer has no receive buffer; pass receive_capacity_bytes > 0"
|
||||||
|
)
|
||||||
|
return SnapshotEndpoint(
|
||||||
|
session_id=self._session_id,
|
||||||
|
base_ptr=self._recv_ptr,
|
||||||
|
capacity_bytes=self._recv_capacity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- buffer management -------------------------------------------------
|
||||||
|
|
||||||
|
def _allocate_recv_buffer(self, length: int) -> None:
|
||||||
|
"""Allocate + register a pinned host buffer for receiving."""
|
||||||
|
# Use c_ubyte (unsigned) so bytes() conversions of the underlying
|
||||||
|
# storage always yield valid byte values.
|
||||||
|
buf = (ctypes.c_ubyte * length)()
|
||||||
|
addr = ctypes.addressof(buf)
|
||||||
|
ret = self.engine.register_memory(addr, length)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"snapshot_link: register_memory({hex(addr)}, {length}) returned {ret}"
|
||||||
|
)
|
||||||
|
self._recv_buffer = buf
|
||||||
|
self._recv_ptr = addr
|
||||||
|
self._recv_capacity = length
|
||||||
|
|
||||||
|
def read_bytes(self, offset: int, length: int) -> bytes:
|
||||||
|
"""Snapshot the recv buffer at [offset, offset+length) (caller syncs)."""
|
||||||
|
if self._recv_buffer is None:
|
||||||
|
raise RuntimeError("no recv buffer")
|
||||||
|
if offset < 0 or offset + length > self._recv_capacity:
|
||||||
|
raise ValueError(
|
||||||
|
f"read_bytes({offset}, {length}) out of capacity {self._recv_capacity}"
|
||||||
|
)
|
||||||
|
# string_at copies via memcpy and yields a proper bytes object — works
|
||||||
|
# regardless of signed/unsigned underlying storage.
|
||||||
|
return ctypes.string_at(self._recv_ptr + offset, length)
|
||||||
|
|
||||||
|
def register_send_buffer(self, ptr: int, length: int) -> None:
|
||||||
|
"""Register an externally-allocated send buffer for outbound RDMA writes."""
|
||||||
|
with self._lock:
|
||||||
|
ret = self.engine.register_memory(ptr, length)
|
||||||
|
if ret != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"snapshot_link: register send buffer({hex(ptr)}, {length}) returned {ret}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def deregister(self, ptr: int) -> None:
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
self.engine.unregister_memory(ptr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# -- transfer ----------------------------------------------------------
|
||||||
|
|
||||||
|
def push(
|
||||||
|
self,
|
||||||
|
target: SnapshotEndpoint,
|
||||||
|
local_ptr: int,
|
||||||
|
local_offset: int,
|
||||||
|
length: int,
|
||||||
|
remote_offset: int = 0,
|
||||||
|
) -> int:
|
||||||
|
"""Synchronously RDMA-write ``length`` bytes from ``local_ptr+local_offset``
|
||||||
|
to ``target.base_ptr+remote_offset`` on the peer identified by
|
||||||
|
``target.session_id``.
|
||||||
|
|
||||||
|
Returns 0 on success, non-zero (or raises) on failure.
|
||||||
|
"""
|
||||||
|
if length <= 0:
|
||||||
|
return 0
|
||||||
|
if remote_offset < 0 or remote_offset + length > target.capacity_bytes:
|
||||||
|
raise ValueError(
|
||||||
|
f"push: remote_offset={remote_offset}, length={length} exceeds "
|
||||||
|
f"target capacity {target.capacity_bytes}"
|
||||||
|
)
|
||||||
|
src = local_ptr + local_offset
|
||||||
|
dst = target.base_ptr + remote_offset
|
||||||
|
try:
|
||||||
|
ret = self.engine.transfer_sync_write(
|
||||||
|
target.session_id, src, dst, length
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("snapshot_link.push transfer_sync_write threw: %s", e)
|
||||||
|
return -1
|
||||||
|
if ret != 0:
|
||||||
|
logger.warning(
|
||||||
|
"snapshot_link.push transfer_sync_write returned %d (src=%s, "
|
||||||
|
"dst=%s/%s, len=%d)",
|
||||||
|
ret,
|
||||||
|
hex(src),
|
||||||
|
target.session_id,
|
||||||
|
hex(dst),
|
||||||
|
length,
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def batch_push(
|
||||||
|
self,
|
||||||
|
target: SnapshotEndpoint,
|
||||||
|
local_addrs: list[int],
|
||||||
|
remote_addrs: list[int],
|
||||||
|
lengths: list[int],
|
||||||
|
) -> int:
|
||||||
|
"""Batched RDMA write (one-shot)."""
|
||||||
|
if not local_addrs:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
|
target.session_id, local_addrs, remote_addrs, lengths
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("snapshot_link.batch_push threw: %s", e)
|
||||||
|
return -1
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Best-effort shutdown — release the receive buffer registration."""
|
||||||
|
if self._recv_ptr:
|
||||||
|
try:
|
||||||
|
self.engine.unregister_memory(self._recv_ptr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._recv_ptr = 0
|
||||||
|
self._recv_capacity = 0
|
||||||
|
self._recv_buffer = None
|
||||||
|
|
||||||
|
|
||||||
|
def make_session_id(host: str, rpc_port: int) -> str:
|
||||||
|
"""Build the ``host:port`` form used as mooncake's session id."""
|
||||||
|
return f"{host}:{rpc_port}"
|
||||||
Reference in New Issue
Block a user