diff --git a/src/agentic_pd_hybrid/replay.py b/src/agentic_pd_hybrid/replay.py index f77c32d..554f7b6 100644 --- a/src/agentic_pd_hybrid/replay.py +++ b/src/agentic_pd_hybrid/replay.py @@ -2187,9 +2187,9 @@ async def _attempt_d_to_p_sync( json={ "session_id": request.session_id, "target_snapshot_session_id": prep["snapshot_session_id"], - "target_k_base_ptrs": prep["k_base_ptrs"], - "target_v_base_ptrs": prep["v_base_ptrs"], - "target_slot_indices": prep["slot_indices"], + "target_snapshot_buf_base": prep["snapshot_buf_base_ptr"], + "target_k_layer_offsets": prep["k_layer_offsets"], + "target_v_layer_offsets": prep["v_layer_offsets"], "target_stride_k_bytes": prep["stride_k_bytes"], "target_stride_v_bytes": prep["stride_v_bytes"], }, @@ -2220,15 +2220,13 @@ async def _attempt_d_to_p_sync( # for the first N — use that as best-available approximation. tokens = list(getattr(request, "input_token_ids", []) or []) if not tokens: - # No token_ids available — can't insert into radix. P will fall back - # to normal prefill but will have wasted slots. Discard. + # No token_ids → can't insert into radix; tell P to free the slab. try: await client.post( f"{prefill_url}/_snapshot/finalize_ingest", json={ "session_id": request.session_id, "token_ids": [], - "slot_indices": prep["slot_indices"], }, timeout=15.0, ) @@ -2242,7 +2240,7 @@ async def _attempt_d_to_p_sync( ) return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)} - n = min(len(tokens), len(prep["slot_indices"])) + n = min(len(tokens), int(prep.get("num_tokens", 0))) t_fin0 = time.perf_counter() try: fin_resp = await client.post( @@ -2250,7 +2248,6 @@ async def _attempt_d_to_p_sync( json={ "session_id": request.session_id, "token_ids": tokens[:n], - "slot_indices": prep["slot_indices"][:n], }, timeout=30.0, ) diff --git a/third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py b/third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py index 6a93bb7..897dc10 100644 --- a/third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py +++ b/third_party/sglang/python/sglang/srt/disaggregation/snapshot/controller.py @@ -1,34 +1,29 @@ -"""SnapshotLinkController — drives D→P RDMA snapshot pushes. +"""SnapshotLinkController — D→P RDMA snapshot pushes with dedicated GPU buffer. -This class owns: - * A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD - pipeline engine). - * Memory-region registrations covering the worker's KV pool layer buffers - (registered once at startup for zero per-snapshot overhead). - * A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side - receivers tracking outstanding ingests until ``snapshot_finalize_ingest`` - is called. +Per `docs/SNAPSHOT_STORE_REFACTOR_ZH.md`, this controller now reserves a +dedicated GPU tensor (``snapshot_buf``) for receiving D→P snapshots, instead +of competing with the worker's ``token_to_kv_pool_allocator`` at +prepare_receive time. The kv_pool alloc is deferred to ``finalize_ingest`` +when the bytes are already in hand — if that alloc fails we drop the +snapshot but RDMA reception itself succeeded. -Lifecycle: - SGLang scheduler instantiates one of these per worker if - ``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for - feeding kv_pool buffer descriptors and the kv_pool_allocator so the - controller can pre-register memory and (on P) allocate slots. +Layout of the snapshot_buf for one session reception (chosen for +mooncake's batch_transfer_sync_write friendliness — every layer maps to +a single contiguous slab): -Direction symmetry: - Both D and P workers run identical controllers. Who sends and who - receives is determined by who calls ``push_session_kv`` vs - ``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE - role baked into the snapshot side. + [K_layer_0: num_tokens × stride_k_bytes] + [K_layer_1: num_tokens × stride_k_bytes] + ... + [K_layer_L-1] + [V_layer_0: num_tokens × stride_v_bytes] + ... + [V_layer_L-1] -This is the **vendored** copy living alongside SGLang internals so the -scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py`` -contains the same primitives for stand-alone smoke testing. +The buffer is split into multiple such slabs via ``SnapshotBufAllocator``. """ from __future__ import annotations -import ctypes import logging import os import threading @@ -44,6 +39,10 @@ SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST" SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT" SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE" +# Default snapshot_buf size: 8 GB. Enough for ~1.5 Qwen3-30B 50k-token sessions. +SNAPSHOT_BUF_BYTES_ENV = "SGLANG_SNAPSHOT_LINK_BUF_BYTES" +DEFAULT_SNAPSHOT_BUF_BYTES = 8 * 1024 * 1024 * 1024 + @dataclass class _LayerBufferDesc: @@ -56,13 +55,75 @@ class _LayerBufferDesc: @dataclass class SnapshotIngestRecord: - """P-side: bookkeeping for an outstanding incoming snapshot.""" + """P-side bookkeeping for one in-flight snapshot reception.""" session_id: str - slot_indices: List[int] # kv_pool slots reserved for this ingest + slab_offset: int # offset within snapshot_buf + slab_size: int # total bytes for this slab num_tokens: int + k_layer_offsets: List[int] # absolute byte offsets of K layers in snapshot_buf + v_layer_offsets: List[int] + per_token_k_bytes: int + per_token_v_bytes: int created_at: float = field(default_factory=time.time) +class SnapshotBufAllocator: + """First-fit free-list allocator over a single contiguous byte range. + + Tracks gaps in a sorted list. Merges adjacent free regions on free(). + """ + + def __init__(self, capacity_bytes: int): + self.capacity = capacity_bytes + # Free regions sorted by offset: [(offset, size), ...] + self._free: List[Tuple[int, int]] = [(0, capacity_bytes)] + self._lock = threading.Lock() + self._inflight: dict[int, int] = {} # offset → size for sanity check + + def alloc(self, size: int) -> Optional[int]: + """Return offset of allocated region, or None if no fit available.""" + if size <= 0: + return None + # Page-align allocations to 4 KB for RDMA-friendly alignment. + size = (size + 4095) & ~4095 + with self._lock: + for i, (off, sz) in enumerate(self._free): + if sz >= size: + if sz == size: + self._free.pop(i) + else: + self._free[i] = (off + size, sz - size) + self._inflight[off] = size + return off + return None + + def free(self, offset: int) -> bool: + """Return True if the offset was successfully freed.""" + with self._lock: + size = self._inflight.pop(offset, None) + if size is None: + return False + # Insert sorted and merge adjacents + self._free.append((offset, size)) + self._free.sort() + merged: List[Tuple[int, int]] = [] + for off, sz in self._free: + if merged and merged[-1][0] + merged[-1][1] == off: + merged[-1] = (merged[-1][0], merged[-1][1] + sz) + else: + merged.append((off, sz)) + self._free = merged + return True + + def available_bytes(self) -> int: + with self._lock: + return sum(sz for _, sz in self._free) + + def in_use_bytes(self) -> int: + with self._lock: + return sum(self._inflight.values()) + + def _import_transfer_engine(): try: from mooncake.engine import TransferEngine @@ -75,7 +136,13 @@ def _import_transfer_engine(): class SnapshotLinkController: - """Mooncake engine + per-layer mem registrations + ingest bookkeeping.""" + """Owns mooncake engine + kv_pool registrations + snapshot_buf + records. + + D-side use: push session KV via ``push_session_to_snapshot_buf``. + P-side use: ``prepare_receive`` → caller pushes via RDMA → + ``ingest_snapshot_into_kvpool`` (does GPU memcpy + + radix insert) → ``finalize_record`` (frees the slab). + """ def __init__( self, @@ -84,24 +151,16 @@ class SnapshotLinkController: ib_device: Optional[str], kv_pool_layer_buffers: List[Tuple[int, int, int, bool]], token_to_kv_pool_allocator, + tree_cache=None, protocol: Optional[str] = None, + snapshot_buf_bytes: Optional[int] = None, ): - """ - Parameters - ---------- - host, port : where this worker binds its snapshot engine. - ib_device : preferred IB HCA (e.g. "mlx5_60"). - kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token, - capacity_bytes, is_k)`` tuples. Order should be K-layers - first then V-layers (matches MHATokenToKVPool layout). - token_to_kv_pool_allocator : the worker's allocator. We use - ``.alloc(N)`` on the P side to reserve receive slots. - """ TransferEngine = _import_transfer_engine() self.host = host self.port = port self.ib_device = ib_device self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.tree_cache = tree_cache self.layer_buffers: List[_LayerBufferDesc] = [ _LayerBufferDesc( base_ptr=base, bytes_per_token=btok, @@ -121,38 +180,67 @@ class SnapshotLinkController: ) self._session_id = f"{host}:{self.engine.get_rpc_port()}" - # Register all layer buffers up-front (one-shot) + # Register existing kv_pool layer buffers (needed for D-side send and + # for P-side ingest copy source = snapshot_buf, destination = kv_pool) ptrs = [d.base_ptr for d in self.layer_buffers] lens = [d.capacity_bytes for d in self.layer_buffers] try: reg_ret = self.engine.batch_register_memory(ptrs, lens) except Exception: - reg_ret = -1 - # Fall back to individual register_memory calls reg_ret = 0 for ptr, length in zip(ptrs, lens): r = self.engine.register_memory(ptr, length) if r != 0: - logger.warning( - "SnapshotLinkController register_memory(%s, %d) returned %d", - hex(ptr), length, r, - ) reg_ret = r if reg_ret != 0: logger.warning( - "SnapshotLinkController batch_register_memory returned %d " - "(continuing — individual registrations may have succeeded)", - reg_ret, + "SnapshotLinkController kv_pool batch_register returned %d", reg_ret ) + # Allocate + register the dedicated snapshot reception buffer (P-side) + # This decouples reception from kv_pool, avoiding the alloc-failed + # death loop that killed E4-v4/v5. + import torch + + if snapshot_buf_bytes is None: + snapshot_buf_bytes = int( + os.environ.get(SNAPSHOT_BUF_BYTES_ENV, DEFAULT_SNAPSHOT_BUF_BYTES) + ) + device = self._allocator_device() + try: + self.snapshot_buf = torch.zeros( + snapshot_buf_bytes, dtype=torch.uint8, device=device, + ) + except RuntimeError as e: + logger.warning( + "Could not allocate snapshot_buf of %d bytes on %s: %s. " + "Falling back to 1 GB.", snapshot_buf_bytes, device, e, + ) + snapshot_buf_bytes = 1024 * 1024 * 1024 + self.snapshot_buf = torch.zeros( + snapshot_buf_bytes, dtype=torch.uint8, device=device, + ) + self._snapshot_buf_bytes = snapshot_buf_bytes + self._snapshot_buf_ptr = self.snapshot_buf.data_ptr() + ret = self.engine.register_memory(self._snapshot_buf_ptr, snapshot_buf_bytes) + if ret != 0: + logger.warning( + "SnapshotLinkController snapshot_buf register_memory(%s, %d) ret=%d", + hex(self._snapshot_buf_ptr), snapshot_buf_bytes, ret, + ) + self.snapshot_buf_alloc = SnapshotBufAllocator(snapshot_buf_bytes) + # Receive-side bookkeeping self._ingest_records: dict[str, SnapshotIngestRecord] = {} + self._records_by_handle: dict[int, SnapshotIngestRecord] = {} + self._next_handle = 1 self._lock = threading.Lock() logger.info( - "SnapshotLinkController up at %s (snapshot_session_id=%s, " - "%d layer buffers registered, ib=%s)", - listen, self._session_id, len(self.layer_buffers), ib_device, + "SnapshotLinkController up at %s (sid=%s, %d kv layer bufs, " + "snapshot_buf=%.1f GB on %s)", + listen, self._session_id, len(self.layer_buffers), + snapshot_buf_bytes / 1e9, device, ) # ----- accessors ---------------------------------------------------- @@ -161,9 +249,16 @@ class SnapshotLinkController: def snapshot_session_id(self) -> str: return self._session_id + @property + def snapshot_buf_ptr(self) -> int: + return self._snapshot_buf_ptr + + @property + def snapshot_buf_bytes(self) -> int: + return self._snapshot_buf_bytes + @property def layer_num(self) -> int: - """Number of K layers (= number of V layers = total / 2).""" return len(self.layer_buffers) // 2 def get_k_base_ptrs(self) -> List[int]: @@ -184,66 +279,6 @@ class SnapshotLinkController: return d.bytes_per_token return 0 - # ----- P-side: prepare to receive ---------------------------------- - - def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]: - """Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot. - - Returns a record with the slot indices, or ``None`` if no capacity. - Mooncake registration is already in place (whole-buffer registration - at startup), so no per-snapshot register is needed. - """ - try: - indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens) - except Exception as e: - logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e) - return None - if indices_tensor is None: - return None - try: - slot_indices = [int(x) for x in indices_tensor.tolist()] - except Exception: - # If allocator returns a python list directly - slot_indices = list(map(int, indices_tensor)) - record = SnapshotIngestRecord( - session_id=session_id, - slot_indices=slot_indices, - num_tokens=num_tokens, - ) - with self._lock: - old = self._ingest_records.pop(session_id, None) - if old is not None: - # Best-effort: free old slots if we're overwriting - try: - self._free_slots(old.slot_indices) - except Exception: - pass - self._ingest_records[session_id] = record - return record - - def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]: - with self._lock: - return self._ingest_records.pop(session_id, None) - - def discard_record(self, session_id: str) -> None: - """Drop a pending ingest (e.g. on timeout or D-side failure).""" - with self._lock: - rec = self._ingest_records.pop(session_id, None) - if rec is not None: - try: - self._free_slots(rec.slot_indices) - except Exception: - pass - - def _free_slots(self, slot_indices: List[int]) -> None: - import torch - t = torch.tensor(slot_indices, dtype=torch.int64, - device=self._allocator_device()) - try: - self.token_to_kv_pool_allocator.free(t) - except Exception as e: - logger.warning("SnapshotLinkController._free_slots failed: %s", e) - def _allocator_device(self): # Best-effort: pull device from one of the buffer tensors via the allocator try: @@ -251,91 +286,291 @@ class SnapshotLinkController: except AttributeError: return "cuda" - # ----- D-side: push session KV -------------------------------------- + # ----- P-side: prepare to receive ---------------------------------- - def push_session_kv( + def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]: + """Carve a slab out of snapshot_buf large enough for num_tokens of K+V. + + Returns the record describing the slab layout, or None if snapshot_buf + is full. This does NOT touch kv_pool — alloc happens at ingest time. + """ + if num_tokens <= 0: + return None + stride_k = self.get_stride_k_bytes() + stride_v = self.get_stride_v_bytes() + L = self.layer_num + slab_bytes = L * num_tokens * stride_k + L * num_tokens * stride_v + offset = self.snapshot_buf_alloc.alloc(slab_bytes) + if offset is None: + logger.info( + "prepare_receive: snapshot_buf full (sid=%s n=%d need=%d B available=%d B)", + session_id, num_tokens, slab_bytes, + self.snapshot_buf_alloc.available_bytes(), + ) + return None + # Layout: K0..KL-1, then V0..VL-1 + k_offs = [offset + i * num_tokens * stride_k for i in range(L)] + v_offs = [offset + L * num_tokens * stride_k + i * num_tokens * stride_v + for i in range(L)] + record = SnapshotIngestRecord( + session_id=session_id, + slab_offset=offset, + slab_size=slab_bytes, + num_tokens=num_tokens, + k_layer_offsets=k_offs, + v_layer_offsets=v_offs, + per_token_k_bytes=stride_k, + per_token_v_bytes=stride_v, + ) + with self._lock: + # Evict prior record for the same session (best-effort) + old = self._ingest_records.pop(session_id, None) + if old is not None: + self.snapshot_buf_alloc.free(old.slab_offset) + self._records_by_handle.pop(id(old), None) + self._ingest_records[session_id] = record + self._records_by_handle[id(record)] = record + return record + + def lookup_by_handle(self, handle: int) -> Optional[SnapshotIngestRecord]: + with self._lock: + return self._records_by_handle.get(handle) + + def discard_record(self, session_id: str) -> None: + with self._lock: + rec = self._ingest_records.pop(session_id, None) + if rec is not None: + self.snapshot_buf_alloc.free(rec.slab_offset) + with self._lock: + self._records_by_handle.pop(id(rec), None) + + def total_pending_snapshot_bytes(self) -> int: + with self._lock: + return sum(rec.slab_size for rec in self._ingest_records.values()) + + # ----- P-side: ingest snapshot into kv_pool + radix tree ----------- + + def ingest_snapshot_into_kvpool( + self, + session_id: str, + token_ids: List[int], + ) -> Tuple[bool, str, int]: + """Copy snapshot_buf bytes into kv_pool slots and insert into radix. + + Returns (ok, reason, inserted_prefix_len). + """ + with self._lock: + record = self._ingest_records.pop(session_id, None) + if record is not None: + self._records_by_handle.pop(id(record), None) + if record is None: + return False, "no-pending-ingest", 0 + + try: + n = min(len(token_ids), record.num_tokens) + if n == 0: + self.snapshot_buf_alloc.free(record.slab_offset) + return False, "empty-token-ids", 0 + + # Alloc kv_pool slots NOW that the snapshot bytes are in hand. + try: + indices_tensor = self.token_to_kv_pool_allocator.alloc(n) + except Exception as exc: + self.snapshot_buf_alloc.free(record.slab_offset) + return False, f"kvpool-alloc-threw:{exc!r}", 0 + if indices_tensor is None: + self.snapshot_buf_alloc.free(record.slab_offset) + return False, "kvpool-alloc-failed-at-ingest", 0 + + # GPU→GPU copy from snapshot_buf into kv_pool layer buffers + try: + self._copy_snapshot_to_kvpool(record, indices_tensor) + except Exception as exc: + logger.exception("snapshot→kvpool copy failed: %s", exc) + # Free both allocations + self._free_slot_indices(indices_tensor) + self.snapshot_buf_alloc.free(record.slab_offset) + return False, f"copy-failed:{exc!r}", 0 + + # Insert into radix tree + try: + inserted_prefix_len = self._radix_insert(token_ids[:n], indices_tensor) + except Exception as exc: + logger.exception("radix insert failed: %s", exc) + self._free_slot_indices(indices_tensor) + self.snapshot_buf_alloc.free(record.slab_offset) + return False, f"radix-insert-failed:{exc!r}", 0 + + # Snapshot is now persisted into kv_pool + radix; the slab is no + # longer needed. + self.snapshot_buf_alloc.free(record.slab_offset) + return True, "ok", int(inserted_prefix_len) + except Exception as exc: + # Belt-and-braces cleanup + try: + self.snapshot_buf_alloc.free(record.slab_offset) + except Exception: + pass + return False, f"unexpected:{exc!r}", 0 + + def _copy_snapshot_to_kvpool( + self, + record: SnapshotIngestRecord, + slot_indices_tensor, + ) -> None: + """For each layer L: copy snapshot_buf[K_off[L]..] → k_buffer[L][slots].""" + import torch + + n = record.num_tokens + stride_k = record.per_token_k_bytes + stride_v = record.per_token_v_bytes + # View snapshot_buf as a 1-D byte tensor; slice by offsets. + for L in range(self.layer_num): + # K + k_slab_start = record.k_layer_offsets[L] - record.slab_offset + record.slab_offset + # NOTE: above is equivalent to record.k_layer_offsets[L] but kept for clarity + k_slab_start = record.k_layer_offsets[L] + k_layer_bytes = self.snapshot_buf[ + k_slab_start : k_slab_start + n * stride_k + ].view(n, stride_k) + # Compute destination tensor on kv_pool: dst[slot_indices] = src + # We need access to the actual k_buffer[L] tensor. The controller + # only has the raw ptr — so we materialize a view via from_blob-ish + # trick. Easier: get the tensor from token_to_kv_pool_allocator's kvcache. + kv_cache = self.token_to_kv_pool_allocator.get_kvcache() + k_buf = kv_cache.k_buffer[L] # (max_tokens, head, dim) + # Flatten per-token to bytes + flat = k_buf.view(k_buf.shape[0], -1) + assert flat.shape[1] * flat.element_size() >= stride_k, ( + f"K layer {L} stride mismatch: pool {flat.shape[1] * flat.element_size()} vs snapshot {stride_k}" + ) + # Copy: dst[slot_indices] ← src[:n] + src_reshape = k_layer_bytes.view(n, flat.shape[1] * flat.element_size()) + # Byte-level view of destination rows + dst_view = flat.view(torch.uint8) + dst_view[slot_indices_tensor] = src_reshape + + # V + v_slab_start = record.v_layer_offsets[L] + v_layer_bytes = self.snapshot_buf[ + v_slab_start : v_slab_start + n * stride_v + ] + v_buf = kv_cache.v_buffer[L] + v_flat = v_buf.view(v_buf.shape[0], -1) + src_v = v_layer_bytes.view(n, v_flat.shape[1] * v_flat.element_size()) + v_dst_view = v_flat.view(torch.uint8) + v_dst_view[slot_indices_tensor] = src_v + + def _radix_insert(self, token_ids: List[int], indices_tensor) -> int: + """Insert (token_ids, kv_indices) into the underlying radix tree.""" + from sglang.srt.mem_cache.base_prefix_cache import InsertParams + from sglang.srt.mem_cache.radix_cache import RadixKey + from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache + + inner = self.tree_cache + if isinstance(inner, SessionAwareCache): + inner = inner.inner + if inner is None: + raise RuntimeError("tree_cache not provided to SnapshotLinkController") + radix_key = RadixKey(token_ids, None) + result = inner.insert(InsertParams(key=radix_key, value=indices_tensor)) + return int(getattr(result, "prefix_len", 0)) + + def _free_slot_indices(self, indices_tensor) -> None: + try: + self.token_to_kv_pool_allocator.free(indices_tensor) + except Exception as e: + logger.warning("_free_slot_indices failed: %s", e) + + # ----- D-side: push session KV to a peer's snapshot_buf ------------ + + def push_session_to_snapshot_buf( self, *, target_snapshot_session_id: str, src_slot_indices: List[int], - target_k_base_ptrs: List[int], - target_v_base_ptrs: List[int], - target_slot_indices: List[int], - target_stride_k_bytes: int, - target_stride_v_bytes: int, + target_snapshot_buf_base: int, + target_k_layer_offsets: List[int], + target_v_layer_offsets: List[int], + target_per_token_k_bytes: int, + target_per_token_v_bytes: int, ) -> Tuple[int, int]: - """Push the KV bytes at src_slot_indices to the remote slots. + """Push session KV from local kv_pool into a peer's snapshot_buf slab. - Returns ``(mooncake_return_code, bytes_pushed)``. Caller is - responsible for any post-push handshake (e.g. finalize_ingest RPC). + For each layer: gather src ranges (possibly scattered slot indices) + and write to a contiguous range in the peer's snapshot_buf. + Returns (mooncake_return_code, bytes_pushed). """ + if not src_slot_indices: + return 0, 0 layer_num = self.layer_num k_src_bases = self.get_k_base_ptrs() v_src_bases = self.get_v_base_ptrs() stride_k = self.get_stride_k_bytes() stride_v = self.get_stride_v_bytes() - if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num: + if (len(target_k_layer_offsets) != layer_num + or len(target_v_layer_offsets) != layer_num): raise ValueError( - f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} " - f"!= local layer_num {layer_num}" + f"target K/V layer offset count {len(target_k_layer_offsets)}/" + f"{len(target_v_layer_offsets)} != local layer_num {layer_num}" ) - if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes: + if (stride_k != target_per_token_k_bytes + or stride_v != target_per_token_v_bytes): raise ValueError( - f"stride mismatch: local k={stride_k}, v={stride_v}; " - f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}" - ) - if len(src_slot_indices) != len(target_slot_indices): - raise ValueError( - f"slot count mismatch: src={len(src_slot_indices)}, " - f"target={len(target_slot_indices)}" + f"stride mismatch: local k={stride_k}/v={stride_v}, " + f"target k={target_per_token_k_bytes}/v={target_per_token_v_bytes}" ) + n = len(src_slot_indices) local_addrs: List[int] = [] remote_addrs: List[int] = [] lengths: List[int] = [] - # Group contiguous runs on the target side to coalesce ops. - # Simple approach: per (layer, K/V) pair, walk src/target index - # tuples; merge runs where both src and target are sequential. - for layer_id in range(layer_num): - for kv_bases, stride, kv_label in ( - (k_src_bases[layer_id], stride_k, "K"), - (v_src_bases[layer_id], stride_v, "V"), - ): - src_base = kv_bases - if kv_label == "K": - tgt_base = target_k_base_ptrs[layer_id] + # Coalesce contiguous src runs. + # Inner-loop helper to walk indices and emit run boundaries. + def _emit_runs(src_base: int, tgt_base: int, stride: int) -> None: + run_src_start = run_tgt_start = run_len = None + for tgt_idx, src in enumerate(src_slot_indices): + if run_src_start is None: + run_src_start, run_tgt_start, run_len = src, tgt_idx, 1 + elif src == run_src_start + run_len: + run_len += 1 else: - tgt_base = target_v_base_ptrs[layer_id] - run_src_start = run_tgt_start = run_len = None - for s, t in zip(src_slot_indices, target_slot_indices): - if run_src_start is None: - run_src_start, run_tgt_start, run_len = s, t, 1 - elif s == run_src_start + run_len and t == run_tgt_start + run_len: - run_len += 1 - else: - local_addrs.append(src_base + run_src_start * stride) - remote_addrs.append(tgt_base + run_tgt_start * stride) - lengths.append(run_len * stride) - run_src_start, run_tgt_start, run_len = s, t, 1 - if run_src_start is not None: local_addrs.append(src_base + run_src_start * stride) remote_addrs.append(tgt_base + run_tgt_start * stride) lengths.append(run_len * stride) + run_src_start, run_tgt_start, run_len = src, tgt_idx, 1 + if run_src_start is not None: + local_addrs.append(src_base + run_src_start * stride) + remote_addrs.append(tgt_base + run_tgt_start * stride) + lengths.append(run_len * stride) + + for L in range(layer_num): + _emit_runs( + k_src_bases[L], + target_snapshot_buf_base + target_k_layer_offsets[L], + stride_k, + ) + _emit_runs( + v_src_bases[L], + target_snapshot_buf_base + target_v_layer_offsets[L], + stride_v, + ) t0 = time.perf_counter() try: ret = self.engine.batch_transfer_sync_write( - target_snapshot_session_id, local_addrs, remote_addrs, lengths + target_snapshot_session_id, local_addrs, remote_addrs, lengths, ) except Exception as e: - logger.exception("SnapshotLinkController.push_session_kv threw: %s", e) + logger.exception( + "SnapshotLinkController.push_session_to_snapshot_buf threw: %s", e + ) return -1, 0 t1 = time.perf_counter() bytes_pushed = sum(lengths) logger.info( - "SnapshotLinkController.push_session_kv → %s: %d ops, %d B, " - "ret=%d, %.2f ms", + "push_session_to_snapshot_buf → %s: %d ops, %d B, ret=%d, %.2f ms", target_snapshot_session_id, len(lengths), bytes_pushed, ret, (t1 - t0) * 1000.0, ) diff --git a/third_party/sglang/python/sglang/srt/managers/io_struct.py b/third_party/sglang/python/sglang/srt/managers/io_struct.py index 5b2c6c0..1d605da 100644 --- a/third_party/sglang/python/sglang/srt/managers/io_struct.py +++ b/third_party/sglang/python/sglang/srt/managers/io_struct.py @@ -1662,38 +1662,38 @@ class SnapshotPrepareReceiveReqInput(BaseReq): @dataclass class SnapshotPrepareReceiveReqOutput(BaseReq): + """P-side response. New schema points D at P's dedicated snapshot_buf.""" + ok: bool reason: Optional[str] = None - # Layout the D side needs to address P's kv_pool slots: - # k_base_ptrs[L] = base device address of layer L's K buffer - # v_base_ptrs[L] = base device address of layer L's V buffer - # slot_indices = the contiguous range P allocated (list[int], length=num_tokens) - # stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize - # stride_v_bytes = bytes per token V (often equals stride_k_bytes) - # P also registers these slot regions with mooncake before returning. - k_base_ptrs: List[int] = field(default_factory=list) - v_base_ptrs: List[int] = field(default_factory=list) - slot_indices: List[int] = field(default_factory=list) + # P's mooncake snapshot session id (host:rpc_port) for D's batch write target + snapshot_session_id: str = "" + # snapshot_buf base pointer + per-layer offsets, replacing the old + # kv_pool slot_indices scheme that competed with P's prefill work and + # always hit alloc-failed. See docs/SNAPSHOT_STORE_REFACTOR_ZH.md. + snapshot_buf_base_ptr: int = 0 + snapshot_buf_capacity_bytes: int = 0 + k_layer_offsets: List[int] = field(default_factory=list) # bytes within snapshot_buf + v_layer_offsets: List[int] = field(default_factory=list) + num_tokens: int = 0 stride_k_bytes: int = 0 stride_v_bytes: int = 0 layer_num: int = 0 - # P's mooncake snapshot session id (host:rpc_port) for D's batch write target - snapshot_session_id: str = "" available_tokens: int = 0 @dataclass class SnapshotDumpReqInput(BaseReq): - """D-side: dump session KV via snapshot_link to a target P endpoint.""" + """D-side: dump session KV via snapshot_link into P's snapshot_buf slab.""" session_id: str - target_snapshot_session_id: str # P's mooncake snapshot session id - target_k_base_ptrs: List[int] = field(default_factory=list) - target_v_base_ptrs: List[int] = field(default_factory=list) - target_slot_indices: List[int] = field(default_factory=list) + target_snapshot_session_id: str + target_snapshot_buf_base: int = 0 + target_k_layer_offsets: List[int] = field(default_factory=list) + target_v_layer_offsets: List[int] = field(default_factory=list) target_stride_k_bytes: int = 0 target_stride_v_bytes: int = 0 - ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization + ib_device: Optional[str] = None @dataclass @@ -1709,11 +1709,10 @@ class SnapshotDumpReqOutput(BaseReq): @dataclass class SnapshotFinalizeIngestReqInput(BaseReq): - """P-side: insert (token_ids, kv_indices) into radix tree after D's push.""" + """P-side: copy snapshot_buf slab into kv_pool + insert into radix tree.""" session_id: str token_ids: List[int] - slot_indices: List[int] @dataclass diff --git a/third_party/sglang/python/sglang/srt/managers/scheduler.py b/third_party/sglang/python/sglang/srt/managers/scheduler.py index b180a1c..477003c 100644 --- a/third_party/sglang/python/sglang/srt/managers/scheduler.py +++ b/third_party/sglang/python/sglang/srt/managers/scheduler.py @@ -902,6 +902,7 @@ class Scheduler( ib_device=ib, kv_pool_layer_buffers=layer_buffers, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + tree_cache=self.tree_cache, ) logger.info( "Snapshot link controller initialized: %s, sid=%s, %d layer bufs", @@ -3750,7 +3751,12 @@ class Scheduler( def snapshot_prepare_receive( self, recv_req: SnapshotPrepareReceiveReqInput ) -> SnapshotPrepareReceiveReqOutput: - """P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push.""" + """P-side: carve snapshot_buf slab + return its layout to caller. + + Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: this no longer + touches the kv_pool allocator. The slab is in a dedicated + snapshot_buf so prepare can never lose to P's prefill work. + """ ctrl = self.snapshot_link_controller if ctrl is None: return SnapshotPrepareReceiveReqOutput( @@ -3765,25 +3771,27 @@ class Scheduler( record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens) if record is None: return SnapshotPrepareReceiveReqOutput( - ok=False, reason="alloc-failed", + ok=False, reason="snapshot-buf-full", available_tokens=available, ) return SnapshotPrepareReceiveReqOutput( ok=True, - k_base_ptrs=ctrl.get_k_base_ptrs(), - v_base_ptrs=ctrl.get_v_base_ptrs(), - slot_indices=record.slot_indices, - stride_k_bytes=ctrl.get_stride_k_bytes(), - stride_v_bytes=ctrl.get_stride_v_bytes(), - layer_num=ctrl.layer_num, snapshot_session_id=ctrl.snapshot_session_id, + snapshot_buf_base_ptr=ctrl.snapshot_buf_ptr, + snapshot_buf_capacity_bytes=ctrl.snapshot_buf_bytes, + k_layer_offsets=record.k_layer_offsets, + v_layer_offsets=record.v_layer_offsets, + num_tokens=record.num_tokens, + stride_k_bytes=record.per_token_k_bytes, + stride_v_bytes=record.per_token_v_bytes, + layer_num=ctrl.layer_num, available_tokens=available, ) def snapshot_dump( self, recv_req: SnapshotDumpReqInput ) -> SnapshotDumpReqOutput: - """D-side: gather session KV from kv_pool, RDMA-write to remote slots.""" + """D-side: gather session KV from kv_pool, RDMA-write into P's snapshot_buf.""" ctrl = self.snapshot_link_controller if ctrl is None: return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled") @@ -3795,110 +3803,60 @@ class Scheduler( kv_committed_len = int(slot.kv_committed_len) if kv_committed_len == 0: return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len") - # Read kv_indices for the session's prefix try: kv_idx_tensor = self.req_to_token_pool.req_to_token[ slot.req_pool_idx, :kv_committed_len ] src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()] except Exception as e: - logger.exception("snapshot_dump: failed to read kv_indices: %s", e) - return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}") - - # Truncate to the count P prepared for (must match) - target_n = len(recv_req.target_slot_indices) - if target_n > kv_committed_len: - return SnapshotDumpReqOutput( - ok=False, - reason=f"target-larger-than-source({target_n}>{kv_committed_len})", - ) - src_slot_indices = src_slot_indices[:target_n] + return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed:{e!r}") try: - ret, bytes_pushed = ctrl.push_session_kv( + ret, bytes_pushed = ctrl.push_session_to_snapshot_buf( target_snapshot_session_id=recv_req.target_snapshot_session_id, src_slot_indices=src_slot_indices, - target_k_base_ptrs=recv_req.target_k_base_ptrs, - target_v_base_ptrs=recv_req.target_v_base_ptrs, - target_slot_indices=recv_req.target_slot_indices[:target_n], - target_stride_k_bytes=recv_req.target_stride_k_bytes, - target_stride_v_bytes=recv_req.target_stride_v_bytes, + target_snapshot_buf_base=recv_req.target_snapshot_buf_base, + target_k_layer_offsets=recv_req.target_k_layer_offsets, + target_v_layer_offsets=recv_req.target_v_layer_offsets, + target_per_token_k_bytes=recv_req.target_stride_k_bytes, + target_per_token_v_bytes=recv_req.target_stride_v_bytes, ) except Exception as e: - logger.exception("snapshot_dump: push_session_kv threw: %s", e) - return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}") + return SnapshotDumpReqOutput(ok=False, reason=f"push-failed:{e!r}") if ret != 0: return SnapshotDumpReqOutput( - ok=False, - reason=f"mooncake-batch-write-ret={ret}", + ok=False, reason=f"mooncake-batch-write-ret={ret}", bytes_pushed=int(bytes_pushed), - kv_committed_len=int(kv_committed_len), + kv_committed_len=kv_committed_len, ) return SnapshotDumpReqOutput( ok=True, bytes_pushed=int(bytes_pushed), - kv_committed_len=int(kv_committed_len), - token_ids=[], # caller already has token_ids + kv_committed_len=kv_committed_len, + token_ids=[], ) def snapshot_finalize_ingest( self, recv_req: SnapshotFinalizeIngestReqInput ) -> SnapshotFinalizeIngestReqOutput: - """P-side: insert (token_ids, slot_indices) into radix tree.""" + """P-side: copy snapshot_buf slab into kv_pool + insert into radix tree. + + Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: kv_pool alloc + happens HERE (deferred from prepare_receive), so we never block + D's RDMA write on kv_pool contention. + """ ctrl = self.snapshot_link_controller if ctrl is None: return SnapshotFinalizeIngestReqOutput( ok=False, reason="snapshot-link-disabled", ) - record = ctrl.take_record(recv_req.session_id) - if record is None: - return SnapshotFinalizeIngestReqOutput( - ok=False, reason="no-pending-ingest", - ) - # Sanity: the slot indices we're about to insert should match the ones we reserved. - if list(recv_req.slot_indices) != record.slot_indices: - # The caller passed back the slot indices we returned in prepare; if they - # don't match, something's gone wrong. Free reserved slots and bail. - try: - ctrl._free_slots(record.slot_indices) - except Exception: - pass - return SnapshotFinalizeIngestReqOutput( - ok=False, - reason="slot-indices-mismatch", - ) - n_tokens = min(len(recv_req.token_ids), len(record.slot_indices)) - if n_tokens == 0: - ctrl._free_slots(record.slot_indices) - return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids") - try: - import torch - from sglang.srt.mem_cache.base_prefix_cache import InsertParams - from sglang.srt.mem_cache.radix_cache import RadixKey - kv_indices = torch.tensor( - record.slot_indices[:n_tokens], - dtype=torch.int64, - device=self.tree_cache.token_to_kv_pool_allocator.device, - ) - radix_key = RadixKey(recv_req.token_ids[:n_tokens], None) - inner = ( - self.tree_cache.inner - if isinstance(self.tree_cache, SessionAwareCache) - else self.tree_cache - ) - result = inner.insert(InsertParams(key=radix_key, value=kv_indices)) - inserted = int(result.prefix_len) - except Exception as e: - logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e) - try: - ctrl._free_slots(record.slot_indices) - except Exception: - pass - return SnapshotFinalizeIngestReqOutput( - ok=False, reason=f"radix-insert-failed: {e!r}", - ) + ok, reason, inserted_prefix_len = ctrl.ingest_snapshot_into_kvpool( + session_id=recv_req.session_id, + token_ids=list(recv_req.token_ids), + ) return SnapshotFinalizeIngestReqOutput( - ok=True, inserted_prefix_len=inserted, + ok=bool(ok), reason=reason if not ok else None, + inserted_prefix_len=int(inserted_prefix_len), ) def _compute_backpressure_pause_hint( diff --git a/third_party/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/third_party/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index e6d487b..c5f3f1f 100644 --- a/third_party/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/third_party/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -181,27 +181,18 @@ class SchedulerRuntimeCheckerMixin: return memory_leak, token_msg def _check_radix_cache_memory(self: Scheduler): + # NB: as of SnapshotStore refactor (see docs/SNAPSHOT_STORE_REFACTOR_ZH.md) + # prepare_receive no longer touches kv_pool — slots are alloc'd from + # a dedicated snapshot_buf. So no snapshot_reserved accounting needed. _, _, available_size, evictable_size = self._get_token_info() protected_size = self.tree_cache.protected_size() session_held = self._session_held_tokens() - # Snapshot link prepare_receive reserves slots that aren't yet visible - # to radix / session bookkeeping until finalize_ingest. Count them so - # the leak check doesn't fire while a snapshot ingest is in-flight. - snapshot_reserved = 0 - ctrl = getattr(self, "snapshot_link_controller", None) - if ctrl is not None: - try: - snapshot_reserved = sum( - len(rec.slot_indices) for rec in ctrl._ingest_records.values() - ) - except Exception: - snapshot_reserved = 0 memory_leak = (available_size + evictable_size) != ( - self.max_total_num_tokens - protected_size - session_held - snapshot_reserved + self.max_total_num_tokens - protected_size - session_held ) token_msg = ( f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, " - f"{protected_size=}, {session_held=}, {snapshot_reserved=}\n" + f"{protected_size=}, {session_held=}\n" ) return memory_leak, token_msg