feat(sglang): D→P snapshot link integration — controller + RPC handlers
Phase 2 of the D→P sync feature (Phase 1 in dc4867c verified the
underlying RDMA link in isolation). This commit wires that link into
each SGLang worker's scheduler so D and P can exchange session KV
without going through the PD prefill pipeline.
New module:
third_party/sglang/python/sglang/srt/disaggregation/snapshot/
controller.py — SnapshotLinkController owns one mooncake transfer
engine per worker, pre-registers all kv_pool layer
buffers, and exposes prepare_receive() and
push_session_kv() APIs. Receive bookkeeping via
a session_id → SnapshotIngestRecord side-table.
Three RPC types added to io_struct.py and full plumbing wired through:
SnapshotPrepareReceiveReqInput/Output P-side alloc + return layout
SnapshotDumpReqInput/Output D-side read kv_pool + RDMA push
SnapshotFinalizeIngestReqInput/Output P-side radix tree insert
Files touched:
managers/io_struct.py 3 new ReqInput/ReqOutput pairs
managers/tokenizer_communicator_mixin.py 3 communicators, 3 awaitables
managers/scheduler.py init controller + 3 handlers
entrypoints/http_server.py 3 HTTP endpoints under /_snapshot
Activation: set SGLANG_SNAPSHOT_LINK_ENABLE=1 (and
SGLANG_SNAPSHOT_LINK_HOST / _PORT / _IB_DEVICE) per worker. Controller
init is opt-in and defaults off, so production PD pipeline is
untouched.
Subsequent work (Phase 3): agentic-pd-hybrid orchestration in
_invoke_kvcache_seeded_router to call prepare_receive on P, dump on
D-old, finalize_ingest on P, then trigger the existing P→D' transfer
which will now hit P's radix cache (skipping re-prefill).
This commit is contained in:
27
third_party/sglang/python/sglang/srt/disaggregation/snapshot/__init__.py
vendored
Normal file
27
third_party/sglang/python/sglang/srt/disaggregation/snapshot/__init__.py
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
"""D→P RDMA snapshot push subsystem.
|
||||
|
||||
A minimal, role-symmetric mooncake transport that runs alongside SGLang's
|
||||
existing PD pipeline. Both D and P workers can both send and receive
|
||||
snapshots — direction is determined by which kv_pool we read from /
|
||||
write into.
|
||||
|
||||
See ``docs/D_TO_P_SYNC_DESIGN_ZH.md`` for the full design.
|
||||
"""
|
||||
|
||||
from sglang.srt.disaggregation.snapshot.controller import (
|
||||
SnapshotLinkController,
|
||||
SnapshotIngestRecord,
|
||||
SNAPSHOT_LINK_ENABLE_ENV,
|
||||
SNAPSHOT_LINK_HOST_ENV,
|
||||
SNAPSHOT_LINK_PORT_ENV,
|
||||
SNAPSHOT_LINK_IB_DEVICE_ENV,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SnapshotLinkController",
|
||||
"SnapshotIngestRecord",
|
||||
"SNAPSHOT_LINK_ENABLE_ENV",
|
||||
"SNAPSHOT_LINK_HOST_ENV",
|
||||
"SNAPSHOT_LINK_PORT_ENV",
|
||||
"SNAPSHOT_LINK_IB_DEVICE_ENV",
|
||||
]
|
||||
342
third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py
vendored
Normal file
342
third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py
vendored
Normal file
@@ -0,0 +1,342 @@
|
||||
"""SnapshotLinkController — drives D→P RDMA snapshot pushes.
|
||||
|
||||
This class owns:
|
||||
* A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD
|
||||
pipeline engine).
|
||||
* Memory-region registrations covering the worker's KV pool layer buffers
|
||||
(registered once at startup for zero per-snapshot overhead).
|
||||
* A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side
|
||||
receivers tracking outstanding ingests until ``snapshot_finalize_ingest``
|
||||
is called.
|
||||
|
||||
Lifecycle:
|
||||
SGLang scheduler instantiates one of these per worker if
|
||||
``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for
|
||||
feeding kv_pool buffer descriptors and the kv_pool_allocator so the
|
||||
controller can pre-register memory and (on P) allocate slots.
|
||||
|
||||
Direction symmetry:
|
||||
Both D and P workers run identical controllers. Who sends and who
|
||||
receives is determined by who calls ``push_session_kv`` vs
|
||||
``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE
|
||||
role baked into the snapshot side.
|
||||
|
||||
This is the **vendored** copy living alongside SGLang internals so the
|
||||
scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py``
|
||||
contains the same primitives for stand-alone smoke testing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Env-var names (also exported from package __init__)
|
||||
SNAPSHOT_LINK_ENABLE_ENV = "SGLANG_SNAPSHOT_LINK_ENABLE"
|
||||
SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST"
|
||||
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
|
||||
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LayerBufferDesc:
|
||||
"""Per-layer KV buffer descriptor on this worker."""
|
||||
base_ptr: int # data pointer of the layer's full buffer tensor
|
||||
bytes_per_token: int # head_num * head_dim * dtype.itemsize
|
||||
capacity_bytes: int # full buffer size in bytes
|
||||
is_k: bool # True for K-buffer, False for V
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotIngestRecord:
|
||||
"""P-side: bookkeeping for an outstanding incoming snapshot."""
|
||||
session_id: str
|
||||
slot_indices: List[int] # kv_pool slots reserved for this ingest
|
||||
num_tokens: int
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
def _import_transfer_engine():
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"mooncake.engine.TransferEngine is required for the snapshot "
|
||||
"link. Install mooncake-transfer-engine in the venv."
|
||||
) from e
|
||||
return TransferEngine
|
||||
|
||||
|
||||
class SnapshotLinkController:
|
||||
"""Mooncake engine + per-layer mem registrations + ingest bookkeeping."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
ib_device: Optional[str],
|
||||
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
|
||||
token_to_kv_pool_allocator,
|
||||
protocol: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
host, port : where this worker binds its snapshot engine.
|
||||
ib_device : preferred IB HCA (e.g. "mlx5_60").
|
||||
kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token,
|
||||
capacity_bytes, is_k)`` tuples. Order should be K-layers
|
||||
first then V-layers (matches MHATokenToKVPool layout).
|
||||
token_to_kv_pool_allocator : the worker's allocator. We use
|
||||
``.alloc(N)`` on the P side to reserve receive slots.
|
||||
"""
|
||||
TransferEngine = _import_transfer_engine()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ib_device = ib_device
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.layer_buffers: List[_LayerBufferDesc] = [
|
||||
_LayerBufferDesc(
|
||||
base_ptr=base, bytes_per_token=btok,
|
||||
capacity_bytes=cap, is_k=is_k,
|
||||
)
|
||||
for (base, btok, cap, is_k) in kv_pool_layer_buffers
|
||||
]
|
||||
|
||||
self.engine = TransferEngine()
|
||||
proto = protocol or os.environ.get("MOONCAKE_PROTOCOL", "rdma")
|
||||
listen = f"{host}:{port}"
|
||||
ret = self.engine.initialize(listen, "P2PHANDSHAKE", proto, ib_device or "")
|
||||
if ret != 0:
|
||||
raise RuntimeError(
|
||||
f"SnapshotLinkController.initialize({listen}, {proto}, "
|
||||
f"ib={ib_device}) returned {ret}"
|
||||
)
|
||||
self._session_id = f"{host}:{self.engine.get_rpc_port()}"
|
||||
|
||||
# Register all layer buffers up-front (one-shot)
|
||||
ptrs = [d.base_ptr for d in self.layer_buffers]
|
||||
lens = [d.capacity_bytes for d in self.layer_buffers]
|
||||
try:
|
||||
reg_ret = self.engine.batch_register_memory(ptrs, lens)
|
||||
except Exception:
|
||||
reg_ret = -1
|
||||
# Fall back to individual register_memory calls
|
||||
reg_ret = 0
|
||||
for ptr, length in zip(ptrs, lens):
|
||||
r = self.engine.register_memory(ptr, length)
|
||||
if r != 0:
|
||||
logger.warning(
|
||||
"SnapshotLinkController register_memory(%s, %d) returned %d",
|
||||
hex(ptr), length, r,
|
||||
)
|
||||
reg_ret = r
|
||||
if reg_ret != 0:
|
||||
logger.warning(
|
||||
"SnapshotLinkController batch_register_memory returned %d "
|
||||
"(continuing — individual registrations may have succeeded)",
|
||||
reg_ret,
|
||||
)
|
||||
|
||||
# Receive-side bookkeeping
|
||||
self._ingest_records: dict[str, SnapshotIngestRecord] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
logger.info(
|
||||
"SnapshotLinkController up at %s (snapshot_session_id=%s, "
|
||||
"%d layer buffers registered, ib=%s)",
|
||||
listen, self._session_id, len(self.layer_buffers), ib_device,
|
||||
)
|
||||
|
||||
# ----- accessors ----------------------------------------------------
|
||||
|
||||
@property
|
||||
def snapshot_session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def layer_num(self) -> int:
|
||||
"""Number of K layers (= number of V layers = total / 2)."""
|
||||
return len(self.layer_buffers) // 2
|
||||
|
||||
def get_k_base_ptrs(self) -> List[int]:
|
||||
return [d.base_ptr for d in self.layer_buffers if d.is_k]
|
||||
|
||||
def get_v_base_ptrs(self) -> List[int]:
|
||||
return [d.base_ptr for d in self.layer_buffers if not d.is_k]
|
||||
|
||||
def get_stride_k_bytes(self) -> int:
|
||||
for d in self.layer_buffers:
|
||||
if d.is_k:
|
||||
return d.bytes_per_token
|
||||
return 0
|
||||
|
||||
def get_stride_v_bytes(self) -> int:
|
||||
for d in self.layer_buffers:
|
||||
if not d.is_k:
|
||||
return d.bytes_per_token
|
||||
return 0
|
||||
|
||||
# ----- P-side: prepare to receive ----------------------------------
|
||||
|
||||
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
|
||||
"""Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot.
|
||||
|
||||
Returns a record with the slot indices, or ``None`` if no capacity.
|
||||
Mooncake registration is already in place (whole-buffer registration
|
||||
at startup), so no per-snapshot register is needed.
|
||||
"""
|
||||
try:
|
||||
indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
except Exception as e:
|
||||
logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e)
|
||||
return None
|
||||
if indices_tensor is None:
|
||||
return None
|
||||
try:
|
||||
slot_indices = [int(x) for x in indices_tensor.tolist()]
|
||||
except Exception:
|
||||
# If allocator returns a python list directly
|
||||
slot_indices = list(map(int, indices_tensor))
|
||||
record = SnapshotIngestRecord(
|
||||
session_id=session_id,
|
||||
slot_indices=slot_indices,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
with self._lock:
|
||||
old = self._ingest_records.pop(session_id, None)
|
||||
if old is not None:
|
||||
# Best-effort: free old slots if we're overwriting
|
||||
try:
|
||||
self._free_slots(old.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
self._ingest_records[session_id] = record
|
||||
return record
|
||||
|
||||
def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]:
|
||||
with self._lock:
|
||||
return self._ingest_records.pop(session_id, None)
|
||||
|
||||
def discard_record(self, session_id: str) -> None:
|
||||
"""Drop a pending ingest (e.g. on timeout or D-side failure)."""
|
||||
with self._lock:
|
||||
rec = self._ingest_records.pop(session_id, None)
|
||||
if rec is not None:
|
||||
try:
|
||||
self._free_slots(rec.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _free_slots(self, slot_indices: List[int]) -> None:
|
||||
import torch
|
||||
t = torch.tensor(slot_indices, dtype=torch.int64,
|
||||
device=self._allocator_device())
|
||||
try:
|
||||
self.token_to_kv_pool_allocator.free(t)
|
||||
except Exception as e:
|
||||
logger.warning("SnapshotLinkController._free_slots failed: %s", e)
|
||||
|
||||
def _allocator_device(self):
|
||||
# Best-effort: pull device from one of the buffer tensors via the allocator
|
||||
try:
|
||||
return self.token_to_kv_pool_allocator.device
|
||||
except AttributeError:
|
||||
return "cuda"
|
||||
|
||||
# ----- D-side: push session KV --------------------------------------
|
||||
|
||||
def push_session_kv(
|
||||
self,
|
||||
*,
|
||||
target_snapshot_session_id: str,
|
||||
src_slot_indices: List[int],
|
||||
target_k_base_ptrs: List[int],
|
||||
target_v_base_ptrs: List[int],
|
||||
target_slot_indices: List[int],
|
||||
target_stride_k_bytes: int,
|
||||
target_stride_v_bytes: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""Push the KV bytes at src_slot_indices to the remote slots.
|
||||
|
||||
Returns ``(mooncake_return_code, bytes_pushed)``. Caller is
|
||||
responsible for any post-push handshake (e.g. finalize_ingest RPC).
|
||||
"""
|
||||
layer_num = self.layer_num
|
||||
k_src_bases = self.get_k_base_ptrs()
|
||||
v_src_bases = self.get_v_base_ptrs()
|
||||
stride_k = self.get_stride_k_bytes()
|
||||
stride_v = self.get_stride_v_bytes()
|
||||
if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num:
|
||||
raise ValueError(
|
||||
f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} "
|
||||
f"!= local layer_num {layer_num}"
|
||||
)
|
||||
if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes:
|
||||
raise ValueError(
|
||||
f"stride mismatch: local k={stride_k}, v={stride_v}; "
|
||||
f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}"
|
||||
)
|
||||
if len(src_slot_indices) != len(target_slot_indices):
|
||||
raise ValueError(
|
||||
f"slot count mismatch: src={len(src_slot_indices)}, "
|
||||
f"target={len(target_slot_indices)}"
|
||||
)
|
||||
|
||||
local_addrs: List[int] = []
|
||||
remote_addrs: List[int] = []
|
||||
lengths: List[int] = []
|
||||
|
||||
# Group contiguous runs on the target side to coalesce ops.
|
||||
# Simple approach: per (layer, K/V) pair, walk src/target index
|
||||
# tuples; merge runs where both src and target are sequential.
|
||||
for layer_id in range(layer_num):
|
||||
for kv_bases, stride, kv_label in (
|
||||
(k_src_bases[layer_id], stride_k, "K"),
|
||||
(v_src_bases[layer_id], stride_v, "V"),
|
||||
):
|
||||
src_base = kv_bases
|
||||
if kv_label == "K":
|
||||
tgt_base = target_k_base_ptrs[layer_id]
|
||||
else:
|
||||
tgt_base = target_v_base_ptrs[layer_id]
|
||||
run_src_start = run_tgt_start = run_len = None
|
||||
for s, t in zip(src_slot_indices, target_slot_indices):
|
||||
if run_src_start is None:
|
||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
||||
elif s == run_src_start + run_len and t == run_tgt_start + run_len:
|
||||
run_len += 1
|
||||
else:
|
||||
local_addrs.append(src_base + run_src_start * stride)
|
||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||
lengths.append(run_len * stride)
|
||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
||||
if run_src_start is not None:
|
||||
local_addrs.append(src_base + run_src_start * stride)
|
||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||
lengths.append(run_len * stride)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
target_snapshot_session_id, local_addrs, remote_addrs, lengths
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("SnapshotLinkController.push_session_kv threw: %s", e)
|
||||
return -1, 0
|
||||
t1 = time.perf_counter()
|
||||
bytes_pushed = sum(lengths)
|
||||
logger.info(
|
||||
"SnapshotLinkController.push_session_kv → %s: %d ops, %d B, "
|
||||
"ret=%d, %.2f ms",
|
||||
target_snapshot_session_id, len(lengths), bytes_pushed, ret,
|
||||
(t1 - t0) * 1000.0,
|
||||
)
|
||||
return ret, bytes_pushed
|
||||
@@ -125,6 +125,9 @@ from sglang.srt.managers.io_struct import (
|
||||
LoadLoRAAdapterFromTensorsReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
DirectAppendAdmissionReqInput,
|
||||
SnapshotDumpReqInput,
|
||||
SnapshotFinalizeIngestReqInput,
|
||||
SnapshotPrepareReceiveReqInput,
|
||||
OpenSessionReqInput,
|
||||
ParseFunctionCallReq,
|
||||
PauseGenerationReqInput,
|
||||
@@ -1295,6 +1298,21 @@ async def admit_direct_append(obj: DirectAppendAdmissionReqInput):
|
||||
return await _global_state.tokenizer_manager.admit_direct_append(obj)
|
||||
|
||||
|
||||
@app.post("/_snapshot/prepare_receive")
|
||||
async def snapshot_prepare_receive(obj: SnapshotPrepareReceiveReqInput):
|
||||
return await _global_state.tokenizer_manager.snapshot_prepare_receive(obj)
|
||||
|
||||
|
||||
@app.post("/_snapshot/dump")
|
||||
async def snapshot_dump(obj: SnapshotDumpReqInput):
|
||||
return await _global_state.tokenizer_manager.snapshot_dump(obj)
|
||||
|
||||
|
||||
@app.post("/_snapshot/finalize_ingest")
|
||||
async def snapshot_finalize_ingest(obj: SnapshotFinalizeIngestReqInput):
|
||||
return await _global_state.tokenizer_manager.snapshot_finalize_ingest(obj)
|
||||
|
||||
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
@auth_level(AuthLevel.ADMIN_OPTIONAL)
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
|
||||
@@ -1632,6 +1632,97 @@ class HealthCheckOutput(BaseReq):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# D→P snapshot ingest (Phase 2 of D→P sync feature; see
|
||||
# docs/D_TO_P_SYNC_DESIGN_ZH.md).
|
||||
#
|
||||
# Three-step protocol orchestrated by agentic-pd-hybrid:
|
||||
# 1. PrepareReceive → P allocates kv_pool slots + returns destination
|
||||
# addresses for D's RDMA writes.
|
||||
# 2. (out-of-band) → D uses snapshot_link to RDMA-push KV bytes
|
||||
# directly to P's slot addresses.
|
||||
# 3. FinalizeIngest → P inserts (token_ids, kv_indices) into its radix
|
||||
# tree so subsequent prefill requests for this
|
||||
# session see a cache hit.
|
||||
#
|
||||
# Each step is its own ReqInput/ReqOutput pair so the scheduler handlers can
|
||||
# be written stateless and the orchestrator can retry / abort cleanly.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotPrepareReceiveReqInput(BaseReq):
|
||||
"""P-side: allocate slots + register them with mooncake for D to push into."""
|
||||
|
||||
session_id: str
|
||||
num_tokens: int # P will alloc this many contiguous slots
|
||||
expected_bytes_per_layer_k: int = 0 # per-token K bytes × num_tokens (sanity)
|
||||
expected_bytes_per_layer_v: int = 0 # per-token V bytes × num_tokens (sanity)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotPrepareReceiveReqOutput(BaseReq):
|
||||
ok: bool
|
||||
reason: Optional[str] = None
|
||||
# Layout the D side needs to address P's kv_pool slots:
|
||||
# k_base_ptrs[L] = base device address of layer L's K buffer
|
||||
# v_base_ptrs[L] = base device address of layer L's V buffer
|
||||
# slot_indices = the contiguous range P allocated (list[int], length=num_tokens)
|
||||
# stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize
|
||||
# stride_v_bytes = bytes per token V (often equals stride_k_bytes)
|
||||
# P also registers these slot regions with mooncake before returning.
|
||||
k_base_ptrs: List[int] = field(default_factory=list)
|
||||
v_base_ptrs: List[int] = field(default_factory=list)
|
||||
slot_indices: List[int] = field(default_factory=list)
|
||||
stride_k_bytes: int = 0
|
||||
stride_v_bytes: int = 0
|
||||
layer_num: int = 0
|
||||
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
|
||||
snapshot_session_id: str = ""
|
||||
available_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotDumpReqInput(BaseReq):
|
||||
"""D-side: dump session KV via snapshot_link to a target P endpoint."""
|
||||
|
||||
session_id: str
|
||||
target_snapshot_session_id: str # P's mooncake snapshot session id
|
||||
target_k_base_ptrs: List[int] = field(default_factory=list)
|
||||
target_v_base_ptrs: List[int] = field(default_factory=list)
|
||||
target_slot_indices: List[int] = field(default_factory=list)
|
||||
target_stride_k_bytes: int = 0
|
||||
target_stride_v_bytes: int = 0
|
||||
ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotDumpReqOutput(BaseReq):
|
||||
ok: bool
|
||||
reason: Optional[str] = None
|
||||
bytes_pushed: int = 0
|
||||
transfer_duration_ms: float = 0.0
|
||||
kv_committed_len: int = 0 # the actual number of tokens D had for this session
|
||||
# The token_ids that go with the KV (so P can call radix_cache.insert)
|
||||
token_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotFinalizeIngestReqInput(BaseReq):
|
||||
"""P-side: insert (token_ids, kv_indices) into radix tree after D's push."""
|
||||
|
||||
session_id: str
|
||||
token_ids: List[int]
|
||||
slot_indices: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotFinalizeIngestReqOutput(BaseReq):
|
||||
ok: bool
|
||||
reason: Optional[str] = None
|
||||
inserted_prefix_len: int = 0
|
||||
|
||||
|
||||
class ExpertDistributionReqType(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
|
||||
@@ -96,6 +96,12 @@ from sglang.srt.managers.io_struct import (
|
||||
ContinueGenerationReqInput,
|
||||
DirectAppendAdmissionReqInput,
|
||||
DirectAppendAdmissionReqOutput,
|
||||
SnapshotDumpReqInput,
|
||||
SnapshotDumpReqOutput,
|
||||
SnapshotFinalizeIngestReqInput,
|
||||
SnapshotFinalizeIngestReqOutput,
|
||||
SnapshotPrepareReceiveReqInput,
|
||||
SnapshotPrepareReceiveReqOutput,
|
||||
DestroyWeightsUpdateGroupReqInput,
|
||||
DetachHiCacheStorageReqInput,
|
||||
DetachHiCacheStorageReqOutput,
|
||||
@@ -844,6 +850,69 @@ class Scheduler(
|
||||
embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get()
|
||||
init_mm_embedding_cache(embedding_cache_size * 1024 * 1024)
|
||||
|
||||
# ---- D→P snapshot link (Phase 2 of D→P sync feature) ------------
|
||||
# Enabled per-worker via SGLANG_SNAPSHOT_LINK_ENABLE=1. Each worker
|
||||
# binds an independent mooncake transfer engine on
|
||||
# SGLANG_SNAPSHOT_LINK_HOST:SGLANG_SNAPSHOT_LINK_PORT and pre-
|
||||
# registers the kv_pool layer buffers for one-shot RDMA pushes /
|
||||
# receives. See docs/D_TO_P_SYNC_DESIGN_ZH.md.
|
||||
self.snapshot_link_controller = None
|
||||
from sglang.srt.disaggregation.snapshot import (
|
||||
SnapshotLinkController as _SnapLinkCtrl,
|
||||
SNAPSHOT_LINK_ENABLE_ENV,
|
||||
SNAPSHOT_LINK_HOST_ENV,
|
||||
SNAPSHOT_LINK_PORT_ENV,
|
||||
SNAPSHOT_LINK_IB_DEVICE_ENV,
|
||||
)
|
||||
if os.environ.get(SNAPSHOT_LINK_ENABLE_ENV, "0") == "1":
|
||||
host = os.environ.get(SNAPSHOT_LINK_HOST_ENV, server_args.host)
|
||||
port = int(os.environ.get(SNAPSHOT_LINK_PORT_ENV,
|
||||
str(server_args.disaggregation_bootstrap_port + 1000)))
|
||||
ib = os.environ.get(SNAPSHOT_LINK_IB_DEVICE_ENV, server_args.disaggregation_ib_device)
|
||||
try:
|
||||
kv_pool = self.token_to_kv_pool_allocator.get_kvcache()
|
||||
except AttributeError:
|
||||
# Some allocators expose the pool directly
|
||||
kv_pool = getattr(self.token_to_kv_pool_allocator, "kvcache", None)
|
||||
if kv_pool is None:
|
||||
logger.warning("SNAPSHOT_LINK_ENABLE=1 but kv_pool unavailable; skipping init")
|
||||
else:
|
||||
try:
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = kv_pool.get_contiguous_buf_infos()
|
||||
layer_n = len(kv_data_ptrs) // 2
|
||||
layer_buffers = []
|
||||
# K layers first, then V layers (matches MHATokenToKVPool.get_contiguous_buf_infos)
|
||||
for i in range(layer_n):
|
||||
layer_buffers.append((
|
||||
kv_data_ptrs[i],
|
||||
kv_item_lens[i] // max(1, kv_pool.page_size),
|
||||
kv_data_lens[i],
|
||||
True, # is_k
|
||||
))
|
||||
for i in range(layer_n):
|
||||
layer_buffers.append((
|
||||
kv_data_ptrs[layer_n + i],
|
||||
kv_item_lens[layer_n + i] // max(1, kv_pool.page_size),
|
||||
kv_data_lens[layer_n + i],
|
||||
False, # is_k=False (V)
|
||||
))
|
||||
self.snapshot_link_controller = _SnapLinkCtrl(
|
||||
host=host,
|
||||
port=port,
|
||||
ib_device=ib,
|
||||
kv_pool_layer_buffers=layer_buffers,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
logger.info(
|
||||
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
|
||||
f"{host}:{port}",
|
||||
self.snapshot_link_controller.snapshot_session_id,
|
||||
len(layer_buffers),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Snapshot link init failed: %s; continuing without it", e)
|
||||
self.snapshot_link_controller = None
|
||||
|
||||
def init_running_status(self):
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.decode_direct_waiting_queue: List[Req] = []
|
||||
@@ -1219,6 +1288,9 @@ class Scheduler(
|
||||
(OpenSessionReqInput, self.open_session),
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(DirectAppendAdmissionReqInput, self.admit_direct_append),
|
||||
(SnapshotPrepareReceiveReqInput, self.snapshot_prepare_receive),
|
||||
(SnapshotDumpReqInput, self.snapshot_dump),
|
||||
(SnapshotFinalizeIngestReqInput, self.snapshot_finalize_ingest),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
||||
@@ -3673,6 +3745,162 @@ class Scheduler(
|
||||
),
|
||||
)
|
||||
|
||||
# ----- D→P snapshot link handlers (Phase 2/3) ---------------------
|
||||
|
||||
def snapshot_prepare_receive(
|
||||
self, recv_req: SnapshotPrepareReceiveReqInput
|
||||
) -> SnapshotPrepareReceiveReqOutput:
|
||||
"""P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push."""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
ok=False, reason="snapshot-link-disabled",
|
||||
)
|
||||
try:
|
||||
available = int(self.token_to_kv_pool_allocator.available_size())
|
||||
except Exception:
|
||||
available = -1
|
||||
if recv_req.num_tokens <= 0:
|
||||
return SnapshotPrepareReceiveReqOutput(ok=False, reason="zero-tokens")
|
||||
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
|
||||
if record is None:
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
ok=False, reason="alloc-failed",
|
||||
available_tokens=available,
|
||||
)
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
ok=True,
|
||||
k_base_ptrs=ctrl.get_k_base_ptrs(),
|
||||
v_base_ptrs=ctrl.get_v_base_ptrs(),
|
||||
slot_indices=record.slot_indices,
|
||||
stride_k_bytes=ctrl.get_stride_k_bytes(),
|
||||
stride_v_bytes=ctrl.get_stride_v_bytes(),
|
||||
layer_num=ctrl.layer_num,
|
||||
snapshot_session_id=ctrl.snapshot_session_id,
|
||||
available_tokens=available,
|
||||
)
|
||||
|
||||
def snapshot_dump(
|
||||
self, recv_req: SnapshotDumpReqInput
|
||||
) -> SnapshotDumpReqOutput:
|
||||
"""D-side: gather session KV from kv_pool, RDMA-write to remote slots."""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
|
||||
if not isinstance(self.tree_cache, SessionAwareCache):
|
||||
return SnapshotDumpReqOutput(ok=False, reason="tree-cache-not-session-aware")
|
||||
slot = self.tree_cache.slots.get(recv_req.session_id)
|
||||
if slot is None or slot.req_pool_idx is None:
|
||||
return SnapshotDumpReqOutput(ok=False, reason="session-not-resident")
|
||||
kv_committed_len = int(slot.kv_committed_len)
|
||||
if kv_committed_len == 0:
|
||||
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
|
||||
# Read kv_indices for the session's prefix
|
||||
try:
|
||||
kv_idx_tensor = self.req_to_token_pool.req_to_token[
|
||||
slot.req_pool_idx, :kv_committed_len
|
||||
]
|
||||
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_dump: failed to read kv_indices: %s", e)
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}")
|
||||
|
||||
# Truncate to the count P prepared for (must match)
|
||||
target_n = len(recv_req.target_slot_indices)
|
||||
if target_n > kv_committed_len:
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=False,
|
||||
reason=f"target-larger-than-source({target_n}>{kv_committed_len})",
|
||||
)
|
||||
src_slot_indices = src_slot_indices[:target_n]
|
||||
|
||||
try:
|
||||
ret, bytes_pushed = ctrl.push_session_kv(
|
||||
target_snapshot_session_id=recv_req.target_snapshot_session_id,
|
||||
src_slot_indices=src_slot_indices,
|
||||
target_k_base_ptrs=recv_req.target_k_base_ptrs,
|
||||
target_v_base_ptrs=recv_req.target_v_base_ptrs,
|
||||
target_slot_indices=recv_req.target_slot_indices[:target_n],
|
||||
target_stride_k_bytes=recv_req.target_stride_k_bytes,
|
||||
target_stride_v_bytes=recv_req.target_stride_v_bytes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_dump: push_session_kv threw: %s", e)
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}")
|
||||
|
||||
if ret != 0:
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=False,
|
||||
reason=f"mooncake-batch-write-ret={ret}",
|
||||
bytes_pushed=int(bytes_pushed),
|
||||
kv_committed_len=int(kv_committed_len),
|
||||
)
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=True, bytes_pushed=int(bytes_pushed),
|
||||
kv_committed_len=int(kv_committed_len),
|
||||
token_ids=[], # caller already has token_ids
|
||||
)
|
||||
|
||||
def snapshot_finalize_ingest(
|
||||
self, recv_req: SnapshotFinalizeIngestReqInput
|
||||
) -> SnapshotFinalizeIngestReqOutput:
|
||||
"""P-side: insert (token_ids, slot_indices) into radix tree."""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason="snapshot-link-disabled",
|
||||
)
|
||||
record = ctrl.take_record(recv_req.session_id)
|
||||
if record is None:
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason="no-pending-ingest",
|
||||
)
|
||||
# Sanity: the slot indices we're about to insert should match the ones we reserved.
|
||||
if list(recv_req.slot_indices) != record.slot_indices:
|
||||
# The caller passed back the slot indices we returned in prepare; if they
|
||||
# don't match, something's gone wrong. Free reserved slots and bail.
|
||||
try:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False,
|
||||
reason="slot-indices-mismatch",
|
||||
)
|
||||
n_tokens = min(len(recv_req.token_ids), len(record.slot_indices))
|
||||
if n_tokens == 0:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids")
|
||||
try:
|
||||
import torch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
kv_indices = torch.tensor(
|
||||
record.slot_indices[:n_tokens],
|
||||
dtype=torch.int64,
|
||||
device=self.tree_cache.token_to_kv_pool_allocator.device,
|
||||
)
|
||||
radix_key = RadixKey(recv_req.token_ids[:n_tokens], None)
|
||||
inner = (
|
||||
self.tree_cache.inner
|
||||
if isinstance(self.tree_cache, SessionAwareCache)
|
||||
else self.tree_cache
|
||||
)
|
||||
result = inner.insert(InsertParams(key=radix_key, value=kv_indices))
|
||||
inserted = int(result.prefix_len)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e)
|
||||
try:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason=f"radix-insert-failed: {e!r}",
|
||||
)
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=True, inserted_prefix_len=inserted,
|
||||
)
|
||||
|
||||
def _compute_backpressure_pause_hint(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -74,6 +74,12 @@ from sglang.srt.managers.io_struct import (
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
SlowDownReqOutput,
|
||||
SnapshotDumpReqInput,
|
||||
SnapshotDumpReqOutput,
|
||||
SnapshotFinalizeIngestReqInput,
|
||||
SnapshotFinalizeIngestReqOutput,
|
||||
SnapshotPrepareReceiveReqInput,
|
||||
SnapshotPrepareReceiveReqOutput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -225,6 +231,15 @@ class TokenizerCommunicatorMixin:
|
||||
self.direct_append_admission_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.snapshot_prepare_receive_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.snapshot_dump_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.snapshot_finalize_ingest_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -325,6 +340,18 @@ class TokenizerCommunicatorMixin:
|
||||
DirectAppendAdmissionReqOutput,
|
||||
self.direct_append_admission_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SnapshotPrepareReceiveReqOutput,
|
||||
self.snapshot_prepare_receive_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SnapshotDumpReqOutput,
|
||||
self.snapshot_dump_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SnapshotFinalizeIngestReqOutput,
|
||||
self.snapshot_finalize_ingest_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
@@ -890,6 +917,36 @@ class TokenizerCommunicatorMixin:
|
||||
)
|
||||
return responses[0]
|
||||
|
||||
async def snapshot_prepare_receive(
|
||||
self: TokenizerManager,
|
||||
obj: SnapshotPrepareReceiveReqInput,
|
||||
) -> SnapshotPrepareReceiveReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
responses: List[SnapshotPrepareReceiveReqOutput] = (
|
||||
await self.snapshot_prepare_receive_communicator(obj)
|
||||
)
|
||||
return responses[0]
|
||||
|
||||
async def snapshot_dump(
|
||||
self: TokenizerManager,
|
||||
obj: SnapshotDumpReqInput,
|
||||
) -> SnapshotDumpReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
responses: List[SnapshotDumpReqOutput] = (
|
||||
await self.snapshot_dump_communicator(obj)
|
||||
)
|
||||
return responses[0]
|
||||
|
||||
async def snapshot_finalize_ingest(
|
||||
self: TokenizerManager,
|
||||
obj: SnapshotFinalizeIngestReqInput,
|
||||
) -> SnapshotFinalizeIngestReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
responses: List[SnapshotFinalizeIngestReqOutput] = (
|
||||
await self.snapshot_finalize_ingest_communicator(obj)
|
||||
)
|
||||
return responses[0]
|
||||
|
||||
async def set_internal_state(
|
||||
self: TokenizerManager, obj: SetInternalStateReq
|
||||
) -> List[bool]:
|
||||
|
||||
Reference in New Issue
Block a user