feat(sglang): D→P snapshot link integration — controller + RPC handlers

Phase 2 of the D→P sync feature (Phase 1 in dc4867c verified the
underlying RDMA link in isolation). This commit wires that link into
each SGLang worker's scheduler so D and P can exchange session KV
without going through the PD prefill pipeline.

New module:
  third_party/sglang/python/sglang/srt/disaggregation/snapshot/
    controller.py — SnapshotLinkController owns one mooncake transfer
                    engine per worker, pre-registers all kv_pool layer
                    buffers, and exposes prepare_receive() and
                    push_session_kv() APIs. Receive bookkeeping via
                    a session_id → SnapshotIngestRecord side-table.

Three RPC types added to io_struct.py and full plumbing wired through:
  SnapshotPrepareReceiveReqInput/Output   P-side alloc + return layout
  SnapshotDumpReqInput/Output             D-side read kv_pool + RDMA push
  SnapshotFinalizeIngestReqInput/Output   P-side radix tree insert

Files touched:
  managers/io_struct.py                   3 new ReqInput/ReqOutput pairs
  managers/tokenizer_communicator_mixin.py  3 communicators, 3 awaitables
  managers/scheduler.py                   init controller + 3 handlers
  entrypoints/http_server.py              3 HTTP endpoints under /_snapshot

Activation: set SGLANG_SNAPSHOT_LINK_ENABLE=1 (and
SGLANG_SNAPSHOT_LINK_HOST / _PORT / _IB_DEVICE) per worker. Controller
init is opt-in and defaults off, so production PD pipeline is
untouched.

Subsequent work (Phase 3): agentic-pd-hybrid orchestration in
_invoke_kvcache_seeded_router to call prepare_receive on P, dump on
D-old, finalize_ingest on P, then trigger the existing P→D' transfer
which will now hit P's radix cache (skipping re-prefill).
This commit is contained in:
Claude Code Agent
2026-05-13 08:12:04 +08:00
parent 7216507773
commit 86412bb174
6 changed files with 763 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
"""D→P RDMA snapshot push subsystem.
A minimal, role-symmetric mooncake transport that runs alongside SGLang's
existing PD pipeline. Both D and P workers can both send and receive
snapshots — direction is determined by which kv_pool we read from /
write into.
See ``docs/D_TO_P_SYNC_DESIGN_ZH.md`` for the full design.
"""
from sglang.srt.disaggregation.snapshot.controller import (
SnapshotLinkController,
SnapshotIngestRecord,
SNAPSHOT_LINK_ENABLE_ENV,
SNAPSHOT_LINK_HOST_ENV,
SNAPSHOT_LINK_PORT_ENV,
SNAPSHOT_LINK_IB_DEVICE_ENV,
)
__all__ = [
"SnapshotLinkController",
"SnapshotIngestRecord",
"SNAPSHOT_LINK_ENABLE_ENV",
"SNAPSHOT_LINK_HOST_ENV",
"SNAPSHOT_LINK_PORT_ENV",
"SNAPSHOT_LINK_IB_DEVICE_ENV",
]

View File

@@ -0,0 +1,342 @@
"""SnapshotLinkController — drives D→P RDMA snapshot pushes.
This class owns:
* A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD
pipeline engine).
* Memory-region registrations covering the worker's KV pool layer buffers
(registered once at startup for zero per-snapshot overhead).
* A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side
receivers tracking outstanding ingests until ``snapshot_finalize_ingest``
is called.
Lifecycle:
SGLang scheduler instantiates one of these per worker if
``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for
feeding kv_pool buffer descriptors and the kv_pool_allocator so the
controller can pre-register memory and (on P) allocate slots.
Direction symmetry:
Both D and P workers run identical controllers. Who sends and who
receives is determined by who calls ``push_session_kv`` vs
``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE
role baked into the snapshot side.
This is the **vendored** copy living alongside SGLang internals so the
scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py``
contains the same primitives for stand-alone smoke testing.
"""
from __future__ import annotations
import ctypes
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
logger = logging.getLogger(__name__)
# Env-var names (also exported from package __init__)
SNAPSHOT_LINK_ENABLE_ENV = "SGLANG_SNAPSHOT_LINK_ENABLE"
SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST"
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
@dataclass
class _LayerBufferDesc:
"""Per-layer KV buffer descriptor on this worker."""
base_ptr: int # data pointer of the layer's full buffer tensor
bytes_per_token: int # head_num * head_dim * dtype.itemsize
capacity_bytes: int # full buffer size in bytes
is_k: bool # True for K-buffer, False for V
@dataclass
class SnapshotIngestRecord:
"""P-side: bookkeeping for an outstanding incoming snapshot."""
session_id: str
slot_indices: List[int] # kv_pool slots reserved for this ingest
num_tokens: int
created_at: float = field(default_factory=time.time)
def _import_transfer_engine():
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"mooncake.engine.TransferEngine is required for the snapshot "
"link. Install mooncake-transfer-engine in the venv."
) from e
return TransferEngine
class SnapshotLinkController:
"""Mooncake engine + per-layer mem registrations + ingest bookkeeping."""
def __init__(
self,
host: str,
port: int,
ib_device: Optional[str],
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
token_to_kv_pool_allocator,
protocol: Optional[str] = None,
):
"""
Parameters
----------
host, port : where this worker binds its snapshot engine.
ib_device : preferred IB HCA (e.g. "mlx5_60").
kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token,
capacity_bytes, is_k)`` tuples. Order should be K-layers
first then V-layers (matches MHATokenToKVPool layout).
token_to_kv_pool_allocator : the worker's allocator. We use
``.alloc(N)`` on the P side to reserve receive slots.
"""
TransferEngine = _import_transfer_engine()
self.host = host
self.port = port
self.ib_device = ib_device
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.layer_buffers: List[_LayerBufferDesc] = [
_LayerBufferDesc(
base_ptr=base, bytes_per_token=btok,
capacity_bytes=cap, is_k=is_k,
)
for (base, btok, cap, is_k) in kv_pool_layer_buffers
]
self.engine = TransferEngine()
proto = protocol or os.environ.get("MOONCAKE_PROTOCOL", "rdma")
listen = f"{host}:{port}"
ret = self.engine.initialize(listen, "P2PHANDSHAKE", proto, ib_device or "")
if ret != 0:
raise RuntimeError(
f"SnapshotLinkController.initialize({listen}, {proto}, "
f"ib={ib_device}) returned {ret}"
)
self._session_id = f"{host}:{self.engine.get_rpc_port()}"
# Register all layer buffers up-front (one-shot)
ptrs = [d.base_ptr for d in self.layer_buffers]
lens = [d.capacity_bytes for d in self.layer_buffers]
try:
reg_ret = self.engine.batch_register_memory(ptrs, lens)
except Exception:
reg_ret = -1
# Fall back to individual register_memory calls
reg_ret = 0
for ptr, length in zip(ptrs, lens):
r = self.engine.register_memory(ptr, length)
if r != 0:
logger.warning(
"SnapshotLinkController register_memory(%s, %d) returned %d",
hex(ptr), length, r,
)
reg_ret = r
if reg_ret != 0:
logger.warning(
"SnapshotLinkController batch_register_memory returned %d "
"(continuing — individual registrations may have succeeded)",
reg_ret,
)
# Receive-side bookkeeping
self._ingest_records: dict[str, SnapshotIngestRecord] = {}
self._lock = threading.Lock()
logger.info(
"SnapshotLinkController up at %s (snapshot_session_id=%s, "
"%d layer buffers registered, ib=%s)",
listen, self._session_id, len(self.layer_buffers), ib_device,
)
# ----- accessors ----------------------------------------------------
@property
def snapshot_session_id(self) -> str:
return self._session_id
@property
def layer_num(self) -> int:
"""Number of K layers (= number of V layers = total / 2)."""
return len(self.layer_buffers) // 2
def get_k_base_ptrs(self) -> List[int]:
return [d.base_ptr for d in self.layer_buffers if d.is_k]
def get_v_base_ptrs(self) -> List[int]:
return [d.base_ptr for d in self.layer_buffers if not d.is_k]
def get_stride_k_bytes(self) -> int:
for d in self.layer_buffers:
if d.is_k:
return d.bytes_per_token
return 0
def get_stride_v_bytes(self) -> int:
for d in self.layer_buffers:
if not d.is_k:
return d.bytes_per_token
return 0
# ----- P-side: prepare to receive ----------------------------------
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
"""Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot.
Returns a record with the slot indices, or ``None`` if no capacity.
Mooncake registration is already in place (whole-buffer registration
at startup), so no per-snapshot register is needed.
"""
try:
indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens)
except Exception as e:
logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e)
return None
if indices_tensor is None:
return None
try:
slot_indices = [int(x) for x in indices_tensor.tolist()]
except Exception:
# If allocator returns a python list directly
slot_indices = list(map(int, indices_tensor))
record = SnapshotIngestRecord(
session_id=session_id,
slot_indices=slot_indices,
num_tokens=num_tokens,
)
with self._lock:
old = self._ingest_records.pop(session_id, None)
if old is not None:
# Best-effort: free old slots if we're overwriting
try:
self._free_slots(old.slot_indices)
except Exception:
pass
self._ingest_records[session_id] = record
return record
def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]:
with self._lock:
return self._ingest_records.pop(session_id, None)
def discard_record(self, session_id: str) -> None:
"""Drop a pending ingest (e.g. on timeout or D-side failure)."""
with self._lock:
rec = self._ingest_records.pop(session_id, None)
if rec is not None:
try:
self._free_slots(rec.slot_indices)
except Exception:
pass
def _free_slots(self, slot_indices: List[int]) -> None:
import torch
t = torch.tensor(slot_indices, dtype=torch.int64,
device=self._allocator_device())
try:
self.token_to_kv_pool_allocator.free(t)
except Exception as e:
logger.warning("SnapshotLinkController._free_slots failed: %s", e)
def _allocator_device(self):
# Best-effort: pull device from one of the buffer tensors via the allocator
try:
return self.token_to_kv_pool_allocator.device
except AttributeError:
return "cuda"
# ----- D-side: push session KV --------------------------------------
def push_session_kv(
self,
*,
target_snapshot_session_id: str,
src_slot_indices: List[int],
target_k_base_ptrs: List[int],
target_v_base_ptrs: List[int],
target_slot_indices: List[int],
target_stride_k_bytes: int,
target_stride_v_bytes: int,
) -> Tuple[int, int]:
"""Push the KV bytes at src_slot_indices to the remote slots.
Returns ``(mooncake_return_code, bytes_pushed)``. Caller is
responsible for any post-push handshake (e.g. finalize_ingest RPC).
"""
layer_num = self.layer_num
k_src_bases = self.get_k_base_ptrs()
v_src_bases = self.get_v_base_ptrs()
stride_k = self.get_stride_k_bytes()
stride_v = self.get_stride_v_bytes()
if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num:
raise ValueError(
f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} "
f"!= local layer_num {layer_num}"
)
if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes:
raise ValueError(
f"stride mismatch: local k={stride_k}, v={stride_v}; "
f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}"
)
if len(src_slot_indices) != len(target_slot_indices):
raise ValueError(
f"slot count mismatch: src={len(src_slot_indices)}, "
f"target={len(target_slot_indices)}"
)
local_addrs: List[int] = []
remote_addrs: List[int] = []
lengths: List[int] = []
# Group contiguous runs on the target side to coalesce ops.
# Simple approach: per (layer, K/V) pair, walk src/target index
# tuples; merge runs where both src and target are sequential.
for layer_id in range(layer_num):
for kv_bases, stride, kv_label in (
(k_src_bases[layer_id], stride_k, "K"),
(v_src_bases[layer_id], stride_v, "V"),
):
src_base = kv_bases
if kv_label == "K":
tgt_base = target_k_base_ptrs[layer_id]
else:
tgt_base = target_v_base_ptrs[layer_id]
run_src_start = run_tgt_start = run_len = None
for s, t in zip(src_slot_indices, target_slot_indices):
if run_src_start is None:
run_src_start, run_tgt_start, run_len = s, t, 1
elif s == run_src_start + run_len and t == run_tgt_start + run_len:
run_len += 1
else:
local_addrs.append(src_base + run_src_start * stride)
remote_addrs.append(tgt_base + run_tgt_start * stride)
lengths.append(run_len * stride)
run_src_start, run_tgt_start, run_len = s, t, 1
if run_src_start is not None:
local_addrs.append(src_base + run_src_start * stride)
remote_addrs.append(tgt_base + run_tgt_start * stride)
lengths.append(run_len * stride)
t0 = time.perf_counter()
try:
ret = self.engine.batch_transfer_sync_write(
target_snapshot_session_id, local_addrs, remote_addrs, lengths
)
except Exception as e:
logger.exception("SnapshotLinkController.push_session_kv threw: %s", e)
return -1, 0
t1 = time.perf_counter()
bytes_pushed = sum(lengths)
logger.info(
"SnapshotLinkController.push_session_kv → %s: %d ops, %d B, "
"ret=%d, %.2f ms",
target_snapshot_session_id, len(lengths), bytes_pushed, ret,
(t1 - t0) * 1000.0,
)
return ret, bytes_pushed

View File

@@ -125,6 +125,9 @@ from sglang.srt.managers.io_struct import (
LoadLoRAAdapterFromTensorsReqInput,
LoadLoRAAdapterReqInput,
DirectAppendAdmissionReqInput,
SnapshotDumpReqInput,
SnapshotFinalizeIngestReqInput,
SnapshotPrepareReceiveReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
@@ -1295,6 +1298,21 @@ async def admit_direct_append(obj: DirectAppendAdmissionReqInput):
return await _global_state.tokenizer_manager.admit_direct_append(obj)
@app.post("/_snapshot/prepare_receive")
async def snapshot_prepare_receive(obj: SnapshotPrepareReceiveReqInput):
return await _global_state.tokenizer_manager.snapshot_prepare_receive(obj)
@app.post("/_snapshot/dump")
async def snapshot_dump(obj: SnapshotDumpReqInput):
return await _global_state.tokenizer_manager.snapshot_dump(obj)
@app.post("/_snapshot/finalize_ingest")
async def snapshot_finalize_ingest(obj: SnapshotFinalizeIngestReqInput):
return await _global_state.tokenizer_manager.snapshot_finalize_ingest(obj)
@app.api_route("/configure_logging", methods=["GET", "POST"])
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def configure_logging(obj: ConfigureLoggingReq, request: Request):

View File

@@ -1632,6 +1632,97 @@ class HealthCheckOutput(BaseReq):
pass
# ---------------------------------------------------------------------------
# D→P snapshot ingest (Phase 2 of D→P sync feature; see
# docs/D_TO_P_SYNC_DESIGN_ZH.md).
#
# Three-step protocol orchestrated by agentic-pd-hybrid:
# 1. PrepareReceive → P allocates kv_pool slots + returns destination
# addresses for D's RDMA writes.
# 2. (out-of-band) → D uses snapshot_link to RDMA-push KV bytes
# directly to P's slot addresses.
# 3. FinalizeIngest → P inserts (token_ids, kv_indices) into its radix
# tree so subsequent prefill requests for this
# session see a cache hit.
#
# Each step is its own ReqInput/ReqOutput pair so the scheduler handlers can
# be written stateless and the orchestrator can retry / abort cleanly.
# ---------------------------------------------------------------------------
@dataclass
class SnapshotPrepareReceiveReqInput(BaseReq):
"""P-side: allocate slots + register them with mooncake for D to push into."""
session_id: str
num_tokens: int # P will alloc this many contiguous slots
expected_bytes_per_layer_k: int = 0 # per-token K bytes × num_tokens (sanity)
expected_bytes_per_layer_v: int = 0 # per-token V bytes × num_tokens (sanity)
@dataclass
class SnapshotPrepareReceiveReqOutput(BaseReq):
ok: bool
reason: Optional[str] = None
# Layout the D side needs to address P's kv_pool slots:
# k_base_ptrs[L] = base device address of layer L's K buffer
# v_base_ptrs[L] = base device address of layer L's V buffer
# slot_indices = the contiguous range P allocated (list[int], length=num_tokens)
# stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize
# stride_v_bytes = bytes per token V (often equals stride_k_bytes)
# P also registers these slot regions with mooncake before returning.
k_base_ptrs: List[int] = field(default_factory=list)
v_base_ptrs: List[int] = field(default_factory=list)
slot_indices: List[int] = field(default_factory=list)
stride_k_bytes: int = 0
stride_v_bytes: int = 0
layer_num: int = 0
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
snapshot_session_id: str = ""
available_tokens: int = 0
@dataclass
class SnapshotDumpReqInput(BaseReq):
"""D-side: dump session KV via snapshot_link to a target P endpoint."""
session_id: str
target_snapshot_session_id: str # P's mooncake snapshot session id
target_k_base_ptrs: List[int] = field(default_factory=list)
target_v_base_ptrs: List[int] = field(default_factory=list)
target_slot_indices: List[int] = field(default_factory=list)
target_stride_k_bytes: int = 0
target_stride_v_bytes: int = 0
ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization
@dataclass
class SnapshotDumpReqOutput(BaseReq):
ok: bool
reason: Optional[str] = None
bytes_pushed: int = 0
transfer_duration_ms: float = 0.0
kv_committed_len: int = 0 # the actual number of tokens D had for this session
# The token_ids that go with the KV (so P can call radix_cache.insert)
token_ids: List[int] = field(default_factory=list)
@dataclass
class SnapshotFinalizeIngestReqInput(BaseReq):
"""P-side: insert (token_ids, kv_indices) into radix tree after D's push."""
session_id: str
token_ids: List[int]
slot_indices: List[int]
@dataclass
class SnapshotFinalizeIngestReqOutput(BaseReq):
ok: bool
reason: Optional[str] = None
inserted_prefix_len: int = 0
class ExpertDistributionReqType(Enum):
START_RECORD = 1
STOP_RECORD = 2

View File

@@ -96,6 +96,12 @@ from sglang.srt.managers.io_struct import (
ContinueGenerationReqInput,
DirectAppendAdmissionReqInput,
DirectAppendAdmissionReqOutput,
SnapshotDumpReqInput,
SnapshotDumpReqOutput,
SnapshotFinalizeIngestReqInput,
SnapshotFinalizeIngestReqOutput,
SnapshotPrepareReceiveReqInput,
SnapshotPrepareReceiveReqOutput,
DestroyWeightsUpdateGroupReqInput,
DetachHiCacheStorageReqInput,
DetachHiCacheStorageReqOutput,
@@ -844,6 +850,69 @@ class Scheduler(
embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get()
init_mm_embedding_cache(embedding_cache_size * 1024 * 1024)
# ---- D→P snapshot link (Phase 2 of D→P sync feature) ------------
# Enabled per-worker via SGLANG_SNAPSHOT_LINK_ENABLE=1. Each worker
# binds an independent mooncake transfer engine on
# SGLANG_SNAPSHOT_LINK_HOST:SGLANG_SNAPSHOT_LINK_PORT and pre-
# registers the kv_pool layer buffers for one-shot RDMA pushes /
# receives. See docs/D_TO_P_SYNC_DESIGN_ZH.md.
self.snapshot_link_controller = None
from sglang.srt.disaggregation.snapshot import (
SnapshotLinkController as _SnapLinkCtrl,
SNAPSHOT_LINK_ENABLE_ENV,
SNAPSHOT_LINK_HOST_ENV,
SNAPSHOT_LINK_PORT_ENV,
SNAPSHOT_LINK_IB_DEVICE_ENV,
)
if os.environ.get(SNAPSHOT_LINK_ENABLE_ENV, "0") == "1":
host = os.environ.get(SNAPSHOT_LINK_HOST_ENV, server_args.host)
port = int(os.environ.get(SNAPSHOT_LINK_PORT_ENV,
str(server_args.disaggregation_bootstrap_port + 1000)))
ib = os.environ.get(SNAPSHOT_LINK_IB_DEVICE_ENV, server_args.disaggregation_ib_device)
try:
kv_pool = self.token_to_kv_pool_allocator.get_kvcache()
except AttributeError:
# Some allocators expose the pool directly
kv_pool = getattr(self.token_to_kv_pool_allocator, "kvcache", None)
if kv_pool is None:
logger.warning("SNAPSHOT_LINK_ENABLE=1 but kv_pool unavailable; skipping init")
else:
try:
kv_data_ptrs, kv_data_lens, kv_item_lens = kv_pool.get_contiguous_buf_infos()
layer_n = len(kv_data_ptrs) // 2
layer_buffers = []
# K layers first, then V layers (matches MHATokenToKVPool.get_contiguous_buf_infos)
for i in range(layer_n):
layer_buffers.append((
kv_data_ptrs[i],
kv_item_lens[i] // max(1, kv_pool.page_size),
kv_data_lens[i],
True, # is_k
))
for i in range(layer_n):
layer_buffers.append((
kv_data_ptrs[layer_n + i],
kv_item_lens[layer_n + i] // max(1, kv_pool.page_size),
kv_data_lens[layer_n + i],
False, # is_k=False (V)
))
self.snapshot_link_controller = _SnapLinkCtrl(
host=host,
port=port,
ib_device=ib,
kv_pool_layer_buffers=layer_buffers,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
logger.info(
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
f"{host}:{port}",
self.snapshot_link_controller.snapshot_session_id,
len(layer_buffers),
)
except Exception as e:
logger.warning("Snapshot link init failed: %s; continuing without it", e)
self.snapshot_link_controller = None
def init_running_status(self):
self.waiting_queue: List[Req] = []
self.decode_direct_waiting_queue: List[Req] = []
@@ -1219,6 +1288,9 @@ class Scheduler(
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(DirectAppendAdmissionReqInput, self.admit_direct_append),
(SnapshotPrepareReceiveReqInput, self.snapshot_prepare_receive),
(SnapshotDumpReqInput, self.snapshot_dump),
(SnapshotFinalizeIngestReqInput, self.snapshot_finalize_ingest),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
@@ -3673,6 +3745,162 @@ class Scheduler(
),
)
# ----- D→P snapshot link handlers (Phase 2/3) ---------------------
def snapshot_prepare_receive(
self, recv_req: SnapshotPrepareReceiveReqInput
) -> SnapshotPrepareReceiveReqOutput:
"""P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push."""
ctrl = self.snapshot_link_controller
if ctrl is None:
return SnapshotPrepareReceiveReqOutput(
ok=False, reason="snapshot-link-disabled",
)
try:
available = int(self.token_to_kv_pool_allocator.available_size())
except Exception:
available = -1
if recv_req.num_tokens <= 0:
return SnapshotPrepareReceiveReqOutput(ok=False, reason="zero-tokens")
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
if record is None:
return SnapshotPrepareReceiveReqOutput(
ok=False, reason="alloc-failed",
available_tokens=available,
)
return SnapshotPrepareReceiveReqOutput(
ok=True,
k_base_ptrs=ctrl.get_k_base_ptrs(),
v_base_ptrs=ctrl.get_v_base_ptrs(),
slot_indices=record.slot_indices,
stride_k_bytes=ctrl.get_stride_k_bytes(),
stride_v_bytes=ctrl.get_stride_v_bytes(),
layer_num=ctrl.layer_num,
snapshot_session_id=ctrl.snapshot_session_id,
available_tokens=available,
)
def snapshot_dump(
self, recv_req: SnapshotDumpReqInput
) -> SnapshotDumpReqOutput:
"""D-side: gather session KV from kv_pool, RDMA-write to remote slots."""
ctrl = self.snapshot_link_controller
if ctrl is None:
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
if not isinstance(self.tree_cache, SessionAwareCache):
return SnapshotDumpReqOutput(ok=False, reason="tree-cache-not-session-aware")
slot = self.tree_cache.slots.get(recv_req.session_id)
if slot is None or slot.req_pool_idx is None:
return SnapshotDumpReqOutput(ok=False, reason="session-not-resident")
kv_committed_len = int(slot.kv_committed_len)
if kv_committed_len == 0:
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
# Read kv_indices for the session's prefix
try:
kv_idx_tensor = self.req_to_token_pool.req_to_token[
slot.req_pool_idx, :kv_committed_len
]
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
except Exception as e:
logger.exception("snapshot_dump: failed to read kv_indices: %s", e)
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}")
# Truncate to the count P prepared for (must match)
target_n = len(recv_req.target_slot_indices)
if target_n > kv_committed_len:
return SnapshotDumpReqOutput(
ok=False,
reason=f"target-larger-than-source({target_n}>{kv_committed_len})",
)
src_slot_indices = src_slot_indices[:target_n]
try:
ret, bytes_pushed = ctrl.push_session_kv(
target_snapshot_session_id=recv_req.target_snapshot_session_id,
src_slot_indices=src_slot_indices,
target_k_base_ptrs=recv_req.target_k_base_ptrs,
target_v_base_ptrs=recv_req.target_v_base_ptrs,
target_slot_indices=recv_req.target_slot_indices[:target_n],
target_stride_k_bytes=recv_req.target_stride_k_bytes,
target_stride_v_bytes=recv_req.target_stride_v_bytes,
)
except Exception as e:
logger.exception("snapshot_dump: push_session_kv threw: %s", e)
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}")
if ret != 0:
return SnapshotDumpReqOutput(
ok=False,
reason=f"mooncake-batch-write-ret={ret}",
bytes_pushed=int(bytes_pushed),
kv_committed_len=int(kv_committed_len),
)
return SnapshotDumpReqOutput(
ok=True, bytes_pushed=int(bytes_pushed),
kv_committed_len=int(kv_committed_len),
token_ids=[], # caller already has token_ids
)
def snapshot_finalize_ingest(
self, recv_req: SnapshotFinalizeIngestReqInput
) -> SnapshotFinalizeIngestReqOutput:
"""P-side: insert (token_ids, slot_indices) into radix tree."""
ctrl = self.snapshot_link_controller
if ctrl is None:
return SnapshotFinalizeIngestReqOutput(
ok=False, reason="snapshot-link-disabled",
)
record = ctrl.take_record(recv_req.session_id)
if record is None:
return SnapshotFinalizeIngestReqOutput(
ok=False, reason="no-pending-ingest",
)
# Sanity: the slot indices we're about to insert should match the ones we reserved.
if list(recv_req.slot_indices) != record.slot_indices:
# The caller passed back the slot indices we returned in prepare; if they
# don't match, something's gone wrong. Free reserved slots and bail.
try:
ctrl._free_slots(record.slot_indices)
except Exception:
pass
return SnapshotFinalizeIngestReqOutput(
ok=False,
reason="slot-indices-mismatch",
)
n_tokens = min(len(recv_req.token_ids), len(record.slot_indices))
if n_tokens == 0:
ctrl._free_slots(record.slot_indices)
return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids")
try:
import torch
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
from sglang.srt.mem_cache.radix_cache import RadixKey
kv_indices = torch.tensor(
record.slot_indices[:n_tokens],
dtype=torch.int64,
device=self.tree_cache.token_to_kv_pool_allocator.device,
)
radix_key = RadixKey(recv_req.token_ids[:n_tokens], None)
inner = (
self.tree_cache.inner
if isinstance(self.tree_cache, SessionAwareCache)
else self.tree_cache
)
result = inner.insert(InsertParams(key=radix_key, value=kv_indices))
inserted = int(result.prefix_len)
except Exception as e:
logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e)
try:
ctrl._free_slots(record.slot_indices)
except Exception:
pass
return SnapshotFinalizeIngestReqOutput(
ok=False, reason=f"radix-insert-failed: {e!r}",
)
return SnapshotFinalizeIngestReqOutput(
ok=True, inserted_prefix_len=inserted,
)
def _compute_backpressure_pause_hint(
self,
*,

View File

@@ -74,6 +74,12 @@ from sglang.srt.managers.io_struct import (
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
SnapshotDumpReqInput,
SnapshotDumpReqOutput,
SnapshotFinalizeIngestReqInput,
SnapshotFinalizeIngestReqOutput,
SnapshotPrepareReceiveReqInput,
SnapshotPrepareReceiveReqOutput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightsFromDistributedReqInput,
@@ -225,6 +231,15 @@ class TokenizerCommunicatorMixin:
self.direct_append_admission_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.snapshot_prepare_receive_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.snapshot_dump_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.snapshot_finalize_ingest_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -325,6 +340,18 @@ class TokenizerCommunicatorMixin:
DirectAppendAdmissionReqOutput,
self.direct_append_admission_communicator.handle_recv,
),
(
SnapshotPrepareReceiveReqOutput,
self.snapshot_prepare_receive_communicator.handle_recv,
),
(
SnapshotDumpReqOutput,
self.snapshot_dump_communicator.handle_recv,
),
(
SnapshotFinalizeIngestReqOutput,
self.snapshot_finalize_ingest_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
@@ -890,6 +917,36 @@ class TokenizerCommunicatorMixin:
)
return responses[0]
async def snapshot_prepare_receive(
self: TokenizerManager,
obj: SnapshotPrepareReceiveReqInput,
) -> SnapshotPrepareReceiveReqOutput:
self.auto_create_handle_loop()
responses: List[SnapshotPrepareReceiveReqOutput] = (
await self.snapshot_prepare_receive_communicator(obj)
)
return responses[0]
async def snapshot_dump(
self: TokenizerManager,
obj: SnapshotDumpReqInput,
) -> SnapshotDumpReqOutput:
self.auto_create_handle_loop()
responses: List[SnapshotDumpReqOutput] = (
await self.snapshot_dump_communicator(obj)
)
return responses[0]
async def snapshot_finalize_ingest(
self: TokenizerManager,
obj: SnapshotFinalizeIngestReqInput,
) -> SnapshotFinalizeIngestReqOutput:
self.auto_create_handle_loop()
responses: List[SnapshotFinalizeIngestReqOutput] = (
await self.snapshot_finalize_ingest_communicator(obj)
)
return responses[0]
async def set_internal_state(
self: TokenizerManager, obj: SetInternalStateReq
) -> List[bool]: