refactor(snapshot): dedicated GPU snapshot_buf replaces kv_pool alloc
Implements the design in docs/SNAPSHOT_STORE_REFACTOR_ZH.md to fix
the alloc-failed death loop that killed D→P in E4-v4/v5 (167 sync
attempts, 0 OK because P's kv_pool was busy with its own prefill).
Mechanism change:
OLD prepare_receive: token_to_kv_pool_allocator.alloc(N) — 90%+ failure
NEW prepare_receive: SnapshotBufAllocator.alloc(slab_bytes) carves a
range from an 8 GB GPU buffer dedicated to
snapshot reception, decoupled from kv_pool
OLD finalize_ingest: just radix.insert with pre-alloc'd slots
NEW finalize_ingest: kv_pool.alloc NOW + GPU memcpy snapshot_buf →
k_buffer/v_buffer + radix.insert
Wire schema changed (clean break, no back-compat):
PrepareReceiveReqOutput swaps k/v_base_ptrs + slot_indices for
snapshot_buf_base_ptr + k/v_layer_offsets +
num_tokens
DumpReqInput swaps target_k/v_base_ptrs + target_slot_indices
for target_snapshot_buf_base +
target_k/v_layer_offsets
FinalizeIngestReqInput drops slot_indices (P resolves at ingest)
Controller adds:
SnapshotBufAllocator: first-fit free-list with 4 KB alignment
ingest_snapshot_into_kvpool: GPU→GPU copy + radix insert
Configurable buffer size via SGLANG_SNAPSHOT_LINK_BUF_BYTES env
(default 8 GB, scales down to 1 GB if alloc fails).
Removed runtime leak-check accommodation since prepare_receive no
longer touches kv_pool.
Total: ~365 LOC including alloc helper; smoke-test verification next.
This commit is contained in:
@@ -2187,9 +2187,9 @@ async def _attempt_d_to_p_sync(
|
|||||||
json={
|
json={
|
||||||
"session_id": request.session_id,
|
"session_id": request.session_id,
|
||||||
"target_snapshot_session_id": prep["snapshot_session_id"],
|
"target_snapshot_session_id": prep["snapshot_session_id"],
|
||||||
"target_k_base_ptrs": prep["k_base_ptrs"],
|
"target_snapshot_buf_base": prep["snapshot_buf_base_ptr"],
|
||||||
"target_v_base_ptrs": prep["v_base_ptrs"],
|
"target_k_layer_offsets": prep["k_layer_offsets"],
|
||||||
"target_slot_indices": prep["slot_indices"],
|
"target_v_layer_offsets": prep["v_layer_offsets"],
|
||||||
"target_stride_k_bytes": prep["stride_k_bytes"],
|
"target_stride_k_bytes": prep["stride_k_bytes"],
|
||||||
"target_stride_v_bytes": prep["stride_v_bytes"],
|
"target_stride_v_bytes": prep["stride_v_bytes"],
|
||||||
},
|
},
|
||||||
@@ -2220,15 +2220,13 @@ async def _attempt_d_to_p_sync(
|
|||||||
# for the first N — use that as best-available approximation.
|
# for the first N — use that as best-available approximation.
|
||||||
tokens = list(getattr(request, "input_token_ids", []) or [])
|
tokens = list(getattr(request, "input_token_ids", []) or [])
|
||||||
if not tokens:
|
if not tokens:
|
||||||
# No token_ids available — can't insert into radix. P will fall back
|
# No token_ids → can't insert into radix; tell P to free the slab.
|
||||||
# to normal prefill but will have wasted slots. Discard.
|
|
||||||
try:
|
try:
|
||||||
await client.post(
|
await client.post(
|
||||||
f"{prefill_url}/_snapshot/finalize_ingest",
|
f"{prefill_url}/_snapshot/finalize_ingest",
|
||||||
json={
|
json={
|
||||||
"session_id": request.session_id,
|
"session_id": request.session_id,
|
||||||
"token_ids": [],
|
"token_ids": [],
|
||||||
"slot_indices": prep["slot_indices"],
|
|
||||||
},
|
},
|
||||||
timeout=15.0,
|
timeout=15.0,
|
||||||
)
|
)
|
||||||
@@ -2242,7 +2240,7 @@ async def _attempt_d_to_p_sync(
|
|||||||
)
|
)
|
||||||
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
|
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
|
||||||
|
|
||||||
n = min(len(tokens), len(prep["slot_indices"]))
|
n = min(len(tokens), int(prep.get("num_tokens", 0)))
|
||||||
t_fin0 = time.perf_counter()
|
t_fin0 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
fin_resp = await client.post(
|
fin_resp = await client.post(
|
||||||
@@ -2250,7 +2248,6 @@ async def _attempt_d_to_p_sync(
|
|||||||
json={
|
json={
|
||||||
"session_id": request.session_id,
|
"session_id": request.session_id,
|
||||||
"token_ids": tokens[:n],
|
"token_ids": tokens[:n],
|
||||||
"slot_indices": prep["slot_indices"][:n],
|
|
||||||
},
|
},
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,34 +1,29 @@
|
|||||||
"""SnapshotLinkController — drives D→P RDMA snapshot pushes.
|
"""SnapshotLinkController — D→P RDMA snapshot pushes with dedicated GPU buffer.
|
||||||
|
|
||||||
This class owns:
|
Per `docs/SNAPSHOT_STORE_REFACTOR_ZH.md`, this controller now reserves a
|
||||||
* A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD
|
dedicated GPU tensor (``snapshot_buf``) for receiving D→P snapshots, instead
|
||||||
pipeline engine).
|
of competing with the worker's ``token_to_kv_pool_allocator`` at
|
||||||
* Memory-region registrations covering the worker's KV pool layer buffers
|
prepare_receive time. The kv_pool alloc is deferred to ``finalize_ingest``
|
||||||
(registered once at startup for zero per-snapshot overhead).
|
when the bytes are already in hand — if that alloc fails we drop the
|
||||||
* A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side
|
snapshot but RDMA reception itself succeeded.
|
||||||
receivers tracking outstanding ingests until ``snapshot_finalize_ingest``
|
|
||||||
is called.
|
|
||||||
|
|
||||||
Lifecycle:
|
Layout of the snapshot_buf for one session reception (chosen for
|
||||||
SGLang scheduler instantiates one of these per worker if
|
mooncake's batch_transfer_sync_write friendliness — every layer maps to
|
||||||
``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for
|
a single contiguous slab):
|
||||||
feeding kv_pool buffer descriptors and the kv_pool_allocator so the
|
|
||||||
controller can pre-register memory and (on P) allocate slots.
|
|
||||||
|
|
||||||
Direction symmetry:
|
[K_layer_0: num_tokens × stride_k_bytes]
|
||||||
Both D and P workers run identical controllers. Who sends and who
|
[K_layer_1: num_tokens × stride_k_bytes]
|
||||||
receives is determined by who calls ``push_session_kv`` vs
|
...
|
||||||
``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE
|
[K_layer_L-1]
|
||||||
role baked into the snapshot side.
|
[V_layer_0: num_tokens × stride_v_bytes]
|
||||||
|
...
|
||||||
|
[V_layer_L-1]
|
||||||
|
|
||||||
This is the **vendored** copy living alongside SGLang internals so the
|
The buffer is split into multiple such slabs via ``SnapshotBufAllocator``.
|
||||||
scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py``
|
|
||||||
contains the same primitives for stand-alone smoke testing.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import ctypes
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
@@ -44,6 +39,10 @@ SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST"
|
|||||||
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
|
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
|
||||||
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
|
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
|
||||||
|
|
||||||
|
# Default snapshot_buf size: 8 GB. Enough for ~1.5 Qwen3-30B 50k-token sessions.
|
||||||
|
SNAPSHOT_BUF_BYTES_ENV = "SGLANG_SNAPSHOT_LINK_BUF_BYTES"
|
||||||
|
DEFAULT_SNAPSHOT_BUF_BYTES = 8 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _LayerBufferDesc:
|
class _LayerBufferDesc:
|
||||||
@@ -56,13 +55,75 @@ class _LayerBufferDesc:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SnapshotIngestRecord:
|
class SnapshotIngestRecord:
|
||||||
"""P-side: bookkeeping for an outstanding incoming snapshot."""
|
"""P-side bookkeeping for one in-flight snapshot reception."""
|
||||||
session_id: str
|
session_id: str
|
||||||
slot_indices: List[int] # kv_pool slots reserved for this ingest
|
slab_offset: int # offset within snapshot_buf
|
||||||
|
slab_size: int # total bytes for this slab
|
||||||
num_tokens: int
|
num_tokens: int
|
||||||
|
k_layer_offsets: List[int] # absolute byte offsets of K layers in snapshot_buf
|
||||||
|
v_layer_offsets: List[int]
|
||||||
|
per_token_k_bytes: int
|
||||||
|
per_token_v_bytes: int
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotBufAllocator:
|
||||||
|
"""First-fit free-list allocator over a single contiguous byte range.
|
||||||
|
|
||||||
|
Tracks gaps in a sorted list. Merges adjacent free regions on free().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, capacity_bytes: int):
|
||||||
|
self.capacity = capacity_bytes
|
||||||
|
# Free regions sorted by offset: [(offset, size), ...]
|
||||||
|
self._free: List[Tuple[int, int]] = [(0, capacity_bytes)]
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._inflight: dict[int, int] = {} # offset → size for sanity check
|
||||||
|
|
||||||
|
def alloc(self, size: int) -> Optional[int]:
|
||||||
|
"""Return offset of allocated region, or None if no fit available."""
|
||||||
|
if size <= 0:
|
||||||
|
return None
|
||||||
|
# Page-align allocations to 4 KB for RDMA-friendly alignment.
|
||||||
|
size = (size + 4095) & ~4095
|
||||||
|
with self._lock:
|
||||||
|
for i, (off, sz) in enumerate(self._free):
|
||||||
|
if sz >= size:
|
||||||
|
if sz == size:
|
||||||
|
self._free.pop(i)
|
||||||
|
else:
|
||||||
|
self._free[i] = (off + size, sz - size)
|
||||||
|
self._inflight[off] = size
|
||||||
|
return off
|
||||||
|
return None
|
||||||
|
|
||||||
|
def free(self, offset: int) -> bool:
|
||||||
|
"""Return True if the offset was successfully freed."""
|
||||||
|
with self._lock:
|
||||||
|
size = self._inflight.pop(offset, None)
|
||||||
|
if size is None:
|
||||||
|
return False
|
||||||
|
# Insert sorted and merge adjacents
|
||||||
|
self._free.append((offset, size))
|
||||||
|
self._free.sort()
|
||||||
|
merged: List[Tuple[int, int]] = []
|
||||||
|
for off, sz in self._free:
|
||||||
|
if merged and merged[-1][0] + merged[-1][1] == off:
|
||||||
|
merged[-1] = (merged[-1][0], merged[-1][1] + sz)
|
||||||
|
else:
|
||||||
|
merged.append((off, sz))
|
||||||
|
self._free = merged
|
||||||
|
return True
|
||||||
|
|
||||||
|
def available_bytes(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return sum(sz for _, sz in self._free)
|
||||||
|
|
||||||
|
def in_use_bytes(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return sum(self._inflight.values())
|
||||||
|
|
||||||
|
|
||||||
def _import_transfer_engine():
|
def _import_transfer_engine():
|
||||||
try:
|
try:
|
||||||
from mooncake.engine import TransferEngine
|
from mooncake.engine import TransferEngine
|
||||||
@@ -75,7 +136,13 @@ def _import_transfer_engine():
|
|||||||
|
|
||||||
|
|
||||||
class SnapshotLinkController:
|
class SnapshotLinkController:
|
||||||
"""Mooncake engine + per-layer mem registrations + ingest bookkeeping."""
|
"""Owns mooncake engine + kv_pool registrations + snapshot_buf + records.
|
||||||
|
|
||||||
|
D-side use: push session KV via ``push_session_to_snapshot_buf``.
|
||||||
|
P-side use: ``prepare_receive`` → caller pushes via RDMA →
|
||||||
|
``ingest_snapshot_into_kvpool`` (does GPU memcpy +
|
||||||
|
radix insert) → ``finalize_record`` (frees the slab).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -84,24 +151,16 @@ class SnapshotLinkController:
|
|||||||
ib_device: Optional[str],
|
ib_device: Optional[str],
|
||||||
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
|
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
|
||||||
token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator,
|
||||||
|
tree_cache=None,
|
||||||
protocol: Optional[str] = None,
|
protocol: Optional[str] = None,
|
||||||
|
snapshot_buf_bytes: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
host, port : where this worker binds its snapshot engine.
|
|
||||||
ib_device : preferred IB HCA (e.g. "mlx5_60").
|
|
||||||
kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token,
|
|
||||||
capacity_bytes, is_k)`` tuples. Order should be K-layers
|
|
||||||
first then V-layers (matches MHATokenToKVPool layout).
|
|
||||||
token_to_kv_pool_allocator : the worker's allocator. We use
|
|
||||||
``.alloc(N)`` on the P side to reserve receive slots.
|
|
||||||
"""
|
|
||||||
TransferEngine = _import_transfer_engine()
|
TransferEngine = _import_transfer_engine()
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.ib_device = ib_device
|
self.ib_device = ib_device
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
|
self.tree_cache = tree_cache
|
||||||
self.layer_buffers: List[_LayerBufferDesc] = [
|
self.layer_buffers: List[_LayerBufferDesc] = [
|
||||||
_LayerBufferDesc(
|
_LayerBufferDesc(
|
||||||
base_ptr=base, bytes_per_token=btok,
|
base_ptr=base, bytes_per_token=btok,
|
||||||
@@ -121,38 +180,67 @@ class SnapshotLinkController:
|
|||||||
)
|
)
|
||||||
self._session_id = f"{host}:{self.engine.get_rpc_port()}"
|
self._session_id = f"{host}:{self.engine.get_rpc_port()}"
|
||||||
|
|
||||||
# Register all layer buffers up-front (one-shot)
|
# Register existing kv_pool layer buffers (needed for D-side send and
|
||||||
|
# for P-side ingest copy source = snapshot_buf, destination = kv_pool)
|
||||||
ptrs = [d.base_ptr for d in self.layer_buffers]
|
ptrs = [d.base_ptr for d in self.layer_buffers]
|
||||||
lens = [d.capacity_bytes for d in self.layer_buffers]
|
lens = [d.capacity_bytes for d in self.layer_buffers]
|
||||||
try:
|
try:
|
||||||
reg_ret = self.engine.batch_register_memory(ptrs, lens)
|
reg_ret = self.engine.batch_register_memory(ptrs, lens)
|
||||||
except Exception:
|
except Exception:
|
||||||
reg_ret = -1
|
|
||||||
# Fall back to individual register_memory calls
|
|
||||||
reg_ret = 0
|
reg_ret = 0
|
||||||
for ptr, length in zip(ptrs, lens):
|
for ptr, length in zip(ptrs, lens):
|
||||||
r = self.engine.register_memory(ptr, length)
|
r = self.engine.register_memory(ptr, length)
|
||||||
if r != 0:
|
if r != 0:
|
||||||
logger.warning(
|
|
||||||
"SnapshotLinkController register_memory(%s, %d) returned %d",
|
|
||||||
hex(ptr), length, r,
|
|
||||||
)
|
|
||||||
reg_ret = r
|
reg_ret = r
|
||||||
if reg_ret != 0:
|
if reg_ret != 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SnapshotLinkController batch_register_memory returned %d "
|
"SnapshotLinkController kv_pool batch_register returned %d", reg_ret
|
||||||
"(continuing — individual registrations may have succeeded)",
|
|
||||||
reg_ret,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Allocate + register the dedicated snapshot reception buffer (P-side)
|
||||||
|
# This decouples reception from kv_pool, avoiding the alloc-failed
|
||||||
|
# death loop that killed E4-v4/v5.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if snapshot_buf_bytes is None:
|
||||||
|
snapshot_buf_bytes = int(
|
||||||
|
os.environ.get(SNAPSHOT_BUF_BYTES_ENV, DEFAULT_SNAPSHOT_BUF_BYTES)
|
||||||
|
)
|
||||||
|
device = self._allocator_device()
|
||||||
|
try:
|
||||||
|
self.snapshot_buf = torch.zeros(
|
||||||
|
snapshot_buf_bytes, dtype=torch.uint8, device=device,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Could not allocate snapshot_buf of %d bytes on %s: %s. "
|
||||||
|
"Falling back to 1 GB.", snapshot_buf_bytes, device, e,
|
||||||
|
)
|
||||||
|
snapshot_buf_bytes = 1024 * 1024 * 1024
|
||||||
|
self.snapshot_buf = torch.zeros(
|
||||||
|
snapshot_buf_bytes, dtype=torch.uint8, device=device,
|
||||||
|
)
|
||||||
|
self._snapshot_buf_bytes = snapshot_buf_bytes
|
||||||
|
self._snapshot_buf_ptr = self.snapshot_buf.data_ptr()
|
||||||
|
ret = self.engine.register_memory(self._snapshot_buf_ptr, snapshot_buf_bytes)
|
||||||
|
if ret != 0:
|
||||||
|
logger.warning(
|
||||||
|
"SnapshotLinkController snapshot_buf register_memory(%s, %d) ret=%d",
|
||||||
|
hex(self._snapshot_buf_ptr), snapshot_buf_bytes, ret,
|
||||||
|
)
|
||||||
|
self.snapshot_buf_alloc = SnapshotBufAllocator(snapshot_buf_bytes)
|
||||||
|
|
||||||
# Receive-side bookkeeping
|
# Receive-side bookkeeping
|
||||||
self._ingest_records: dict[str, SnapshotIngestRecord] = {}
|
self._ingest_records: dict[str, SnapshotIngestRecord] = {}
|
||||||
|
self._records_by_handle: dict[int, SnapshotIngestRecord] = {}
|
||||||
|
self._next_handle = 1
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"SnapshotLinkController up at %s (snapshot_session_id=%s, "
|
"SnapshotLinkController up at %s (sid=%s, %d kv layer bufs, "
|
||||||
"%d layer buffers registered, ib=%s)",
|
"snapshot_buf=%.1f GB on %s)",
|
||||||
listen, self._session_id, len(self.layer_buffers), ib_device,
|
listen, self._session_id, len(self.layer_buffers),
|
||||||
|
snapshot_buf_bytes / 1e9, device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ----- accessors ----------------------------------------------------
|
# ----- accessors ----------------------------------------------------
|
||||||
@@ -161,9 +249,16 @@ class SnapshotLinkController:
|
|||||||
def snapshot_session_id(self) -> str:
|
def snapshot_session_id(self) -> str:
|
||||||
return self._session_id
|
return self._session_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def snapshot_buf_ptr(self) -> int:
|
||||||
|
return self._snapshot_buf_ptr
|
||||||
|
|
||||||
|
@property
|
||||||
|
def snapshot_buf_bytes(self) -> int:
|
||||||
|
return self._snapshot_buf_bytes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layer_num(self) -> int:
|
def layer_num(self) -> int:
|
||||||
"""Number of K layers (= number of V layers = total / 2)."""
|
|
||||||
return len(self.layer_buffers) // 2
|
return len(self.layer_buffers) // 2
|
||||||
|
|
||||||
def get_k_base_ptrs(self) -> List[int]:
|
def get_k_base_ptrs(self) -> List[int]:
|
||||||
@@ -184,66 +279,6 @@ class SnapshotLinkController:
|
|||||||
return d.bytes_per_token
|
return d.bytes_per_token
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# ----- P-side: prepare to receive ----------------------------------
|
|
||||||
|
|
||||||
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
|
|
||||||
"""Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot.
|
|
||||||
|
|
||||||
Returns a record with the slot indices, or ``None`` if no capacity.
|
|
||||||
Mooncake registration is already in place (whole-buffer registration
|
|
||||||
at startup), so no per-snapshot register is needed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e)
|
|
||||||
return None
|
|
||||||
if indices_tensor is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
slot_indices = [int(x) for x in indices_tensor.tolist()]
|
|
||||||
except Exception:
|
|
||||||
# If allocator returns a python list directly
|
|
||||||
slot_indices = list(map(int, indices_tensor))
|
|
||||||
record = SnapshotIngestRecord(
|
|
||||||
session_id=session_id,
|
|
||||||
slot_indices=slot_indices,
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
)
|
|
||||||
with self._lock:
|
|
||||||
old = self._ingest_records.pop(session_id, None)
|
|
||||||
if old is not None:
|
|
||||||
# Best-effort: free old slots if we're overwriting
|
|
||||||
try:
|
|
||||||
self._free_slots(old.slot_indices)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._ingest_records[session_id] = record
|
|
||||||
return record
|
|
||||||
|
|
||||||
def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]:
|
|
||||||
with self._lock:
|
|
||||||
return self._ingest_records.pop(session_id, None)
|
|
||||||
|
|
||||||
def discard_record(self, session_id: str) -> None:
|
|
||||||
"""Drop a pending ingest (e.g. on timeout or D-side failure)."""
|
|
||||||
with self._lock:
|
|
||||||
rec = self._ingest_records.pop(session_id, None)
|
|
||||||
if rec is not None:
|
|
||||||
try:
|
|
||||||
self._free_slots(rec.slot_indices)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _free_slots(self, slot_indices: List[int]) -> None:
|
|
||||||
import torch
|
|
||||||
t = torch.tensor(slot_indices, dtype=torch.int64,
|
|
||||||
device=self._allocator_device())
|
|
||||||
try:
|
|
||||||
self.token_to_kv_pool_allocator.free(t)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("SnapshotLinkController._free_slots failed: %s", e)
|
|
||||||
|
|
||||||
def _allocator_device(self):
|
def _allocator_device(self):
|
||||||
# Best-effort: pull device from one of the buffer tensors via the allocator
|
# Best-effort: pull device from one of the buffer tensors via the allocator
|
||||||
try:
|
try:
|
||||||
@@ -251,91 +286,291 @@ class SnapshotLinkController:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
# ----- D-side: push session KV --------------------------------------
|
# ----- P-side: prepare to receive ----------------------------------
|
||||||
|
|
||||||
def push_session_kv(
|
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
|
||||||
|
"""Carve a slab out of snapshot_buf large enough for num_tokens of K+V.
|
||||||
|
|
||||||
|
Returns the record describing the slab layout, or None if snapshot_buf
|
||||||
|
is full. This does NOT touch kv_pool — alloc happens at ingest time.
|
||||||
|
"""
|
||||||
|
if num_tokens <= 0:
|
||||||
|
return None
|
||||||
|
stride_k = self.get_stride_k_bytes()
|
||||||
|
stride_v = self.get_stride_v_bytes()
|
||||||
|
L = self.layer_num
|
||||||
|
slab_bytes = L * num_tokens * stride_k + L * num_tokens * stride_v
|
||||||
|
offset = self.snapshot_buf_alloc.alloc(slab_bytes)
|
||||||
|
if offset is None:
|
||||||
|
logger.info(
|
||||||
|
"prepare_receive: snapshot_buf full (sid=%s n=%d need=%d B available=%d B)",
|
||||||
|
session_id, num_tokens, slab_bytes,
|
||||||
|
self.snapshot_buf_alloc.available_bytes(),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
# Layout: K0..KL-1, then V0..VL-1
|
||||||
|
k_offs = [offset + i * num_tokens * stride_k for i in range(L)]
|
||||||
|
v_offs = [offset + L * num_tokens * stride_k + i * num_tokens * stride_v
|
||||||
|
for i in range(L)]
|
||||||
|
record = SnapshotIngestRecord(
|
||||||
|
session_id=session_id,
|
||||||
|
slab_offset=offset,
|
||||||
|
slab_size=slab_bytes,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
k_layer_offsets=k_offs,
|
||||||
|
v_layer_offsets=v_offs,
|
||||||
|
per_token_k_bytes=stride_k,
|
||||||
|
per_token_v_bytes=stride_v,
|
||||||
|
)
|
||||||
|
with self._lock:
|
||||||
|
# Evict prior record for the same session (best-effort)
|
||||||
|
old = self._ingest_records.pop(session_id, None)
|
||||||
|
if old is not None:
|
||||||
|
self.snapshot_buf_alloc.free(old.slab_offset)
|
||||||
|
self._records_by_handle.pop(id(old), None)
|
||||||
|
self._ingest_records[session_id] = record
|
||||||
|
self._records_by_handle[id(record)] = record
|
||||||
|
return record
|
||||||
|
|
||||||
|
def lookup_by_handle(self, handle: int) -> Optional[SnapshotIngestRecord]:
|
||||||
|
with self._lock:
|
||||||
|
return self._records_by_handle.get(handle)
|
||||||
|
|
||||||
|
def discard_record(self, session_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
rec = self._ingest_records.pop(session_id, None)
|
||||||
|
if rec is not None:
|
||||||
|
self.snapshot_buf_alloc.free(rec.slab_offset)
|
||||||
|
with self._lock:
|
||||||
|
self._records_by_handle.pop(id(rec), None)
|
||||||
|
|
||||||
|
def total_pending_snapshot_bytes(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return sum(rec.slab_size for rec in self._ingest_records.values())
|
||||||
|
|
||||||
|
# ----- P-side: ingest snapshot into kv_pool + radix tree -----------
|
||||||
|
|
||||||
|
def ingest_snapshot_into_kvpool(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
token_ids: List[int],
|
||||||
|
) -> Tuple[bool, str, int]:
|
||||||
|
"""Copy snapshot_buf bytes into kv_pool slots and insert into radix.
|
||||||
|
|
||||||
|
Returns (ok, reason, inserted_prefix_len).
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
record = self._ingest_records.pop(session_id, None)
|
||||||
|
if record is not None:
|
||||||
|
self._records_by_handle.pop(id(record), None)
|
||||||
|
if record is None:
|
||||||
|
return False, "no-pending-ingest", 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
n = min(len(token_ids), record.num_tokens)
|
||||||
|
if n == 0:
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return False, "empty-token-ids", 0
|
||||||
|
|
||||||
|
# Alloc kv_pool slots NOW that the snapshot bytes are in hand.
|
||||||
|
try:
|
||||||
|
indices_tensor = self.token_to_kv_pool_allocator.alloc(n)
|
||||||
|
except Exception as exc:
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return False, f"kvpool-alloc-threw:{exc!r}", 0
|
||||||
|
if indices_tensor is None:
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return False, "kvpool-alloc-failed-at-ingest", 0
|
||||||
|
|
||||||
|
# GPU→GPU copy from snapshot_buf into kv_pool layer buffers
|
||||||
|
try:
|
||||||
|
self._copy_snapshot_to_kvpool(record, indices_tensor)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("snapshot→kvpool copy failed: %s", exc)
|
||||||
|
# Free both allocations
|
||||||
|
self._free_slot_indices(indices_tensor)
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return False, f"copy-failed:{exc!r}", 0
|
||||||
|
|
||||||
|
# Insert into radix tree
|
||||||
|
try:
|
||||||
|
inserted_prefix_len = self._radix_insert(token_ids[:n], indices_tensor)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("radix insert failed: %s", exc)
|
||||||
|
self._free_slot_indices(indices_tensor)
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return False, f"radix-insert-failed:{exc!r}", 0
|
||||||
|
|
||||||
|
# Snapshot is now persisted into kv_pool + radix; the slab is no
|
||||||
|
# longer needed.
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
return True, "ok", int(inserted_prefix_len)
|
||||||
|
except Exception as exc:
|
||||||
|
# Belt-and-braces cleanup
|
||||||
|
try:
|
||||||
|
self.snapshot_buf_alloc.free(record.slab_offset)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False, f"unexpected:{exc!r}", 0
|
||||||
|
|
||||||
|
def _copy_snapshot_to_kvpool(
|
||||||
|
self,
|
||||||
|
record: SnapshotIngestRecord,
|
||||||
|
slot_indices_tensor,
|
||||||
|
) -> None:
|
||||||
|
"""For each layer L: copy snapshot_buf[K_off[L]..] → k_buffer[L][slots]."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
n = record.num_tokens
|
||||||
|
stride_k = record.per_token_k_bytes
|
||||||
|
stride_v = record.per_token_v_bytes
|
||||||
|
# View snapshot_buf as a 1-D byte tensor; slice by offsets.
|
||||||
|
for L in range(self.layer_num):
|
||||||
|
# K
|
||||||
|
k_slab_start = record.k_layer_offsets[L] - record.slab_offset + record.slab_offset
|
||||||
|
# NOTE: above is equivalent to record.k_layer_offsets[L] but kept for clarity
|
||||||
|
k_slab_start = record.k_layer_offsets[L]
|
||||||
|
k_layer_bytes = self.snapshot_buf[
|
||||||
|
k_slab_start : k_slab_start + n * stride_k
|
||||||
|
].view(n, stride_k)
|
||||||
|
# Compute destination tensor on kv_pool: dst[slot_indices] = src
|
||||||
|
# We need access to the actual k_buffer[L] tensor. The controller
|
||||||
|
# only has the raw ptr — so we materialize a view via from_blob-ish
|
||||||
|
# trick. Easier: get the tensor from token_to_kv_pool_allocator's kvcache.
|
||||||
|
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
k_buf = kv_cache.k_buffer[L] # (max_tokens, head, dim)
|
||||||
|
# Flatten per-token to bytes
|
||||||
|
flat = k_buf.view(k_buf.shape[0], -1)
|
||||||
|
assert flat.shape[1] * flat.element_size() >= stride_k, (
|
||||||
|
f"K layer {L} stride mismatch: pool {flat.shape[1] * flat.element_size()} vs snapshot {stride_k}"
|
||||||
|
)
|
||||||
|
# Copy: dst[slot_indices] ← src[:n]
|
||||||
|
src_reshape = k_layer_bytes.view(n, flat.shape[1] * flat.element_size())
|
||||||
|
# Byte-level view of destination rows
|
||||||
|
dst_view = flat.view(torch.uint8)
|
||||||
|
dst_view[slot_indices_tensor] = src_reshape
|
||||||
|
|
||||||
|
# V
|
||||||
|
v_slab_start = record.v_layer_offsets[L]
|
||||||
|
v_layer_bytes = self.snapshot_buf[
|
||||||
|
v_slab_start : v_slab_start + n * stride_v
|
||||||
|
]
|
||||||
|
v_buf = kv_cache.v_buffer[L]
|
||||||
|
v_flat = v_buf.view(v_buf.shape[0], -1)
|
||||||
|
src_v = v_layer_bytes.view(n, v_flat.shape[1] * v_flat.element_size())
|
||||||
|
v_dst_view = v_flat.view(torch.uint8)
|
||||||
|
v_dst_view[slot_indices_tensor] = src_v
|
||||||
|
|
||||||
|
def _radix_insert(self, token_ids: List[int], indices_tensor) -> int:
|
||||||
|
"""Insert (token_ids, kv_indices) into the underlying radix tree."""
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
|
||||||
|
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||||
|
from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache
|
||||||
|
|
||||||
|
inner = self.tree_cache
|
||||||
|
if isinstance(inner, SessionAwareCache):
|
||||||
|
inner = inner.inner
|
||||||
|
if inner is None:
|
||||||
|
raise RuntimeError("tree_cache not provided to SnapshotLinkController")
|
||||||
|
radix_key = RadixKey(token_ids, None)
|
||||||
|
result = inner.insert(InsertParams(key=radix_key, value=indices_tensor))
|
||||||
|
return int(getattr(result, "prefix_len", 0))
|
||||||
|
|
||||||
|
def _free_slot_indices(self, indices_tensor) -> None:
|
||||||
|
try:
|
||||||
|
self.token_to_kv_pool_allocator.free(indices_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("_free_slot_indices failed: %s", e)
|
||||||
|
|
||||||
|
# ----- D-side: push session KV to a peer's snapshot_buf ------------
|
||||||
|
|
||||||
|
def push_session_to_snapshot_buf(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
target_snapshot_session_id: str,
|
target_snapshot_session_id: str,
|
||||||
src_slot_indices: List[int],
|
src_slot_indices: List[int],
|
||||||
target_k_base_ptrs: List[int],
|
target_snapshot_buf_base: int,
|
||||||
target_v_base_ptrs: List[int],
|
target_k_layer_offsets: List[int],
|
||||||
target_slot_indices: List[int],
|
target_v_layer_offsets: List[int],
|
||||||
target_stride_k_bytes: int,
|
target_per_token_k_bytes: int,
|
||||||
target_stride_v_bytes: int,
|
target_per_token_v_bytes: int,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""Push the KV bytes at src_slot_indices to the remote slots.
|
"""Push session KV from local kv_pool into a peer's snapshot_buf slab.
|
||||||
|
|
||||||
Returns ``(mooncake_return_code, bytes_pushed)``. Caller is
|
For each layer: gather src ranges (possibly scattered slot indices)
|
||||||
responsible for any post-push handshake (e.g. finalize_ingest RPC).
|
and write to a contiguous range in the peer's snapshot_buf.
|
||||||
|
Returns (mooncake_return_code, bytes_pushed).
|
||||||
"""
|
"""
|
||||||
|
if not src_slot_indices:
|
||||||
|
return 0, 0
|
||||||
layer_num = self.layer_num
|
layer_num = self.layer_num
|
||||||
k_src_bases = self.get_k_base_ptrs()
|
k_src_bases = self.get_k_base_ptrs()
|
||||||
v_src_bases = self.get_v_base_ptrs()
|
v_src_bases = self.get_v_base_ptrs()
|
||||||
stride_k = self.get_stride_k_bytes()
|
stride_k = self.get_stride_k_bytes()
|
||||||
stride_v = self.get_stride_v_bytes()
|
stride_v = self.get_stride_v_bytes()
|
||||||
if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num:
|
if (len(target_k_layer_offsets) != layer_num
|
||||||
|
or len(target_v_layer_offsets) != layer_num):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} "
|
f"target K/V layer offset count {len(target_k_layer_offsets)}/"
|
||||||
f"!= local layer_num {layer_num}"
|
f"{len(target_v_layer_offsets)} != local layer_num {layer_num}"
|
||||||
)
|
)
|
||||||
if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes:
|
if (stride_k != target_per_token_k_bytes
|
||||||
|
or stride_v != target_per_token_v_bytes):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"stride mismatch: local k={stride_k}, v={stride_v}; "
|
f"stride mismatch: local k={stride_k}/v={stride_v}, "
|
||||||
f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}"
|
f"target k={target_per_token_k_bytes}/v={target_per_token_v_bytes}"
|
||||||
)
|
|
||||||
if len(src_slot_indices) != len(target_slot_indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"slot count mismatch: src={len(src_slot_indices)}, "
|
|
||||||
f"target={len(target_slot_indices)}"
|
|
||||||
)
|
)
|
||||||
|
n = len(src_slot_indices)
|
||||||
|
|
||||||
local_addrs: List[int] = []
|
local_addrs: List[int] = []
|
||||||
remote_addrs: List[int] = []
|
remote_addrs: List[int] = []
|
||||||
lengths: List[int] = []
|
lengths: List[int] = []
|
||||||
|
|
||||||
# Group contiguous runs on the target side to coalesce ops.
|
# Coalesce contiguous src runs.
|
||||||
# Simple approach: per (layer, K/V) pair, walk src/target index
|
# Inner-loop helper to walk indices and emit run boundaries.
|
||||||
# tuples; merge runs where both src and target are sequential.
|
def _emit_runs(src_base: int, tgt_base: int, stride: int) -> None:
|
||||||
for layer_id in range(layer_num):
|
run_src_start = run_tgt_start = run_len = None
|
||||||
for kv_bases, stride, kv_label in (
|
for tgt_idx, src in enumerate(src_slot_indices):
|
||||||
(k_src_bases[layer_id], stride_k, "K"),
|
if run_src_start is None:
|
||||||
(v_src_bases[layer_id], stride_v, "V"),
|
run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
|
||||||
):
|
elif src == run_src_start + run_len:
|
||||||
src_base = kv_bases
|
run_len += 1
|
||||||
if kv_label == "K":
|
|
||||||
tgt_base = target_k_base_ptrs[layer_id]
|
|
||||||
else:
|
else:
|
||||||
tgt_base = target_v_base_ptrs[layer_id]
|
|
||||||
run_src_start = run_tgt_start = run_len = None
|
|
||||||
for s, t in zip(src_slot_indices, target_slot_indices):
|
|
||||||
if run_src_start is None:
|
|
||||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
|
||||||
elif s == run_src_start + run_len and t == run_tgt_start + run_len:
|
|
||||||
run_len += 1
|
|
||||||
else:
|
|
||||||
local_addrs.append(src_base + run_src_start * stride)
|
|
||||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
|
||||||
lengths.append(run_len * stride)
|
|
||||||
run_src_start, run_tgt_start, run_len = s, t, 1
|
|
||||||
if run_src_start is not None:
|
|
||||||
local_addrs.append(src_base + run_src_start * stride)
|
local_addrs.append(src_base + run_src_start * stride)
|
||||||
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||||
lengths.append(run_len * stride)
|
lengths.append(run_len * stride)
|
||||||
|
run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
|
||||||
|
if run_src_start is not None:
|
||||||
|
local_addrs.append(src_base + run_src_start * stride)
|
||||||
|
remote_addrs.append(tgt_base + run_tgt_start * stride)
|
||||||
|
lengths.append(run_len * stride)
|
||||||
|
|
||||||
|
for L in range(layer_num):
|
||||||
|
_emit_runs(
|
||||||
|
k_src_bases[L],
|
||||||
|
target_snapshot_buf_base + target_k_layer_offsets[L],
|
||||||
|
stride_k,
|
||||||
|
)
|
||||||
|
_emit_runs(
|
||||||
|
v_src_bases[L],
|
||||||
|
target_snapshot_buf_base + target_v_layer_offsets[L],
|
||||||
|
stride_v,
|
||||||
|
)
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
target_snapshot_session_id, local_addrs, remote_addrs, lengths
|
target_snapshot_session_id, local_addrs, remote_addrs, lengths,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("SnapshotLinkController.push_session_kv threw: %s", e)
|
logger.exception(
|
||||||
|
"SnapshotLinkController.push_session_to_snapshot_buf threw: %s", e
|
||||||
|
)
|
||||||
return -1, 0
|
return -1, 0
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
bytes_pushed = sum(lengths)
|
bytes_pushed = sum(lengths)
|
||||||
logger.info(
|
logger.info(
|
||||||
"SnapshotLinkController.push_session_kv → %s: %d ops, %d B, "
|
"push_session_to_snapshot_buf → %s: %d ops, %d B, ret=%d, %.2f ms",
|
||||||
"ret=%d, %.2f ms",
|
|
||||||
target_snapshot_session_id, len(lengths), bytes_pushed, ret,
|
target_snapshot_session_id, len(lengths), bytes_pushed, ret,
|
||||||
(t1 - t0) * 1000.0,
|
(t1 - t0) * 1000.0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1662,38 +1662,38 @@ class SnapshotPrepareReceiveReqInput(BaseReq):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SnapshotPrepareReceiveReqOutput(BaseReq):
|
class SnapshotPrepareReceiveReqOutput(BaseReq):
|
||||||
|
"""P-side response. New schema points D at P's dedicated snapshot_buf."""
|
||||||
|
|
||||||
ok: bool
|
ok: bool
|
||||||
reason: Optional[str] = None
|
reason: Optional[str] = None
|
||||||
# Layout the D side needs to address P's kv_pool slots:
|
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
|
||||||
# k_base_ptrs[L] = base device address of layer L's K buffer
|
snapshot_session_id: str = ""
|
||||||
# v_base_ptrs[L] = base device address of layer L's V buffer
|
# snapshot_buf base pointer + per-layer offsets, replacing the old
|
||||||
# slot_indices = the contiguous range P allocated (list[int], length=num_tokens)
|
# kv_pool slot_indices scheme that competed with P's prefill work and
|
||||||
# stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize
|
# always hit alloc-failed. See docs/SNAPSHOT_STORE_REFACTOR_ZH.md.
|
||||||
# stride_v_bytes = bytes per token V (often equals stride_k_bytes)
|
snapshot_buf_base_ptr: int = 0
|
||||||
# P also registers these slot regions with mooncake before returning.
|
snapshot_buf_capacity_bytes: int = 0
|
||||||
k_base_ptrs: List[int] = field(default_factory=list)
|
k_layer_offsets: List[int] = field(default_factory=list) # bytes within snapshot_buf
|
||||||
v_base_ptrs: List[int] = field(default_factory=list)
|
v_layer_offsets: List[int] = field(default_factory=list)
|
||||||
slot_indices: List[int] = field(default_factory=list)
|
num_tokens: int = 0
|
||||||
stride_k_bytes: int = 0
|
stride_k_bytes: int = 0
|
||||||
stride_v_bytes: int = 0
|
stride_v_bytes: int = 0
|
||||||
layer_num: int = 0
|
layer_num: int = 0
|
||||||
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
|
|
||||||
snapshot_session_id: str = ""
|
|
||||||
available_tokens: int = 0
|
available_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SnapshotDumpReqInput(BaseReq):
|
class SnapshotDumpReqInput(BaseReq):
|
||||||
"""D-side: dump session KV via snapshot_link to a target P endpoint."""
|
"""D-side: dump session KV via snapshot_link into P's snapshot_buf slab."""
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
target_snapshot_session_id: str # P's mooncake snapshot session id
|
target_snapshot_session_id: str
|
||||||
target_k_base_ptrs: List[int] = field(default_factory=list)
|
target_snapshot_buf_base: int = 0
|
||||||
target_v_base_ptrs: List[int] = field(default_factory=list)
|
target_k_layer_offsets: List[int] = field(default_factory=list)
|
||||||
target_slot_indices: List[int] = field(default_factory=list)
|
target_v_layer_offsets: List[int] = field(default_factory=list)
|
||||||
target_stride_k_bytes: int = 0
|
target_stride_k_bytes: int = 0
|
||||||
target_stride_v_bytes: int = 0
|
target_stride_v_bytes: int = 0
|
||||||
ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization
|
ib_device: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1709,11 +1709,10 @@ class SnapshotDumpReqOutput(BaseReq):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SnapshotFinalizeIngestReqInput(BaseReq):
|
class SnapshotFinalizeIngestReqInput(BaseReq):
|
||||||
"""P-side: insert (token_ids, kv_indices) into radix tree after D's push."""
|
"""P-side: copy snapshot_buf slab into kv_pool + insert into radix tree."""
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
slot_indices: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -902,6 +902,7 @@ class Scheduler(
|
|||||||
ib_device=ib,
|
ib_device=ib,
|
||||||
kv_pool_layer_buffers=layer_buffers,
|
kv_pool_layer_buffers=layer_buffers,
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
|
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
|
||||||
@@ -3750,7 +3751,12 @@ class Scheduler(
|
|||||||
def snapshot_prepare_receive(
|
def snapshot_prepare_receive(
|
||||||
self, recv_req: SnapshotPrepareReceiveReqInput
|
self, recv_req: SnapshotPrepareReceiveReqInput
|
||||||
) -> SnapshotPrepareReceiveReqOutput:
|
) -> SnapshotPrepareReceiveReqOutput:
|
||||||
"""P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push."""
|
"""P-side: carve snapshot_buf slab + return its layout to caller.
|
||||||
|
|
||||||
|
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: this no longer
|
||||||
|
touches the kv_pool allocator. The slab is in a dedicated
|
||||||
|
snapshot_buf so prepare can never lose to P's prefill work.
|
||||||
|
"""
|
||||||
ctrl = self.snapshot_link_controller
|
ctrl = self.snapshot_link_controller
|
||||||
if ctrl is None:
|
if ctrl is None:
|
||||||
return SnapshotPrepareReceiveReqOutput(
|
return SnapshotPrepareReceiveReqOutput(
|
||||||
@@ -3765,25 +3771,27 @@ class Scheduler(
|
|||||||
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
|
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
|
||||||
if record is None:
|
if record is None:
|
||||||
return SnapshotPrepareReceiveReqOutput(
|
return SnapshotPrepareReceiveReqOutput(
|
||||||
ok=False, reason="alloc-failed",
|
ok=False, reason="snapshot-buf-full",
|
||||||
available_tokens=available,
|
available_tokens=available,
|
||||||
)
|
)
|
||||||
return SnapshotPrepareReceiveReqOutput(
|
return SnapshotPrepareReceiveReqOutput(
|
||||||
ok=True,
|
ok=True,
|
||||||
k_base_ptrs=ctrl.get_k_base_ptrs(),
|
|
||||||
v_base_ptrs=ctrl.get_v_base_ptrs(),
|
|
||||||
slot_indices=record.slot_indices,
|
|
||||||
stride_k_bytes=ctrl.get_stride_k_bytes(),
|
|
||||||
stride_v_bytes=ctrl.get_stride_v_bytes(),
|
|
||||||
layer_num=ctrl.layer_num,
|
|
||||||
snapshot_session_id=ctrl.snapshot_session_id,
|
snapshot_session_id=ctrl.snapshot_session_id,
|
||||||
|
snapshot_buf_base_ptr=ctrl.snapshot_buf_ptr,
|
||||||
|
snapshot_buf_capacity_bytes=ctrl.snapshot_buf_bytes,
|
||||||
|
k_layer_offsets=record.k_layer_offsets,
|
||||||
|
v_layer_offsets=record.v_layer_offsets,
|
||||||
|
num_tokens=record.num_tokens,
|
||||||
|
stride_k_bytes=record.per_token_k_bytes,
|
||||||
|
stride_v_bytes=record.per_token_v_bytes,
|
||||||
|
layer_num=ctrl.layer_num,
|
||||||
available_tokens=available,
|
available_tokens=available,
|
||||||
)
|
)
|
||||||
|
|
||||||
def snapshot_dump(
|
def snapshot_dump(
|
||||||
self, recv_req: SnapshotDumpReqInput
|
self, recv_req: SnapshotDumpReqInput
|
||||||
) -> SnapshotDumpReqOutput:
|
) -> SnapshotDumpReqOutput:
|
||||||
"""D-side: gather session KV from kv_pool, RDMA-write to remote slots."""
|
"""D-side: gather session KV from kv_pool, RDMA-write into P's snapshot_buf."""
|
||||||
ctrl = self.snapshot_link_controller
|
ctrl = self.snapshot_link_controller
|
||||||
if ctrl is None:
|
if ctrl is None:
|
||||||
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
|
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
|
||||||
@@ -3795,110 +3803,60 @@ class Scheduler(
|
|||||||
kv_committed_len = int(slot.kv_committed_len)
|
kv_committed_len = int(slot.kv_committed_len)
|
||||||
if kv_committed_len == 0:
|
if kv_committed_len == 0:
|
||||||
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
|
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
|
||||||
# Read kv_indices for the session's prefix
|
|
||||||
try:
|
try:
|
||||||
kv_idx_tensor = self.req_to_token_pool.req_to_token[
|
kv_idx_tensor = self.req_to_token_pool.req_to_token[
|
||||||
slot.req_pool_idx, :kv_committed_len
|
slot.req_pool_idx, :kv_committed_len
|
||||||
]
|
]
|
||||||
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
|
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("snapshot_dump: failed to read kv_indices: %s", e)
|
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed:{e!r}")
|
||||||
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}")
|
|
||||||
|
|
||||||
# Truncate to the count P prepared for (must match)
|
|
||||||
target_n = len(recv_req.target_slot_indices)
|
|
||||||
if target_n > kv_committed_len:
|
|
||||||
return SnapshotDumpReqOutput(
|
|
||||||
ok=False,
|
|
||||||
reason=f"target-larger-than-source({target_n}>{kv_committed_len})",
|
|
||||||
)
|
|
||||||
src_slot_indices = src_slot_indices[:target_n]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret, bytes_pushed = ctrl.push_session_kv(
|
ret, bytes_pushed = ctrl.push_session_to_snapshot_buf(
|
||||||
target_snapshot_session_id=recv_req.target_snapshot_session_id,
|
target_snapshot_session_id=recv_req.target_snapshot_session_id,
|
||||||
src_slot_indices=src_slot_indices,
|
src_slot_indices=src_slot_indices,
|
||||||
target_k_base_ptrs=recv_req.target_k_base_ptrs,
|
target_snapshot_buf_base=recv_req.target_snapshot_buf_base,
|
||||||
target_v_base_ptrs=recv_req.target_v_base_ptrs,
|
target_k_layer_offsets=recv_req.target_k_layer_offsets,
|
||||||
target_slot_indices=recv_req.target_slot_indices[:target_n],
|
target_v_layer_offsets=recv_req.target_v_layer_offsets,
|
||||||
target_stride_k_bytes=recv_req.target_stride_k_bytes,
|
target_per_token_k_bytes=recv_req.target_stride_k_bytes,
|
||||||
target_stride_v_bytes=recv_req.target_stride_v_bytes,
|
target_per_token_v_bytes=recv_req.target_stride_v_bytes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("snapshot_dump: push_session_kv threw: %s", e)
|
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed:{e!r}")
|
||||||
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}")
|
|
||||||
|
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
return SnapshotDumpReqOutput(
|
return SnapshotDumpReqOutput(
|
||||||
ok=False,
|
ok=False, reason=f"mooncake-batch-write-ret={ret}",
|
||||||
reason=f"mooncake-batch-write-ret={ret}",
|
|
||||||
bytes_pushed=int(bytes_pushed),
|
bytes_pushed=int(bytes_pushed),
|
||||||
kv_committed_len=int(kv_committed_len),
|
kv_committed_len=kv_committed_len,
|
||||||
)
|
)
|
||||||
return SnapshotDumpReqOutput(
|
return SnapshotDumpReqOutput(
|
||||||
ok=True, bytes_pushed=int(bytes_pushed),
|
ok=True, bytes_pushed=int(bytes_pushed),
|
||||||
kv_committed_len=int(kv_committed_len),
|
kv_committed_len=kv_committed_len,
|
||||||
token_ids=[], # caller already has token_ids
|
token_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def snapshot_finalize_ingest(
|
def snapshot_finalize_ingest(
|
||||||
self, recv_req: SnapshotFinalizeIngestReqInput
|
self, recv_req: SnapshotFinalizeIngestReqInput
|
||||||
) -> SnapshotFinalizeIngestReqOutput:
|
) -> SnapshotFinalizeIngestReqOutput:
|
||||||
"""P-side: insert (token_ids, slot_indices) into radix tree."""
|
"""P-side: copy snapshot_buf slab into kv_pool + insert into radix tree.
|
||||||
|
|
||||||
|
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: kv_pool alloc
|
||||||
|
happens HERE (deferred from prepare_receive), so we never block
|
||||||
|
D's RDMA write on kv_pool contention.
|
||||||
|
"""
|
||||||
ctrl = self.snapshot_link_controller
|
ctrl = self.snapshot_link_controller
|
||||||
if ctrl is None:
|
if ctrl is None:
|
||||||
return SnapshotFinalizeIngestReqOutput(
|
return SnapshotFinalizeIngestReqOutput(
|
||||||
ok=False, reason="snapshot-link-disabled",
|
ok=False, reason="snapshot-link-disabled",
|
||||||
)
|
)
|
||||||
record = ctrl.take_record(recv_req.session_id)
|
ok, reason, inserted_prefix_len = ctrl.ingest_snapshot_into_kvpool(
|
||||||
if record is None:
|
session_id=recv_req.session_id,
|
||||||
return SnapshotFinalizeIngestReqOutput(
|
token_ids=list(recv_req.token_ids),
|
||||||
ok=False, reason="no-pending-ingest",
|
)
|
||||||
)
|
|
||||||
# Sanity: the slot indices we're about to insert should match the ones we reserved.
|
|
||||||
if list(recv_req.slot_indices) != record.slot_indices:
|
|
||||||
# The caller passed back the slot indices we returned in prepare; if they
|
|
||||||
# don't match, something's gone wrong. Free reserved slots and bail.
|
|
||||||
try:
|
|
||||||
ctrl._free_slots(record.slot_indices)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return SnapshotFinalizeIngestReqOutput(
|
|
||||||
ok=False,
|
|
||||||
reason="slot-indices-mismatch",
|
|
||||||
)
|
|
||||||
n_tokens = min(len(recv_req.token_ids), len(record.slot_indices))
|
|
||||||
if n_tokens == 0:
|
|
||||||
ctrl._free_slots(record.slot_indices)
|
|
||||||
return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids")
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
|
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
|
||||||
kv_indices = torch.tensor(
|
|
||||||
record.slot_indices[:n_tokens],
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.tree_cache.token_to_kv_pool_allocator.device,
|
|
||||||
)
|
|
||||||
radix_key = RadixKey(recv_req.token_ids[:n_tokens], None)
|
|
||||||
inner = (
|
|
||||||
self.tree_cache.inner
|
|
||||||
if isinstance(self.tree_cache, SessionAwareCache)
|
|
||||||
else self.tree_cache
|
|
||||||
)
|
|
||||||
result = inner.insert(InsertParams(key=radix_key, value=kv_indices))
|
|
||||||
inserted = int(result.prefix_len)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e)
|
|
||||||
try:
|
|
||||||
ctrl._free_slots(record.slot_indices)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return SnapshotFinalizeIngestReqOutput(
|
|
||||||
ok=False, reason=f"radix-insert-failed: {e!r}",
|
|
||||||
)
|
|
||||||
return SnapshotFinalizeIngestReqOutput(
|
return SnapshotFinalizeIngestReqOutput(
|
||||||
ok=True, inserted_prefix_len=inserted,
|
ok=bool(ok), reason=reason if not ok else None,
|
||||||
|
inserted_prefix_len=int(inserted_prefix_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compute_backpressure_pause_hint(
|
def _compute_backpressure_pause_hint(
|
||||||
|
|||||||
@@ -181,27 +181,18 @@ class SchedulerRuntimeCheckerMixin:
|
|||||||
return memory_leak, token_msg
|
return memory_leak, token_msg
|
||||||
|
|
||||||
def _check_radix_cache_memory(self: Scheduler):
|
def _check_radix_cache_memory(self: Scheduler):
|
||||||
|
# NB: as of SnapshotStore refactor (see docs/SNAPSHOT_STORE_REFACTOR_ZH.md)
|
||||||
|
# prepare_receive no longer touches kv_pool — slots are alloc'd from
|
||||||
|
# a dedicated snapshot_buf. So no snapshot_reserved accounting needed.
|
||||||
_, _, available_size, evictable_size = self._get_token_info()
|
_, _, available_size, evictable_size = self._get_token_info()
|
||||||
protected_size = self.tree_cache.protected_size()
|
protected_size = self.tree_cache.protected_size()
|
||||||
session_held = self._session_held_tokens()
|
session_held = self._session_held_tokens()
|
||||||
# Snapshot link prepare_receive reserves slots that aren't yet visible
|
|
||||||
# to radix / session bookkeeping until finalize_ingest. Count them so
|
|
||||||
# the leak check doesn't fire while a snapshot ingest is in-flight.
|
|
||||||
snapshot_reserved = 0
|
|
||||||
ctrl = getattr(self, "snapshot_link_controller", None)
|
|
||||||
if ctrl is not None:
|
|
||||||
try:
|
|
||||||
snapshot_reserved = sum(
|
|
||||||
len(rec.slot_indices) for rec in ctrl._ingest_records.values()
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
snapshot_reserved = 0
|
|
||||||
memory_leak = (available_size + evictable_size) != (
|
memory_leak = (available_size + evictable_size) != (
|
||||||
self.max_total_num_tokens - protected_size - session_held - snapshot_reserved
|
self.max_total_num_tokens - protected_size - session_held
|
||||||
)
|
)
|
||||||
token_msg = (
|
token_msg = (
|
||||||
f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, "
|
f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, "
|
||||||
f"{protected_size=}, {session_held=}, {snapshot_reserved=}\n"
|
f"{protected_size=}, {session_held=}\n"
|
||||||
)
|
)
|
||||||
return memory_leak, token_msg
|
return memory_leak, token_msg
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user