refactor(snapshot): dedicated GPU snapshot_buf replaces kv_pool alloc
Implements the design in docs/SNAPSHOT_STORE_REFACTOR_ZH.md to fix
the alloc-failed death loop that killed D→P in E4-v4/v5 (167 sync
attempts, 0 OK because P's kv_pool was busy with its own prefill).
Mechanism change:
OLD prepare_receive: token_to_kv_pool_allocator.alloc(N) — 90%+ failure
NEW prepare_receive: SnapshotBufAllocator.alloc(slab_bytes) carves a
range from an 8 GB GPU buffer dedicated to
snapshot reception, decoupled from kv_pool
OLD finalize_ingest: just radix.insert with pre-alloc'd slots
NEW finalize_ingest: kv_pool.alloc NOW + GPU memcpy snapshot_buf →
k_buffer/v_buffer + radix.insert
Wire schema changed (clean break, no back-compat):
PrepareReceiveReqOutput swaps k/v_base_ptrs + slot_indices for
snapshot_buf_base_ptr + k/v_layer_offsets +
num_tokens
DumpReqInput swaps target_k/v_base_ptrs + target_slot_indices
for target_snapshot_buf_base +
target_k/v_layer_offsets
FinalizeIngestReqInput drops slot_indices (P resolves at ingest)
Controller adds:
SnapshotBufAllocator: first-fit free-list with 4 KB alignment
ingest_snapshot_into_kvpool: GPU→GPU copy + radix insert
Configurable buffer size via SGLANG_SNAPSHOT_LINK_BUF_BYTES env
(default 8 GB, scales down to 1 GB if alloc fails).
Removed runtime leak-check accommodation since prepare_receive no
longer touches kv_pool.
Total: ~365 LOC including alloc helper; smoke-test verification next.
This commit is contained in:
@@ -2187,9 +2187,9 @@ async def _attempt_d_to_p_sync(
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"target_snapshot_session_id": prep["snapshot_session_id"],
|
||||
"target_k_base_ptrs": prep["k_base_ptrs"],
|
||||
"target_v_base_ptrs": prep["v_base_ptrs"],
|
||||
"target_slot_indices": prep["slot_indices"],
|
||||
"target_snapshot_buf_base": prep["snapshot_buf_base_ptr"],
|
||||
"target_k_layer_offsets": prep["k_layer_offsets"],
|
||||
"target_v_layer_offsets": prep["v_layer_offsets"],
|
||||
"target_stride_k_bytes": prep["stride_k_bytes"],
|
||||
"target_stride_v_bytes": prep["stride_v_bytes"],
|
||||
},
|
||||
@@ -2220,15 +2220,13 @@ async def _attempt_d_to_p_sync(
|
||||
# for the first N — use that as best-available approximation.
|
||||
tokens = list(getattr(request, "input_token_ids", []) or [])
|
||||
if not tokens:
|
||||
# No token_ids available — can't insert into radix. P will fall back
|
||||
# to normal prefill but will have wasted slots. Discard.
|
||||
# No token_ids → can't insert into radix; tell P to free the slab.
|
||||
try:
|
||||
await client.post(
|
||||
f"{prefill_url}/_snapshot/finalize_ingest",
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"token_ids": [],
|
||||
"slot_indices": prep["slot_indices"],
|
||||
},
|
||||
timeout=15.0,
|
||||
)
|
||||
@@ -2242,7 +2240,7 @@ async def _attempt_d_to_p_sync(
|
||||
)
|
||||
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||
|
||||
n = min(len(tokens), len(prep["slot_indices"]))
|
||||
n = min(len(tokens), int(prep.get("num_tokens", 0)))
|
||||
t_fin0 = time.perf_counter()
|
||||
try:
|
||||
fin_resp = await client.post(
|
||||
@@ -2250,7 +2248,6 @@ async def _attempt_d_to_p_sync(
|
||||
json={
|
||||
"session_id": request.session_id,
|
||||
"token_ids": tokens[:n],
|
||||
"slot_indices": prep["slot_indices"][:n],
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
@@ -1,34 +1,29 @@
|
||||
"""SnapshotLinkController — drives D→P RDMA snapshot pushes.
|
||||
"""SnapshotLinkController — D→P RDMA snapshot pushes with dedicated GPU buffer.
|
||||
|
||||
This class owns:
|
||||
* A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD
|
||||
pipeline engine).
|
||||
* Memory-region registrations covering the worker's KV pool layer buffers
|
||||
(registered once at startup for zero per-snapshot overhead).
|
||||
* A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side
|
||||
receivers tracking outstanding ingests until ``snapshot_finalize_ingest``
|
||||
is called.
|
||||
Per `docs/SNAPSHOT_STORE_REFACTOR_ZH.md`, this controller now reserves a
|
||||
dedicated GPU tensor (``snapshot_buf``) for receiving D→P snapshots, instead
|
||||
of competing with the worker's ``token_to_kv_pool_allocator`` at
|
||||
prepare_receive time. The kv_pool alloc is deferred to ``finalize_ingest``
|
||||
when the bytes are already in hand — if that alloc fails we drop the
|
||||
snapshot but RDMA reception itself succeeded.
|
||||
|
||||
Lifecycle:
|
||||
SGLang scheduler instantiates one of these per worker if
|
||||
``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for
|
||||
feeding kv_pool buffer descriptors and the kv_pool_allocator so the
|
||||
controller can pre-register memory and (on P) allocate slots.
|
||||
Layout of the snapshot_buf for one session reception (chosen for
|
||||
mooncake's batch_transfer_sync_write friendliness — every layer maps to
|
||||
a single contiguous slab):
|
||||
|
||||
Direction symmetry:
|
||||
Both D and P workers run identical controllers. Who sends and who
|
||||
receives is determined by who calls ``push_session_kv`` vs
|
||||
``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE
|
||||
role baked into the snapshot side.
|
||||
[K_layer_0: num_tokens × stride_k_bytes]
|
||||
[K_layer_1: num_tokens × stride_k_bytes]
|
||||
...
|
||||
[K_layer_L-1]
|
||||
[V_layer_0: num_tokens × stride_v_bytes]
|
||||
...
|
||||
[V_layer_L-1]
|
||||
|
||||
This is the **vendored** copy living alongside SGLang internals so the
|
||||
scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py``
|
||||
contains the same primitives for stand-alone smoke testing.
|
||||
The buffer is split into multiple such slabs via ``SnapshotBufAllocator``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -44,6 +39,10 @@ SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST"
|
||||
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
|
||||
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
|
||||
|
||||
# Default snapshot_buf size: 8 GB. Enough for ~1.5 Qwen3-30B 50k-token sessions.
|
||||
SNAPSHOT_BUF_BYTES_ENV = "SGLANG_SNAPSHOT_LINK_BUF_BYTES"
|
||||
DEFAULT_SNAPSHOT_BUF_BYTES = 8 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LayerBufferDesc:
|
||||
@@ -56,13 +55,75 @@ class _LayerBufferDesc:
|
||||
|
||||
@dataclass
|
||||
class SnapshotIngestRecord:
|
||||
"""P-side: bookkeeping for an outstanding incoming snapshot."""
|
||||
"""P-side bookkeeping for one in-flight snapshot reception."""
|
||||
session_id: str
|
||||
slot_indices: List[int] # kv_pool slots reserved for this ingest
|
||||
slab_offset: int # offset within snapshot_buf
|
||||
slab_size: int # total bytes for this slab
|
||||
num_tokens: int
|
||||
k_layer_offsets: List[int] # absolute byte offsets of K layers in snapshot_buf
|
||||
v_layer_offsets: List[int]
|
||||
per_token_k_bytes: int
|
||||
per_token_v_bytes: int
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class SnapshotBufAllocator:
|
||||
"""First-fit free-list allocator over a single contiguous byte range.
|
||||
|
||||
Tracks gaps in a sorted list. Merges adjacent free regions on free().
|
||||
"""
|
||||
|
||||
def __init__(self, capacity_bytes: int):
|
||||
self.capacity = capacity_bytes
|
||||
# Free regions sorted by offset: [(offset, size), ...]
|
||||
self._free: List[Tuple[int, int]] = [(0, capacity_bytes)]
|
||||
self._lock = threading.Lock()
|
||||
self._inflight: dict[int, int] = {} # offset → size for sanity check
|
||||
|
||||
def alloc(self, size: int) -> Optional[int]:
|
||||
"""Return offset of allocated region, or None if no fit available."""
|
||||
if size <= 0:
|
||||
return None
|
||||
# Page-align allocations to 4 KB for RDMA-friendly alignment.
|
||||
size = (size + 4095) & ~4095
|
||||
with self._lock:
|
||||
for i, (off, sz) in enumerate(self._free):
|
||||
if sz >= size:
|
||||
if sz == size:
|
||||
self._free.pop(i)
|
||||
else:
|
||||
self._free[i] = (off + size, sz - size)
|
||||
self._inflight[off] = size
|
||||
return off
|
||||
return None
|
||||
|
||||
def free(self, offset: int) -> bool:
|
||||
"""Return True if the offset was successfully freed."""
|
||||
with self._lock:
|
||||
size = self._inflight.pop(offset, None)
|
||||
if size is None:
|
||||
return False
|
||||
# Insert sorted and merge adjacents
|
||||
self._free.append((offset, size))
|
||||
self._free.sort()
|
||||
merged: List[Tuple[int, int]] = []
|
||||
for off, sz in self._free:
|
||||
if merged and merged[-1][0] + merged[-1][1] == off:
|
||||
merged[-1] = (merged[-1][0], merged[-1][1] + sz)
|
||||
else:
|
||||
merged.append((off, sz))
|
||||
self._free = merged
|
||||
return True
|
||||
|
||||
def available_bytes(self) -> int:
|
||||
with self._lock:
|
||||
return sum(sz for _, sz in self._free)
|
||||
|
||||
def in_use_bytes(self) -> int:
|
||||
with self._lock:
|
||||
return sum(self._inflight.values())
|
||||
|
||||
|
||||
def _import_transfer_engine():
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
@@ -75,7 +136,13 @@ def _import_transfer_engine():
|
||||
|
||||
|
||||
class SnapshotLinkController:
|
||||
"""Mooncake engine + per-layer mem registrations + ingest bookkeeping."""
|
||||
"""Owns mooncake engine + kv_pool registrations + snapshot_buf + records.
|
||||
|
||||
D-side use: push session KV via ``push_session_to_snapshot_buf``.
|
||||
P-side use: ``prepare_receive`` → caller pushes via RDMA →
|
||||
``ingest_snapshot_into_kvpool`` (does GPU memcpy +
|
||||
radix insert) → ``finalize_record`` (frees the slab).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,24 +151,16 @@ class SnapshotLinkController:
|
||||
ib_device: Optional[str],
|
||||
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
|
||||
token_to_kv_pool_allocator,
|
||||
tree_cache=None,
|
||||
protocol: Optional[str] = None,
|
||||
snapshot_buf_bytes: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
host, port : where this worker binds its snapshot engine.
|
||||
ib_device : preferred IB HCA (e.g. "mlx5_60").
|
||||
kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token,
|
||||
capacity_bytes, is_k)`` tuples. Order should be K-layers
|
||||
first then V-layers (matches MHATokenToKVPool layout).
|
||||
token_to_kv_pool_allocator : the worker's allocator. We use
|
||||
``.alloc(N)`` on the P side to reserve receive slots.
|
||||
"""
|
||||
TransferEngine = _import_transfer_engine()
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ib_device = ib_device
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.tree_cache = tree_cache
|
||||
self.layer_buffers: List[_LayerBufferDesc] = [
|
||||
_LayerBufferDesc(
|
||||
base_ptr=base, bytes_per_token=btok,
|
||||
@@ -121,38 +180,67 @@ class SnapshotLinkController:
|
||||
)
|
||||
self._session_id = f"{host}:{self.engine.get_rpc_port()}"
|
||||
|
||||
# Register all layer buffers up-front (one-shot)
|
||||
# Register existing kv_pool layer buffers (needed for D-side send and
|
||||
# for P-side ingest copy source = snapshot_buf, destination = kv_pool)
|
||||
ptrs = [d.base_ptr for d in self.layer_buffers]
|
||||
lens = [d.capacity_bytes for d in self.layer_buffers]
|
||||
try:
|
||||
reg_ret = self.engine.batch_register_memory(ptrs, lens)
|
||||
except Exception:
|
||||
reg_ret = -1
|
||||
# Fall back to individual register_memory calls
|
||||
reg_ret = 0
|
||||
for ptr, length in zip(ptrs, lens):
|
||||
r = self.engine.register_memory(ptr, length)
|
||||
if r != 0:
|
||||
logger.warning(
|
||||
"SnapshotLinkController register_memory(%s, %d) returned %d",
|
||||
hex(ptr), length, r,
|
||||
)
|
||||
reg_ret = r
|
||||
if reg_ret != 0:
|
||||
logger.warning(
|
||||
"SnapshotLinkController batch_register_memory returned %d "
|
||||
"(continuing — individual registrations may have succeeded)",
|
||||
reg_ret,
|
||||
"SnapshotLinkController kv_pool batch_register returned %d", reg_ret
|
||||
)
|
||||
|
||||
# Allocate + register the dedicated snapshot reception buffer (P-side)
|
||||
# This decouples reception from kv_pool, avoiding the alloc-failed
|
||||
# death loop that killed E4-v4/v5.
|
||||
import torch
|
||||
|
||||
if snapshot_buf_bytes is None:
|
||||
snapshot_buf_bytes = int(
|
||||
os.environ.get(SNAPSHOT_BUF_BYTES_ENV, DEFAULT_SNAPSHOT_BUF_BYTES)
|
||||
)
|
||||
device = self._allocator_device()
|
||||
try:
|
||||
self.snapshot_buf = torch.zeros(
|
||||
snapshot_buf_bytes, dtype=torch.uint8, device=device,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
"Could not allocate snapshot_buf of %d bytes on %s: %s. "
|
||||
"Falling back to 1 GB.", snapshot_buf_bytes, device, e,
|
||||
)
|
||||
snapshot_buf_bytes = 1024 * 1024 * 1024
|
||||
self.snapshot_buf = torch.zeros(
|
||||
snapshot_buf_bytes, dtype=torch.uint8, device=device,
|
||||
)
|
||||
self._snapshot_buf_bytes = snapshot_buf_bytes
|
||||
self._snapshot_buf_ptr = self.snapshot_buf.data_ptr()
|
||||
ret = self.engine.register_memory(self._snapshot_buf_ptr, snapshot_buf_bytes)
|
||||
if ret != 0:
|
||||
logger.warning(
|
||||
"SnapshotLinkController snapshot_buf register_memory(%s, %d) ret=%d",
|
||||
hex(self._snapshot_buf_ptr), snapshot_buf_bytes, ret,
|
||||
)
|
||||
self.snapshot_buf_alloc = SnapshotBufAllocator(snapshot_buf_bytes)
|
||||
|
||||
# Receive-side bookkeeping
|
||||
self._ingest_records: dict[str, SnapshotIngestRecord] = {}
|
||||
self._records_by_handle: dict[int, SnapshotIngestRecord] = {}
|
||||
self._next_handle = 1
|
||||
self._lock = threading.Lock()
|
||||
|
||||
logger.info(
|
||||
"SnapshotLinkController up at %s (snapshot_session_id=%s, "
|
||||
"%d layer buffers registered, ib=%s)",
|
||||
listen, self._session_id, len(self.layer_buffers), ib_device,
|
||||
"SnapshotLinkController up at %s (sid=%s, %d kv layer bufs, "
|
||||
"snapshot_buf=%.1f GB on %s)",
|
||||
listen, self._session_id, len(self.layer_buffers),
|
||||
snapshot_buf_bytes / 1e9, device,
|
||||
)
|
||||
|
||||
# ----- accessors ----------------------------------------------------
|
||||
@@ -161,9 +249,16 @@ class SnapshotLinkController:
|
||||
def snapshot_session_id(self) -> str:
|
||||
return self._session_id
|
||||
|
||||
@property
|
||||
def snapshot_buf_ptr(self) -> int:
|
||||
return self._snapshot_buf_ptr
|
||||
|
||||
@property
|
||||
def snapshot_buf_bytes(self) -> int:
|
||||
return self._snapshot_buf_bytes
|
||||
|
||||
@property
|
||||
def layer_num(self) -> int:
|
||||
"""Number of K layers (= number of V layers = total / 2)."""
|
||||
return len(self.layer_buffers) // 2
|
||||
|
||||
def get_k_base_ptrs(self) -> List[int]:
|
||||
@@ -184,66 +279,6 @@ class SnapshotLinkController:
|
||||
return d.bytes_per_token
|
||||
return 0
|
||||
|
||||
# ----- P-side: prepare to receive ----------------------------------
|
||||
|
||||
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
|
||||
"""Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot.
|
||||
|
||||
Returns a record with the slot indices, or ``None`` if no capacity.
|
||||
Mooncake registration is already in place (whole-buffer registration
|
||||
at startup), so no per-snapshot register is needed.
|
||||
"""
|
||||
try:
|
||||
indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
except Exception as e:
|
||||
logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e)
|
||||
return None
|
||||
if indices_tensor is None:
|
||||
return None
|
||||
try:
|
||||
slot_indices = [int(x) for x in indices_tensor.tolist()]
|
||||
except Exception:
|
||||
# If allocator returns a python list directly
|
||||
slot_indices = list(map(int, indices_tensor))
|
||||
record = SnapshotIngestRecord(
|
||||
session_id=session_id,
|
||||
slot_indices=slot_indices,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
with self._lock:
|
||||
old = self._ingest_records.pop(session_id, None)
|
||||
if old is not None:
|
||||
# Best-effort: free old slots if we're overwriting
|
||||
try:
|
||||
self._free_slots(old.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
self._ingest_records[session_id] = record
|
||||
return record
|
||||
|
||||
def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]:
|
||||
with self._lock:
|
||||
return self._ingest_records.pop(session_id, None)
|
||||
|
||||
def discard_record(self, session_id: str) -> None:
|
||||
"""Drop a pending ingest (e.g. on timeout or D-side failure)."""
|
||||
with self._lock:
|
||||
rec = self._ingest_records.pop(session_id, None)
|
||||
if rec is not None:
|
||||
try:
|
||||
self._free_slots(rec.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _free_slots(self, slot_indices: List[int]) -> None:
|
||||
import torch
|
||||
t = torch.tensor(slot_indices, dtype=torch.int64,
|
||||
device=self._allocator_device())
|
||||
try:
|
||||
self.token_to_kv_pool_allocator.free(t)
|
||||
except Exception as e:
|
||||
logger.warning("SnapshotLinkController._free_slots failed: %s", e)
|
||||
|
||||
def _allocator_device(self):
|
||||
# Best-effort: pull device from one of the buffer tensors via the allocator
|
||||
try:
|
||||
@@ -251,91 +286,291 @@ class SnapshotLinkController:
|
||||
except AttributeError:
|
||||
return "cuda"
|
||||
|
||||
# ----- D-side: push session KV --------------------------------------
|
||||
# ----- P-side: prepare to receive ----------------------------------
|
||||
|
||||
def push_session_kv(
|
||||
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
|
||||
"""Carve a slab out of snapshot_buf large enough for num_tokens of K+V.
|
||||
|
||||
Returns the record describing the slab layout, or None if snapshot_buf
|
||||
is full. This does NOT touch kv_pool — alloc happens at ingest time.
|
||||
"""
|
||||
if num_tokens <= 0:
|
||||
return None
|
||||
stride_k = self.get_stride_k_bytes()
|
||||
stride_v = self.get_stride_v_bytes()
|
||||
L = self.layer_num
|
||||
slab_bytes = L * num_tokens * stride_k + L * num_tokens * stride_v
|
||||
offset = self.snapshot_buf_alloc.alloc(slab_bytes)
|
||||
if offset is None:
|
||||
logger.info(
|
||||
"prepare_receive: snapshot_buf full (sid=%s n=%d need=%d B available=%d B)",
|
||||
session_id, num_tokens, slab_bytes,
|
||||
self.snapshot_buf_alloc.available_bytes(),
|
||||
)
|
||||
return None
|
||||
# Layout: K0..KL-1, then V0..VL-1
|
||||
k_offs = [offset + i * num_tokens * stride_k for i in range(L)]
|
||||
v_offs = [offset + L * num_tokens * stride_k + i * num_tokens * stride_v
|
||||
for i in range(L)]
|
||||
record = SnapshotIngestRecord(
|
||||
session_id=session_id,
|
||||
slab_offset=offset,
|
||||
slab_size=slab_bytes,
|
||||
num_tokens=num_tokens,
|
||||
k_layer_offsets=k_offs,
|
||||
v_layer_offsets=v_offs,
|
||||
per_token_k_bytes=stride_k,
|
||||
per_token_v_bytes=stride_v,
|
||||
)
|
||||
with self._lock:
|
||||
# Evict prior record for the same session (best-effort)
|
||||
old = self._ingest_records.pop(session_id, None)
|
||||
if old is not None:
|
||||
self.snapshot_buf_alloc.free(old.slab_offset)
|
||||
self._records_by_handle.pop(id(old), None)
|
||||
self._ingest_records[session_id] = record
|
||||
self._records_by_handle[id(record)] = record
|
||||
return record
|
||||
|
||||
def lookup_by_handle(self, handle: int) -> Optional[SnapshotIngestRecord]:
|
||||
with self._lock:
|
||||
return self._records_by_handle.get(handle)
|
||||
|
||||
def discard_record(self, session_id: str) -> None:
|
||||
with self._lock:
|
||||
rec = self._ingest_records.pop(session_id, None)
|
||||
if rec is not None:
|
||||
self.snapshot_buf_alloc.free(rec.slab_offset)
|
||||
with self._lock:
|
||||
self._records_by_handle.pop(id(rec), None)
|
||||
|
||||
def total_pending_snapshot_bytes(self) -> int:
|
||||
with self._lock:
|
||||
return sum(rec.slab_size for rec in self._ingest_records.values())
|
||||
|
||||
# ----- P-side: ingest snapshot into kv_pool + radix tree -----------
|
||||
|
||||
def ingest_snapshot_into_kvpool(
|
||||
self,
|
||||
session_id: str,
|
||||
token_ids: List[int],
|
||||
) -> Tuple[bool, str, int]:
|
||||
"""Copy snapshot_buf bytes into kv_pool slots and insert into radix.
|
||||
|
||||
Returns (ok, reason, inserted_prefix_len).
|
||||
"""
|
||||
with self._lock:
|
||||
record = self._ingest_records.pop(session_id, None)
|
||||
if record is not None:
|
||||
self._records_by_handle.pop(id(record), None)
|
||||
if record is None:
|
||||
return False, "no-pending-ingest", 0
|
||||
|
||||
try:
|
||||
n = min(len(token_ids), record.num_tokens)
|
||||
if n == 0:
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return False, "empty-token-ids", 0
|
||||
|
||||
# Alloc kv_pool slots NOW that the snapshot bytes are in hand.
|
||||
try:
|
||||
indices_tensor = self.token_to_kv_pool_allocator.alloc(n)
|
||||
except Exception as exc:
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return False, f"kvpool-alloc-threw:{exc!r}", 0
|
||||
if indices_tensor is None:
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return False, "kvpool-alloc-failed-at-ingest", 0
|
||||
|
||||
# GPU→GPU copy from snapshot_buf into kv_pool layer buffers
|
||||
try:
|
||||
self._copy_snapshot_to_kvpool(record, indices_tensor)
|
||||
except Exception as exc:
|
||||
logger.exception("snapshot→kvpool copy failed: %s", exc)
|
||||
# Free both allocations
|
||||
self._free_slot_indices(indices_tensor)
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return False, f"copy-failed:{exc!r}", 0
|
||||
|
||||
# Insert into radix tree
|
||||
try:
|
||||
inserted_prefix_len = self._radix_insert(token_ids[:n], indices_tensor)
|
||||
except Exception as exc:
|
||||
logger.exception("radix insert failed: %s", exc)
|
||||
self._free_slot_indices(indices_tensor)
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return False, f"radix-insert-failed:{exc!r}", 0
|
||||
|
||||
# Snapshot is now persisted into kv_pool + radix; the slab is no
|
||||
# longer needed.
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
return True, "ok", int(inserted_prefix_len)
|
||||
except Exception as exc:
|
||||
# Belt-and-braces cleanup
|
||||
try:
|
||||
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||
except Exception:
|
||||
pass
|
||||
return False, f"unexpected:{exc!r}", 0
|
||||
|
||||
def _copy_snapshot_to_kvpool(
|
||||
self,
|
||||
record: SnapshotIngestRecord,
|
||||
slot_indices_tensor,
|
||||
) -> None:
|
||||
"""For each layer L: copy snapshot_buf[K_off[L]..] → k_buffer[L][slots]."""
|
||||
import torch
|
||||
|
||||
n = record.num_tokens
|
||||
stride_k = record.per_token_k_bytes
|
||||
stride_v = record.per_token_v_bytes
|
||||
# View snapshot_buf as a 1-D byte tensor; slice by offsets.
|
||||
for L in range(self.layer_num):
|
||||
# K
|
||||
k_slab_start = record.k_layer_offsets[L] - record.slab_offset + record.slab_offset
|
||||
# NOTE: above is equivalent to record.k_layer_offsets[L] but kept for clarity
|
||||
k_slab_start = record.k_layer_offsets[L]
|
||||
k_layer_bytes = self.snapshot_buf[
|
||||
k_slab_start : k_slab_start + n * stride_k
|
||||
].view(n, stride_k)
|
||||
# Compute destination tensor on kv_pool: dst[slot_indices] = src
|
||||
# We need access to the actual k_buffer[L] tensor. The controller
|
||||
# only has the raw ptr — so we materialize a view via from_blob-ish
|
||||
# trick. Easier: get the tensor from token_to_kv_pool_allocator's kvcache.
|
||||
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||
k_buf = kv_cache.k_buffer[L] # (max_tokens, head, dim)
|
||||
# Flatten per-token to bytes
|
||||
flat = k_buf.view(k_buf.shape[0], -1)
|
||||
assert flat.shape[1] * flat.element_size() >= stride_k, (
|
||||
f"K layer {L} stride mismatch: pool {flat.shape[1] * flat.element_size()} vs snapshot {stride_k}"
|
||||
)
|
||||
# Copy: dst[slot_indices] ← src[:n]
|
||||
src_reshape = k_layer_bytes.view(n, flat.shape[1] * flat.element_size())
|
||||
# Byte-level view of destination rows
|
||||
dst_view = flat.view(torch.uint8)
|
||||
dst_view[slot_indices_tensor] = src_reshape
|
||||
|
||||
# V
|
||||
v_slab_start = record.v_layer_offsets[L]
|
||||
v_layer_bytes = self.snapshot_buf[
|
||||
v_slab_start : v_slab_start + n * stride_v
|
||||
]
|
||||
v_buf = kv_cache.v_buffer[L]
|
||||
v_flat = v_buf.view(v_buf.shape[0], -1)
|
||||
src_v = v_layer_bytes.view(n, v_flat.shape[1] * v_flat.element_size())
|
||||
v_dst_view = v_flat.view(torch.uint8)
|
||||
v_dst_view[slot_indices_tensor] = src_v
|
||||
|
||||
def _radix_insert(self, token_ids: List[int], indices_tensor) -> int:
|
||||
"""Insert (token_ids, kv_indices) into the underlying radix tree."""
|
||||
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache
|
||||
|
||||
inner = self.tree_cache
|
||||
if isinstance(inner, SessionAwareCache):
|
||||
inner = inner.inner
|
||||
if inner is None:
|
||||
raise RuntimeError("tree_cache not provided to SnapshotLinkController")
|
||||
radix_key = RadixKey(token_ids, None)
|
||||
result = inner.insert(InsertParams(key=radix_key, value=indices_tensor))
|
||||
return int(getattr(result, "prefix_len", 0))
|
||||
|
||||
def _free_slot_indices(self, indices_tensor) -> None:
|
||||
try:
|
||||
self.token_to_kv_pool_allocator.free(indices_tensor)
|
||||
except Exception as e:
|
||||
logger.warning("_free_slot_indices failed: %s", e)
|
||||
|
||||
# ----- D-side: push session KV to a peer's snapshot_buf ------------
|
||||
|
||||
def push_session_to_snapshot_buf(
|
||||
self,
|
||||
*,
|
||||
target_snapshot_session_id: str,
|
||||
src_slot_indices: List[int],
|
||||
target_k_base_ptrs: List[int],
|
||||
target_v_base_ptrs: List[int],
|
||||
target_slot_indices: List[int],
|
||||
target_stride_k_bytes: int,
|
||||
target_stride_v_bytes: int,
|
||||
target_snapshot_buf_base: int,
|
||||
target_k_layer_offsets: List[int],
|
||||
target_v_layer_offsets: List[int],
|
||||
target_per_token_k_bytes: int,
|
||||
target_per_token_v_bytes: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""Push the KV bytes at src_slot_indices to the remote slots.
|
||||
"""Push session KV from local kv_pool into a peer's snapshot_buf slab.
|
||||
|
||||
Returns ``(mooncake_return_code, bytes_pushed)``. Caller is
|
||||
responsible for any post-push handshake (e.g. finalize_ingest RPC).
|
||||
For each layer: gather src ranges (possibly scattered slot indices)
|
||||
and write to a contiguous range in the peer's snapshot_buf.
|
||||
Returns (mooncake_return_code, bytes_pushed).
|
||||
"""
|
||||
if not src_slot_indices:
|
||||
return 0, 0
|
||||
layer_num = self.layer_num
|
||||
k_src_bases = self.get_k_base_ptrs()
|
||||
v_src_bases = self.get_v_base_ptrs()
|
||||
stride_k = self.get_stride_k_bytes()
|
||||
stride_v = self.get_stride_v_bytes()
|
||||
if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num:
|
||||
if (len(target_k_layer_offsets) != layer_num
|
||||
or len(target_v_layer_offsets) != layer_num):
|
||||
raise ValueError(
|
||||
f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} "
|
||||
f"!= local layer_num {layer_num}"
|
||||
f"target K/V layer offset count {len(target_k_layer_offsets)}/"
|
||||
f"{len(target_v_layer_offsets)} != local layer_num {layer_num}"
|
||||
)
|
||||
if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes:
|
||||
if (stride_k != target_per_token_k_bytes
|
||||
or stride_v != target_per_token_v_bytes):
|
||||
raise ValueError(
|
||||
f"stride mismatch: local k={stride_k}, v={stride_v}; "
|
||||
f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}"
|
||||
)
|
||||
if len(src_slot_indices) != len(target_slot_indices):
|
||||
raise ValueError(
|
||||
f"slot count mismatch: src={len(src_slot_indices)}, "
|
||||
f"target={len(target_slot_indices)}"
|
||||
f"stride mismatch: local k={stride_k}/v={stride_v}, "
|
||||
f"target k={target_per_token_k_bytes}/v={target_per_token_v_bytes}"
|
||||
)
|
||||
n = len(src_slot_indices)
|
||||
|
||||
local_addrs: List[int] = []
|
||||
remote_addrs: List[int] = []
|
||||
lengths: List[int] = []
|
||||
|
||||
# Group contiguous runs on the target side to coalesce ops.
|
||||
# Simple approach: per (layer, K/V) pair, walk src/target index
|
||||
# tuples; merge runs where both src and target are sequential.
|
||||
for layer_id in range(layer_num):
|
||||
for kv_bases, stride, kv_label in (
|
||||
(k_src_bases[layer_id], stride_k, "K"),
|
||||
(v_src_bases[layer_id], stride_v, "V"),
|
||||
):
|
||||
src_base = kv_bases
|
||||
if kv_label == "K":
|
||||
tgt_base = target_k_base_ptrs[layer_id]
|
||||
else:
|
||||
tgt_base = target_v_base_ptrs[layer_id]
|
||||
# Coalesce contiguous src runs.
|
||||
# Inner-loop helper to walk indices and emit run boundaries.
|
||||
def _emit_runs(src_base: int, tgt_base: int, stride: int) -> None:
|
||||
run_src_start = run_tgt_start = run_len = None
|
||||
for s, t in zip(src_slot_indices, target_slot_indices):
|
||||
for tgt_idx, src in enumerate(src_slot_indices):
|
||||
if run_src_start is None:
|
||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
||||
elif s == run_src_start + run_len and t == run_tgt_start + run_len:
|
||||
run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
|
||||
elif src == run_src_start + run_len:
|
||||
run_len += 1
|
||||
else:
|
||||
local_addrs.append(src_base + run_src_start * stride)
|
||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||
lengths.append(run_len * stride)
|
||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
||||
run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
|
||||
if run_src_start is not None:
|
||||
local_addrs.append(src_base + run_src_start * stride)
|
||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||
lengths.append(run_len * stride)
|
||||
|
||||
for L in range(layer_num):
|
||||
_emit_runs(
|
||||
k_src_bases[L],
|
||||
target_snapshot_buf_base + target_k_layer_offsets[L],
|
||||
stride_k,
|
||||
)
|
||||
_emit_runs(
|
||||
v_src_bases[L],
|
||||
target_snapshot_buf_base + target_v_layer_offsets[L],
|
||||
stride_v,
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
target_snapshot_session_id, local_addrs, remote_addrs, lengths
|
||||
target_snapshot_session_id, local_addrs, remote_addrs, lengths,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("SnapshotLinkController.push_session_kv threw: %s", e)
|
||||
logger.exception(
|
||||
"SnapshotLinkController.push_session_to_snapshot_buf threw: %s", e
|
||||
)
|
||||
return -1, 0
|
||||
t1 = time.perf_counter()
|
||||
bytes_pushed = sum(lengths)
|
||||
logger.info(
|
||||
"SnapshotLinkController.push_session_kv → %s: %d ops, %d B, "
|
||||
"ret=%d, %.2f ms",
|
||||
"push_session_to_snapshot_buf → %s: %d ops, %d B, ret=%d, %.2f ms",
|
||||
target_snapshot_session_id, len(lengths), bytes_pushed, ret,
|
||||
(t1 - t0) * 1000.0,
|
||||
)
|
||||
|
||||
@@ -1662,38 +1662,38 @@ class SnapshotPrepareReceiveReqInput(BaseReq):
|
||||
|
||||
@dataclass
|
||||
class SnapshotPrepareReceiveReqOutput(BaseReq):
|
||||
"""P-side response. New schema points D at P's dedicated snapshot_buf."""
|
||||
|
||||
ok: bool
|
||||
reason: Optional[str] = None
|
||||
# Layout the D side needs to address P's kv_pool slots:
|
||||
# k_base_ptrs[L] = base device address of layer L's K buffer
|
||||
# v_base_ptrs[L] = base device address of layer L's V buffer
|
||||
# slot_indices = the contiguous range P allocated (list[int], length=num_tokens)
|
||||
# stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize
|
||||
# stride_v_bytes = bytes per token V (often equals stride_k_bytes)
|
||||
# P also registers these slot regions with mooncake before returning.
|
||||
k_base_ptrs: List[int] = field(default_factory=list)
|
||||
v_base_ptrs: List[int] = field(default_factory=list)
|
||||
slot_indices: List[int] = field(default_factory=list)
|
||||
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
|
||||
snapshot_session_id: str = ""
|
||||
# snapshot_buf base pointer + per-layer offsets, replacing the old
|
||||
# kv_pool slot_indices scheme that competed with P's prefill work and
|
||||
# always hit alloc-failed. See docs/SNAPSHOT_STORE_REFACTOR_ZH.md.
|
||||
snapshot_buf_base_ptr: int = 0
|
||||
snapshot_buf_capacity_bytes: int = 0
|
||||
k_layer_offsets: List[int] = field(default_factory=list) # bytes within snapshot_buf
|
||||
v_layer_offsets: List[int] = field(default_factory=list)
|
||||
num_tokens: int = 0
|
||||
stride_k_bytes: int = 0
|
||||
stride_v_bytes: int = 0
|
||||
layer_num: int = 0
|
||||
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
|
||||
snapshot_session_id: str = ""
|
||||
available_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnapshotDumpReqInput(BaseReq):
|
||||
"""D-side: dump session KV via snapshot_link to a target P endpoint."""
|
||||
"""D-side: dump session KV via snapshot_link into P's snapshot_buf slab."""
|
||||
|
||||
session_id: str
|
||||
target_snapshot_session_id: str # P's mooncake snapshot session id
|
||||
target_k_base_ptrs: List[int] = field(default_factory=list)
|
||||
target_v_base_ptrs: List[int] = field(default_factory=list)
|
||||
target_slot_indices: List[int] = field(default_factory=list)
|
||||
target_snapshot_session_id: str
|
||||
target_snapshot_buf_base: int = 0
|
||||
target_k_layer_offsets: List[int] = field(default_factory=list)
|
||||
target_v_layer_offsets: List[int] = field(default_factory=list)
|
||||
target_stride_k_bytes: int = 0
|
||||
target_stride_v_bytes: int = 0
|
||||
ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization
|
||||
ib_device: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1709,11 +1709,10 @@ class SnapshotDumpReqOutput(BaseReq):
|
||||
|
||||
@dataclass
|
||||
class SnapshotFinalizeIngestReqInput(BaseReq):
|
||||
"""P-side: insert (token_ids, kv_indices) into radix tree after D's push."""
|
||||
"""P-side: copy snapshot_buf slab into kv_pool + insert into radix tree."""
|
||||
|
||||
session_id: str
|
||||
token_ids: List[int]
|
||||
slot_indices: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -902,6 +902,7 @@ class Scheduler(
|
||||
ib_device=ib,
|
||||
kv_pool_layer_buffers=layer_buffers,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
tree_cache=self.tree_cache,
|
||||
)
|
||||
logger.info(
|
||||
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
|
||||
@@ -3750,7 +3751,12 @@ class Scheduler(
|
||||
def snapshot_prepare_receive(
|
||||
self, recv_req: SnapshotPrepareReceiveReqInput
|
||||
) -> SnapshotPrepareReceiveReqOutput:
|
||||
"""P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push."""
|
||||
"""P-side: carve snapshot_buf slab + return its layout to caller.
|
||||
|
||||
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: this no longer
|
||||
touches the kv_pool allocator. The slab is in a dedicated
|
||||
snapshot_buf so prepare can never lose to P's prefill work.
|
||||
"""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
@@ -3765,25 +3771,27 @@ class Scheduler(
|
||||
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
|
||||
if record is None:
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
ok=False, reason="alloc-failed",
|
||||
ok=False, reason="snapshot-buf-full",
|
||||
available_tokens=available,
|
||||
)
|
||||
return SnapshotPrepareReceiveReqOutput(
|
||||
ok=True,
|
||||
k_base_ptrs=ctrl.get_k_base_ptrs(),
|
||||
v_base_ptrs=ctrl.get_v_base_ptrs(),
|
||||
slot_indices=record.slot_indices,
|
||||
stride_k_bytes=ctrl.get_stride_k_bytes(),
|
||||
stride_v_bytes=ctrl.get_stride_v_bytes(),
|
||||
layer_num=ctrl.layer_num,
|
||||
snapshot_session_id=ctrl.snapshot_session_id,
|
||||
snapshot_buf_base_ptr=ctrl.snapshot_buf_ptr,
|
||||
snapshot_buf_capacity_bytes=ctrl.snapshot_buf_bytes,
|
||||
k_layer_offsets=record.k_layer_offsets,
|
||||
v_layer_offsets=record.v_layer_offsets,
|
||||
num_tokens=record.num_tokens,
|
||||
stride_k_bytes=record.per_token_k_bytes,
|
||||
stride_v_bytes=record.per_token_v_bytes,
|
||||
layer_num=ctrl.layer_num,
|
||||
available_tokens=available,
|
||||
)
|
||||
|
||||
def snapshot_dump(
|
||||
self, recv_req: SnapshotDumpReqInput
|
||||
) -> SnapshotDumpReqOutput:
|
||||
"""D-side: gather session KV from kv_pool, RDMA-write to remote slots."""
|
||||
"""D-side: gather session KV from kv_pool, RDMA-write into P's snapshot_buf."""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
|
||||
@@ -3795,110 +3803,60 @@ class Scheduler(
|
||||
kv_committed_len = int(slot.kv_committed_len)
|
||||
if kv_committed_len == 0:
|
||||
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
|
||||
# Read kv_indices for the session's prefix
|
||||
try:
|
||||
kv_idx_tensor = self.req_to_token_pool.req_to_token[
|
||||
slot.req_pool_idx, :kv_committed_len
|
||||
]
|
||||
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_dump: failed to read kv_indices: %s", e)
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}")
|
||||
|
||||
# Truncate to the count P prepared for (must match)
|
||||
target_n = len(recv_req.target_slot_indices)
|
||||
if target_n > kv_committed_len:
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=False,
|
||||
reason=f"target-larger-than-source({target_n}>{kv_committed_len})",
|
||||
)
|
||||
src_slot_indices = src_slot_indices[:target_n]
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed:{e!r}")
|
||||
|
||||
try:
|
||||
ret, bytes_pushed = ctrl.push_session_kv(
|
||||
ret, bytes_pushed = ctrl.push_session_to_snapshot_buf(
|
||||
target_snapshot_session_id=recv_req.target_snapshot_session_id,
|
||||
src_slot_indices=src_slot_indices,
|
||||
target_k_base_ptrs=recv_req.target_k_base_ptrs,
|
||||
target_v_base_ptrs=recv_req.target_v_base_ptrs,
|
||||
target_slot_indices=recv_req.target_slot_indices[:target_n],
|
||||
target_stride_k_bytes=recv_req.target_stride_k_bytes,
|
||||
target_stride_v_bytes=recv_req.target_stride_v_bytes,
|
||||
target_snapshot_buf_base=recv_req.target_snapshot_buf_base,
|
||||
target_k_layer_offsets=recv_req.target_k_layer_offsets,
|
||||
target_v_layer_offsets=recv_req.target_v_layer_offsets,
|
||||
target_per_token_k_bytes=recv_req.target_stride_k_bytes,
|
||||
target_per_token_v_bytes=recv_req.target_stride_v_bytes,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_dump: push_session_kv threw: %s", e)
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}")
|
||||
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed:{e!r}")
|
||||
|
||||
if ret != 0:
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=False,
|
||||
reason=f"mooncake-batch-write-ret={ret}",
|
||||
ok=False, reason=f"mooncake-batch-write-ret={ret}",
|
||||
bytes_pushed=int(bytes_pushed),
|
||||
kv_committed_len=int(kv_committed_len),
|
||||
kv_committed_len=kv_committed_len,
|
||||
)
|
||||
return SnapshotDumpReqOutput(
|
||||
ok=True, bytes_pushed=int(bytes_pushed),
|
||||
kv_committed_len=int(kv_committed_len),
|
||||
token_ids=[], # caller already has token_ids
|
||||
kv_committed_len=kv_committed_len,
|
||||
token_ids=[],
|
||||
)
|
||||
|
||||
def snapshot_finalize_ingest(
|
||||
self, recv_req: SnapshotFinalizeIngestReqInput
|
||||
) -> SnapshotFinalizeIngestReqOutput:
|
||||
"""P-side: insert (token_ids, slot_indices) into radix tree."""
|
||||
"""P-side: copy snapshot_buf slab into kv_pool + insert into radix tree.
|
||||
|
||||
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: kv_pool alloc
|
||||
happens HERE (deferred from prepare_receive), so we never block
|
||||
D's RDMA write on kv_pool contention.
|
||||
"""
|
||||
ctrl = self.snapshot_link_controller
|
||||
if ctrl is None:
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason="snapshot-link-disabled",
|
||||
)
|
||||
record = ctrl.take_record(recv_req.session_id)
|
||||
if record is None:
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason="no-pending-ingest",
|
||||
)
|
||||
# Sanity: the slot indices we're about to insert should match the ones we reserved.
|
||||
if list(recv_req.slot_indices) != record.slot_indices:
|
||||
# The caller passed back the slot indices we returned in prepare; if they
|
||||
# don't match, something's gone wrong. Free reserved slots and bail.
|
||||
try:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False,
|
||||
reason="slot-indices-mismatch",
|
||||
)
|
||||
n_tokens = min(len(recv_req.token_ids), len(record.slot_indices))
|
||||
if n_tokens == 0:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids")
|
||||
try:
|
||||
import torch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
kv_indices = torch.tensor(
|
||||
record.slot_indices[:n_tokens],
|
||||
dtype=torch.int64,
|
||||
device=self.tree_cache.token_to_kv_pool_allocator.device,
|
||||
)
|
||||
radix_key = RadixKey(recv_req.token_ids[:n_tokens], None)
|
||||
inner = (
|
||||
self.tree_cache.inner
|
||||
if isinstance(self.tree_cache, SessionAwareCache)
|
||||
else self.tree_cache
|
||||
)
|
||||
result = inner.insert(InsertParams(key=radix_key, value=kv_indices))
|
||||
inserted = int(result.prefix_len)
|
||||
except Exception as e:
|
||||
logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e)
|
||||
try:
|
||||
ctrl._free_slots(record.slot_indices)
|
||||
except Exception:
|
||||
pass
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=False, reason=f"radix-insert-failed: {e!r}",
|
||||
ok, reason, inserted_prefix_len = ctrl.ingest_snapshot_into_kvpool(
|
||||
session_id=recv_req.session_id,
|
||||
token_ids=list(recv_req.token_ids),
|
||||
)
|
||||
return SnapshotFinalizeIngestReqOutput(
|
||||
ok=True, inserted_prefix_len=inserted,
|
||||
ok=bool(ok), reason=reason if not ok else None,
|
||||
inserted_prefix_len=int(inserted_prefix_len),
|
||||
)
|
||||
|
||||
def _compute_backpressure_pause_hint(
|
||||
|
||||
@@ -181,27 +181,18 @@ class SchedulerRuntimeCheckerMixin:
|
||||
return memory_leak, token_msg
|
||||
|
||||
def _check_radix_cache_memory(self: Scheduler):
|
||||
# NB: as of SnapshotStore refactor (see docs/SNAPSHOT_STORE_REFACTOR_ZH.md)
|
||||
# prepare_receive no longer touches kv_pool — slots are alloc'd from
|
||||
# a dedicated snapshot_buf. So no snapshot_reserved accounting needed.
|
||||
_, _, available_size, evictable_size = self._get_token_info()
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
session_held = self._session_held_tokens()
|
||||
# Snapshot link prepare_receive reserves slots that aren't yet visible
|
||||
# to radix / session bookkeeping until finalize_ingest. Count them so
|
||||
# the leak check doesn't fire while a snapshot ingest is in-flight.
|
||||
snapshot_reserved = 0
|
||||
ctrl = getattr(self, "snapshot_link_controller", None)
|
||||
if ctrl is not None:
|
||||
try:
|
||||
snapshot_reserved = sum(
|
||||
len(rec.slot_indices) for rec in ctrl._ingest_records.values()
|
||||
)
|
||||
except Exception:
|
||||
snapshot_reserved = 0
|
||||
memory_leak = (available_size + evictable_size) != (
|
||||
self.max_total_num_tokens - protected_size - session_held - snapshot_reserved
|
||||
self.max_total_num_tokens - protected_size - session_held
|
||||
)
|
||||
token_msg = (
|
||||
f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, "
|
||||
f"{protected_size=}, {session_held=}, {snapshot_reserved=}\n"
|
||||
f"{protected_size=}, {session_held=}\n"
|
||||
)
|
||||
return memory_leak, token_msg
|
||||
|
||||
|
||||
Reference in New Issue
Block a user