refactor(snapshot): dedicated GPU snapshot_buf replaces kv_pool alloc

Implements the design in docs/SNAPSHOT_STORE_REFACTOR_ZH.md to fix
the alloc-failed death loop that killed D→P in E4-v4/v5 (167 sync
attempts, 0 OK because P's kv_pool was busy with its own prefill).

Mechanism change:
  OLD prepare_receive: token_to_kv_pool_allocator.alloc(N) — 90%+ failure
  NEW prepare_receive: SnapshotBufAllocator.alloc(slab_bytes) carves a
                       range from an 8 GB GPU buffer dedicated to
                       snapshot reception, decoupled from kv_pool

  OLD finalize_ingest: just radix.insert with pre-alloc'd slots
  NEW finalize_ingest: kv_pool.alloc NOW + GPU memcpy snapshot_buf →
                       k_buffer/v_buffer + radix.insert

Wire schema changed (clean break, no back-compat):
  PrepareReceiveReqOutput  swaps k/v_base_ptrs + slot_indices  for
                           snapshot_buf_base_ptr + k/v_layer_offsets +
                           num_tokens
  DumpReqInput             swaps target_k/v_base_ptrs + target_slot_indices
                           for target_snapshot_buf_base +
                           target_k/v_layer_offsets
  FinalizeIngestReqInput   drops slot_indices (P resolves at ingest)

Controller adds:
  SnapshotBufAllocator: first-fit free-list with 4 KB alignment
  ingest_snapshot_into_kvpool: GPU→GPU copy + radix insert

Configurable buffer size via SGLANG_SNAPSHOT_LINK_BUF_BYTES env
(default 8 GB, scales down to 1 GB if alloc fails).

Removed runtime leak-check accommodation since prepare_receive no
longer touches kv_pool.

Total: ~365 LOC including alloc helper; smoke-test verification next.
This commit is contained in:
Claude Code Agent
2026-05-13 14:18:23 +08:00
parent 6be5f9b57e
commit 2dfe22ab20
5 changed files with 465 additions and 285 deletions

View File

@@ -2187,9 +2187,9 @@ async def _attempt_d_to_p_sync(
json={ json={
"session_id": request.session_id, "session_id": request.session_id,
"target_snapshot_session_id": prep["snapshot_session_id"], "target_snapshot_session_id": prep["snapshot_session_id"],
"target_k_base_ptrs": prep["k_base_ptrs"], "target_snapshot_buf_base": prep["snapshot_buf_base_ptr"],
"target_v_base_ptrs": prep["v_base_ptrs"], "target_k_layer_offsets": prep["k_layer_offsets"],
"target_slot_indices": prep["slot_indices"], "target_v_layer_offsets": prep["v_layer_offsets"],
"target_stride_k_bytes": prep["stride_k_bytes"], "target_stride_k_bytes": prep["stride_k_bytes"],
"target_stride_v_bytes": prep["stride_v_bytes"], "target_stride_v_bytes": prep["stride_v_bytes"],
}, },
@@ -2220,15 +2220,13 @@ async def _attempt_d_to_p_sync(
# for the first N — use that as best-available approximation. # for the first N — use that as best-available approximation.
tokens = list(getattr(request, "input_token_ids", []) or []) tokens = list(getattr(request, "input_token_ids", []) or [])
if not tokens: if not tokens:
# No token_ids available — can't insert into radix. P will fall back # No token_ids can't insert into radix; tell P to free the slab.
# to normal prefill but will have wasted slots. Discard.
try: try:
await client.post( await client.post(
f"{prefill_url}/_snapshot/finalize_ingest", f"{prefill_url}/_snapshot/finalize_ingest",
json={ json={
"session_id": request.session_id, "session_id": request.session_id,
"token_ids": [], "token_ids": [],
"slot_indices": prep["slot_indices"],
}, },
timeout=15.0, timeout=15.0,
) )
@@ -2242,7 +2240,7 @@ async def _attempt_d_to_p_sync(
) )
return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)} return {"status": "no-tokens-discard", "bytes_pushed": dump.get("bytes_pushed", 0)}
n = min(len(tokens), len(prep["slot_indices"])) n = min(len(tokens), int(prep.get("num_tokens", 0)))
t_fin0 = time.perf_counter() t_fin0 = time.perf_counter()
try: try:
fin_resp = await client.post( fin_resp = await client.post(
@@ -2250,7 +2248,6 @@ async def _attempt_d_to_p_sync(
json={ json={
"session_id": request.session_id, "session_id": request.session_id,
"token_ids": tokens[:n], "token_ids": tokens[:n],
"slot_indices": prep["slot_indices"][:n],
}, },
timeout=30.0, timeout=30.0,
) )

View File

@@ -1,34 +1,29 @@
"""SnapshotLinkController — drives D→P RDMA snapshot pushes. """SnapshotLinkController — D→P RDMA snapshot pushes with dedicated GPU buffer.
This class owns: Per `docs/SNAPSHOT_STORE_REFACTOR_ZH.md`, this controller now reserves a
* A dedicated ``mooncake.engine.TransferEngine`` (independent of the PD dedicated GPU tensor (``snapshot_buf``) for receiving D→P snapshots, instead
pipeline engine). of competing with the worker's ``token_to_kv_pool_allocator`` at
* Memory-region registrations covering the worker's KV pool layer buffers prepare_receive time. The kv_pool alloc is deferred to ``finalize_ingest``
(registered once at startup for zero per-snapshot overhead). when the bytes are already in hand — if that alloc fails we drop the
* A side-table mapping ``session_id → SnapshotIngestRecord`` for P-side snapshot but RDMA reception itself succeeded.
receivers tracking outstanding ingests until ``snapshot_finalize_ingest``
is called.
Lifecycle: Layout of the snapshot_buf for one session reception (chosen for
SGLang scheduler instantiates one of these per worker if mooncake's batch_transfer_sync_write friendliness — every layer maps to
``SGLANG_SNAPSHOT_LINK_ENABLE=1``. The scheduler is responsible for a single contiguous slab):
feeding kv_pool buffer descriptors and the kv_pool_allocator so the
controller can pre-register memory and (on P) allocate slots.
Direction symmetry: [K_layer_0: num_tokens × stride_k_bytes]
Both D and P workers run identical controllers. Who sends and who [K_layer_1: num_tokens × stride_k_bytes]
receives is determined by who calls ``push_session_kv`` vs ...
``prepare_receive`` / ``finalize_ingest`` — there's no PREFILL/DECODE [K_layer_L-1]
role baked into the snapshot side. [V_layer_0: num_tokens × stride_v_bytes]
...
[V_layer_L-1]
This is the **vendored** copy living alongside SGLang internals so the The buffer is split into multiple such slabs via ``SnapshotBufAllocator``.
scheduler can import it directly. ``src/agentic_pd_hybrid/snapshot_link.py``
contains the same primitives for stand-alone smoke testing.
""" """
from __future__ import annotations from __future__ import annotations
import ctypes
import logging import logging
import os import os
import threading import threading
@@ -44,6 +39,10 @@ SNAPSHOT_LINK_HOST_ENV = "SGLANG_SNAPSHOT_LINK_HOST"
SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT" SNAPSHOT_LINK_PORT_ENV = "SGLANG_SNAPSHOT_LINK_PORT"
SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE" SNAPSHOT_LINK_IB_DEVICE_ENV = "SGLANG_SNAPSHOT_LINK_IB_DEVICE"
# Default snapshot_buf size: 8 GB. Enough for ~1.5 Qwen3-30B 50k-token sessions.
SNAPSHOT_BUF_BYTES_ENV = "SGLANG_SNAPSHOT_LINK_BUF_BYTES"
DEFAULT_SNAPSHOT_BUF_BYTES = 8 * 1024 * 1024 * 1024
@dataclass @dataclass
class _LayerBufferDesc: class _LayerBufferDesc:
@@ -56,13 +55,75 @@ class _LayerBufferDesc:
@dataclass @dataclass
class SnapshotIngestRecord: class SnapshotIngestRecord:
"""P-side: bookkeeping for an outstanding incoming snapshot.""" """P-side bookkeeping for one in-flight snapshot reception."""
session_id: str session_id: str
slot_indices: List[int] # kv_pool slots reserved for this ingest slab_offset: int # offset within snapshot_buf
slab_size: int # total bytes for this slab
num_tokens: int num_tokens: int
k_layer_offsets: List[int] # absolute byte offsets of K layers in snapshot_buf
v_layer_offsets: List[int]
per_token_k_bytes: int
per_token_v_bytes: int
created_at: float = field(default_factory=time.time) created_at: float = field(default_factory=time.time)
class SnapshotBufAllocator:
"""First-fit free-list allocator over a single contiguous byte range.
Tracks gaps in a sorted list. Merges adjacent free regions on free().
"""
def __init__(self, capacity_bytes: int):
self.capacity = capacity_bytes
# Free regions sorted by offset: [(offset, size), ...]
self._free: List[Tuple[int, int]] = [(0, capacity_bytes)]
self._lock = threading.Lock()
self._inflight: dict[int, int] = {} # offset → size for sanity check
def alloc(self, size: int) -> Optional[int]:
"""Return offset of allocated region, or None if no fit available."""
if size <= 0:
return None
# Page-align allocations to 4 KB for RDMA-friendly alignment.
size = (size + 4095) & ~4095
with self._lock:
for i, (off, sz) in enumerate(self._free):
if sz >= size:
if sz == size:
self._free.pop(i)
else:
self._free[i] = (off + size, sz - size)
self._inflight[off] = size
return off
return None
def free(self, offset: int) -> bool:
"""Return True if the offset was successfully freed."""
with self._lock:
size = self._inflight.pop(offset, None)
if size is None:
return False
# Insert sorted and merge adjacents
self._free.append((offset, size))
self._free.sort()
merged: List[Tuple[int, int]] = []
for off, sz in self._free:
if merged and merged[-1][0] + merged[-1][1] == off:
merged[-1] = (merged[-1][0], merged[-1][1] + sz)
else:
merged.append((off, sz))
self._free = merged
return True
def available_bytes(self) -> int:
with self._lock:
return sum(sz for _, sz in self._free)
def in_use_bytes(self) -> int:
with self._lock:
return sum(self._inflight.values())
def _import_transfer_engine(): def _import_transfer_engine():
try: try:
from mooncake.engine import TransferEngine from mooncake.engine import TransferEngine
@@ -75,7 +136,13 @@ def _import_transfer_engine():
class SnapshotLinkController: class SnapshotLinkController:
"""Mooncake engine + per-layer mem registrations + ingest bookkeeping.""" """Owns mooncake engine + kv_pool registrations + snapshot_buf + records.
D-side use: push session KV via ``push_session_to_snapshot_buf``.
P-side use: ``prepare_receive`` → caller pushes via RDMA →
``ingest_snapshot_into_kvpool`` (does GPU memcpy +
radix insert) → ``finalize_record`` (frees the slab).
"""
def __init__( def __init__(
self, self,
@@ -84,24 +151,16 @@ class SnapshotLinkController:
ib_device: Optional[str], ib_device: Optional[str],
kv_pool_layer_buffers: List[Tuple[int, int, int, bool]], kv_pool_layer_buffers: List[Tuple[int, int, int, bool]],
token_to_kv_pool_allocator, token_to_kv_pool_allocator,
tree_cache=None,
protocol: Optional[str] = None, protocol: Optional[str] = None,
snapshot_buf_bytes: Optional[int] = None,
): ):
"""
Parameters
----------
host, port : where this worker binds its snapshot engine.
ib_device : preferred IB HCA (e.g. "mlx5_60").
kv_pool_layer_buffers : list of ``(base_ptr, bytes_per_token,
capacity_bytes, is_k)`` tuples. Order should be K-layers
first then V-layers (matches MHATokenToKVPool layout).
token_to_kv_pool_allocator : the worker's allocator. We use
``.alloc(N)`` on the P side to reserve receive slots.
"""
TransferEngine = _import_transfer_engine() TransferEngine = _import_transfer_engine()
self.host = host self.host = host
self.port = port self.port = port
self.ib_device = ib_device self.ib_device = ib_device
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.tree_cache = tree_cache
self.layer_buffers: List[_LayerBufferDesc] = [ self.layer_buffers: List[_LayerBufferDesc] = [
_LayerBufferDesc( _LayerBufferDesc(
base_ptr=base, bytes_per_token=btok, base_ptr=base, bytes_per_token=btok,
@@ -121,38 +180,67 @@ class SnapshotLinkController:
) )
self._session_id = f"{host}:{self.engine.get_rpc_port()}" self._session_id = f"{host}:{self.engine.get_rpc_port()}"
# Register all layer buffers up-front (one-shot) # Register existing kv_pool layer buffers (needed for D-side send and
# for P-side ingest copy source = snapshot_buf, destination = kv_pool)
ptrs = [d.base_ptr for d in self.layer_buffers] ptrs = [d.base_ptr for d in self.layer_buffers]
lens = [d.capacity_bytes for d in self.layer_buffers] lens = [d.capacity_bytes for d in self.layer_buffers]
try: try:
reg_ret = self.engine.batch_register_memory(ptrs, lens) reg_ret = self.engine.batch_register_memory(ptrs, lens)
except Exception: except Exception:
reg_ret = -1
# Fall back to individual register_memory calls
reg_ret = 0 reg_ret = 0
for ptr, length in zip(ptrs, lens): for ptr, length in zip(ptrs, lens):
r = self.engine.register_memory(ptr, length) r = self.engine.register_memory(ptr, length)
if r != 0: if r != 0:
logger.warning(
"SnapshotLinkController register_memory(%s, %d) returned %d",
hex(ptr), length, r,
)
reg_ret = r reg_ret = r
if reg_ret != 0: if reg_ret != 0:
logger.warning( logger.warning(
"SnapshotLinkController batch_register_memory returned %d " "SnapshotLinkController kv_pool batch_register returned %d", reg_ret
"(continuing — individual registrations may have succeeded)",
reg_ret,
) )
# Allocate + register the dedicated snapshot reception buffer (P-side)
# This decouples reception from kv_pool, avoiding the alloc-failed
# death loop that killed E4-v4/v5.
import torch
if snapshot_buf_bytes is None:
snapshot_buf_bytes = int(
os.environ.get(SNAPSHOT_BUF_BYTES_ENV, DEFAULT_SNAPSHOT_BUF_BYTES)
)
device = self._allocator_device()
try:
self.snapshot_buf = torch.zeros(
snapshot_buf_bytes, dtype=torch.uint8, device=device,
)
except RuntimeError as e:
logger.warning(
"Could not allocate snapshot_buf of %d bytes on %s: %s. "
"Falling back to 1 GB.", snapshot_buf_bytes, device, e,
)
snapshot_buf_bytes = 1024 * 1024 * 1024
self.snapshot_buf = torch.zeros(
snapshot_buf_bytes, dtype=torch.uint8, device=device,
)
self._snapshot_buf_bytes = snapshot_buf_bytes
self._snapshot_buf_ptr = self.snapshot_buf.data_ptr()
ret = self.engine.register_memory(self._snapshot_buf_ptr, snapshot_buf_bytes)
if ret != 0:
logger.warning(
"SnapshotLinkController snapshot_buf register_memory(%s, %d) ret=%d",
hex(self._snapshot_buf_ptr), snapshot_buf_bytes, ret,
)
self.snapshot_buf_alloc = SnapshotBufAllocator(snapshot_buf_bytes)
# Receive-side bookkeeping # Receive-side bookkeeping
self._ingest_records: dict[str, SnapshotIngestRecord] = {} self._ingest_records: dict[str, SnapshotIngestRecord] = {}
self._records_by_handle: dict[int, SnapshotIngestRecord] = {}
self._next_handle = 1
self._lock = threading.Lock() self._lock = threading.Lock()
logger.info( logger.info(
"SnapshotLinkController up at %s (snapshot_session_id=%s, " "SnapshotLinkController up at %s (sid=%s, %d kv layer bufs, "
"%d layer buffers registered, ib=%s)", "snapshot_buf=%.1f GB on %s)",
listen, self._session_id, len(self.layer_buffers), ib_device, listen, self._session_id, len(self.layer_buffers),
snapshot_buf_bytes / 1e9, device,
) )
# ----- accessors ---------------------------------------------------- # ----- accessors ----------------------------------------------------
@@ -161,9 +249,16 @@ class SnapshotLinkController:
def snapshot_session_id(self) -> str: def snapshot_session_id(self) -> str:
return self._session_id return self._session_id
@property
def snapshot_buf_ptr(self) -> int:
return self._snapshot_buf_ptr
@property
def snapshot_buf_bytes(self) -> int:
return self._snapshot_buf_bytes
@property @property
def layer_num(self) -> int: def layer_num(self) -> int:
"""Number of K layers (= number of V layers = total / 2)."""
return len(self.layer_buffers) // 2 return len(self.layer_buffers) // 2
def get_k_base_ptrs(self) -> List[int]: def get_k_base_ptrs(self) -> List[int]:
@@ -184,66 +279,6 @@ class SnapshotLinkController:
return d.bytes_per_token return d.bytes_per_token
return 0 return 0
# ----- P-side: prepare to receive ----------------------------------
def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
"""Allocate ``num_tokens`` slots in kv_pool for an incoming snapshot.
Returns a record with the slot indices, or ``None`` if no capacity.
Mooncake registration is already in place (whole-buffer registration
at startup), so no per-snapshot register is needed.
"""
try:
indices_tensor = self.token_to_kv_pool_allocator.alloc(num_tokens)
except Exception as e:
logger.exception("SnapshotLinkController.prepare_receive alloc failed: %s", e)
return None
if indices_tensor is None:
return None
try:
slot_indices = [int(x) for x in indices_tensor.tolist()]
except Exception:
# If allocator returns a python list directly
slot_indices = list(map(int, indices_tensor))
record = SnapshotIngestRecord(
session_id=session_id,
slot_indices=slot_indices,
num_tokens=num_tokens,
)
with self._lock:
old = self._ingest_records.pop(session_id, None)
if old is not None:
# Best-effort: free old slots if we're overwriting
try:
self._free_slots(old.slot_indices)
except Exception:
pass
self._ingest_records[session_id] = record
return record
def take_record(self, session_id: str) -> Optional[SnapshotIngestRecord]:
with self._lock:
return self._ingest_records.pop(session_id, None)
def discard_record(self, session_id: str) -> None:
"""Drop a pending ingest (e.g. on timeout or D-side failure)."""
with self._lock:
rec = self._ingest_records.pop(session_id, None)
if rec is not None:
try:
self._free_slots(rec.slot_indices)
except Exception:
pass
def _free_slots(self, slot_indices: List[int]) -> None:
import torch
t = torch.tensor(slot_indices, dtype=torch.int64,
device=self._allocator_device())
try:
self.token_to_kv_pool_allocator.free(t)
except Exception as e:
logger.warning("SnapshotLinkController._free_slots failed: %s", e)
def _allocator_device(self): def _allocator_device(self):
# Best-effort: pull device from one of the buffer tensors via the allocator # Best-effort: pull device from one of the buffer tensors via the allocator
try: try:
@@ -251,91 +286,291 @@ class SnapshotLinkController:
except AttributeError: except AttributeError:
return "cuda" return "cuda"
# ----- D-side: push session KV -------------------------------------- # ----- P-side: prepare to receive ----------------------------------
def push_session_kv( def prepare_receive(self, session_id: str, num_tokens: int) -> Optional[SnapshotIngestRecord]:
"""Carve a slab out of snapshot_buf large enough for num_tokens of K+V.
Returns the record describing the slab layout, or None if snapshot_buf
is full. This does NOT touch kv_pool — alloc happens at ingest time.
"""
if num_tokens <= 0:
return None
stride_k = self.get_stride_k_bytes()
stride_v = self.get_stride_v_bytes()
L = self.layer_num
slab_bytes = L * num_tokens * stride_k + L * num_tokens * stride_v
offset = self.snapshot_buf_alloc.alloc(slab_bytes)
if offset is None:
logger.info(
"prepare_receive: snapshot_buf full (sid=%s n=%d need=%d B available=%d B)",
session_id, num_tokens, slab_bytes,
self.snapshot_buf_alloc.available_bytes(),
)
return None
# Layout: K0..KL-1, then V0..VL-1
k_offs = [offset + i * num_tokens * stride_k for i in range(L)]
v_offs = [offset + L * num_tokens * stride_k + i * num_tokens * stride_v
for i in range(L)]
record = SnapshotIngestRecord(
session_id=session_id,
slab_offset=offset,
slab_size=slab_bytes,
num_tokens=num_tokens,
k_layer_offsets=k_offs,
v_layer_offsets=v_offs,
per_token_k_bytes=stride_k,
per_token_v_bytes=stride_v,
)
with self._lock:
# Evict prior record for the same session (best-effort)
old = self._ingest_records.pop(session_id, None)
if old is not None:
self.snapshot_buf_alloc.free(old.slab_offset)
self._records_by_handle.pop(id(old), None)
self._ingest_records[session_id] = record
self._records_by_handle[id(record)] = record
return record
def lookup_by_handle(self, handle: int) -> Optional[SnapshotIngestRecord]:
with self._lock:
return self._records_by_handle.get(handle)
def discard_record(self, session_id: str) -> None:
with self._lock:
rec = self._ingest_records.pop(session_id, None)
if rec is not None:
self.snapshot_buf_alloc.free(rec.slab_offset)
with self._lock:
self._records_by_handle.pop(id(rec), None)
def total_pending_snapshot_bytes(self) -> int:
with self._lock:
return sum(rec.slab_size for rec in self._ingest_records.values())
# ----- P-side: ingest snapshot into kv_pool + radix tree -----------
def ingest_snapshot_into_kvpool(
self,
session_id: str,
token_ids: List[int],
) -> Tuple[bool, str, int]:
"""Copy snapshot_buf bytes into kv_pool slots and insert into radix.
Returns (ok, reason, inserted_prefix_len).
"""
with self._lock:
record = self._ingest_records.pop(session_id, None)
if record is not None:
self._records_by_handle.pop(id(record), None)
if record is None:
return False, "no-pending-ingest", 0
try:
n = min(len(token_ids), record.num_tokens)
if n == 0:
self.snapshot_buf_alloc.free(record.slab_offset)
return False, "empty-token-ids", 0
# Alloc kv_pool slots NOW that the snapshot bytes are in hand.
try:
indices_tensor = self.token_to_kv_pool_allocator.alloc(n)
except Exception as exc:
self.snapshot_buf_alloc.free(record.slab_offset)
return False, f"kvpool-alloc-threw:{exc!r}", 0
if indices_tensor is None:
self.snapshot_buf_alloc.free(record.slab_offset)
return False, "kvpool-alloc-failed-at-ingest", 0
# GPU→GPU copy from snapshot_buf into kv_pool layer buffers
try:
self._copy_snapshot_to_kvpool(record, indices_tensor)
except Exception as exc:
logger.exception("snapshot→kvpool copy failed: %s", exc)
# Free both allocations
self._free_slot_indices(indices_tensor)
self.snapshot_buf_alloc.free(record.slab_offset)
return False, f"copy-failed:{exc!r}", 0
# Insert into radix tree
try:
inserted_prefix_len = self._radix_insert(token_ids[:n], indices_tensor)
except Exception as exc:
logger.exception("radix insert failed: %s", exc)
self._free_slot_indices(indices_tensor)
self.snapshot_buf_alloc.free(record.slab_offset)
return False, f"radix-insert-failed:{exc!r}", 0
# Snapshot is now persisted into kv_pool + radix; the slab is no
# longer needed.
self.snapshot_buf_alloc.free(record.slab_offset)
return True, "ok", int(inserted_prefix_len)
except Exception as exc:
# Belt-and-braces cleanup
try:
self.snapshot_buf_alloc.free(record.slab_offset)
except Exception:
pass
return False, f"unexpected:{exc!r}", 0
def _copy_snapshot_to_kvpool(
self,
record: SnapshotIngestRecord,
slot_indices_tensor,
) -> None:
"""For each layer L: copy snapshot_buf[K_off[L]..] → k_buffer[L][slots]."""
import torch
n = record.num_tokens
stride_k = record.per_token_k_bytes
stride_v = record.per_token_v_bytes
# View snapshot_buf as a 1-D byte tensor; slice by offsets.
for L in range(self.layer_num):
# K
k_slab_start = record.k_layer_offsets[L] - record.slab_offset + record.slab_offset
# NOTE: above is equivalent to record.k_layer_offsets[L] but kept for clarity
k_slab_start = record.k_layer_offsets[L]
k_layer_bytes = self.snapshot_buf[
k_slab_start : k_slab_start + n * stride_k
].view(n, stride_k)
# Compute destination tensor on kv_pool: dst[slot_indices] = src
# We need access to the actual k_buffer[L] tensor. The controller
# only has the raw ptr — so we materialize a view via from_blob-ish
# trick. Easier: get the tensor from token_to_kv_pool_allocator's kvcache.
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
k_buf = kv_cache.k_buffer[L] # (max_tokens, head, dim)
# Flatten per-token to bytes
flat = k_buf.view(k_buf.shape[0], -1)
assert flat.shape[1] * flat.element_size() >= stride_k, (
f"K layer {L} stride mismatch: pool {flat.shape[1] * flat.element_size()} vs snapshot {stride_k}"
)
# Copy: dst[slot_indices] ← src[:n]
src_reshape = k_layer_bytes.view(n, flat.shape[1] * flat.element_size())
# Byte-level view of destination rows
dst_view = flat.view(torch.uint8)
dst_view[slot_indices_tensor] = src_reshape
# V
v_slab_start = record.v_layer_offsets[L]
v_layer_bytes = self.snapshot_buf[
v_slab_start : v_slab_start + n * stride_v
]
v_buf = kv_cache.v_buffer[L]
v_flat = v_buf.view(v_buf.shape[0], -1)
src_v = v_layer_bytes.view(n, v_flat.shape[1] * v_flat.element_size())
v_dst_view = v_flat.view(torch.uint8)
v_dst_view[slot_indices_tensor] = src_v
def _radix_insert(self, token_ids: List[int], indices_tensor) -> int:
"""Insert (token_ids, kv_indices) into the underlying radix tree."""
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache
inner = self.tree_cache
if isinstance(inner, SessionAwareCache):
inner = inner.inner
if inner is None:
raise RuntimeError("tree_cache not provided to SnapshotLinkController")
radix_key = RadixKey(token_ids, None)
result = inner.insert(InsertParams(key=radix_key, value=indices_tensor))
return int(getattr(result, "prefix_len", 0))
def _free_slot_indices(self, indices_tensor) -> None:
try:
self.token_to_kv_pool_allocator.free(indices_tensor)
except Exception as e:
logger.warning("_free_slot_indices failed: %s", e)
# ----- D-side: push session KV to a peer's snapshot_buf ------------
def push_session_to_snapshot_buf(
self, self,
*, *,
target_snapshot_session_id: str, target_snapshot_session_id: str,
src_slot_indices: List[int], src_slot_indices: List[int],
target_k_base_ptrs: List[int], target_snapshot_buf_base: int,
target_v_base_ptrs: List[int], target_k_layer_offsets: List[int],
target_slot_indices: List[int], target_v_layer_offsets: List[int],
target_stride_k_bytes: int, target_per_token_k_bytes: int,
target_stride_v_bytes: int, target_per_token_v_bytes: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Push the KV bytes at src_slot_indices to the remote slots. """Push session KV from local kv_pool into a peer's snapshot_buf slab.
Returns ``(mooncake_return_code, bytes_pushed)``. Caller is For each layer: gather src ranges (possibly scattered slot indices)
responsible for any post-push handshake (e.g. finalize_ingest RPC). and write to a contiguous range in the peer's snapshot_buf.
Returns (mooncake_return_code, bytes_pushed).
""" """
if not src_slot_indices:
return 0, 0
layer_num = self.layer_num layer_num = self.layer_num
k_src_bases = self.get_k_base_ptrs() k_src_bases = self.get_k_base_ptrs()
v_src_bases = self.get_v_base_ptrs() v_src_bases = self.get_v_base_ptrs()
stride_k = self.get_stride_k_bytes() stride_k = self.get_stride_k_bytes()
stride_v = self.get_stride_v_bytes() stride_v = self.get_stride_v_bytes()
if len(target_k_base_ptrs) != layer_num or len(target_v_base_ptrs) != layer_num: if (len(target_k_layer_offsets) != layer_num
or len(target_v_layer_offsets) != layer_num):
raise ValueError( raise ValueError(
f"target K/V base ptr count {len(target_k_base_ptrs)}/{len(target_v_base_ptrs)} " f"target K/V layer offset count {len(target_k_layer_offsets)}/"
f"!= local layer_num {layer_num}" f"{len(target_v_layer_offsets)} != local layer_num {layer_num}"
) )
if stride_k != target_stride_k_bytes or stride_v != target_stride_v_bytes: if (stride_k != target_per_token_k_bytes
or stride_v != target_per_token_v_bytes):
raise ValueError( raise ValueError(
f"stride mismatch: local k={stride_k}, v={stride_v}; " f"stride mismatch: local k={stride_k}/v={stride_v}, "
f"target k={target_stride_k_bytes}, v={target_stride_v_bytes}" f"target k={target_per_token_k_bytes}/v={target_per_token_v_bytes}"
)
if len(src_slot_indices) != len(target_slot_indices):
raise ValueError(
f"slot count mismatch: src={len(src_slot_indices)}, "
f"target={len(target_slot_indices)}"
) )
n = len(src_slot_indices)
local_addrs: List[int] = [] local_addrs: List[int] = []
remote_addrs: List[int] = [] remote_addrs: List[int] = []
lengths: List[int] = [] lengths: List[int] = []
# Group contiguous runs on the target side to coalesce ops. # Coalesce contiguous src runs.
# Simple approach: per (layer, K/V) pair, walk src/target index # Inner-loop helper to walk indices and emit run boundaries.
# tuples; merge runs where both src and target are sequential. def _emit_runs(src_base: int, tgt_base: int, stride: int) -> None:
for layer_id in range(layer_num): run_src_start = run_tgt_start = run_len = None
for kv_bases, stride, kv_label in ( for tgt_idx, src in enumerate(src_slot_indices):
(k_src_bases[layer_id], stride_k, "K"), if run_src_start is None:
(v_src_bases[layer_id], stride_v, "V"), run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
): elif src == run_src_start + run_len:
src_base = kv_bases run_len += 1
if kv_label == "K":
tgt_base = target_k_base_ptrs[layer_id]
else: else:
tgt_base = target_v_base_ptrs[layer_id]
run_src_start = run_tgt_start = run_len = None
for s, t in zip(src_slot_indices, target_slot_indices):
if run_src_start is None:
run_src_start, run_tgt_start, run_len = s, t, 1
elif s == run_src_start + run_len and t == run_tgt_start + run_len:
run_len += 1
else:
local_addrs.append(src_base + run_src_start * stride)
remote_addrs.append(tgt_base + run_tgt_start * stride)
lengths.append(run_len * stride)
run_src_start, run_tgt_start, run_len = s, t, 1
if run_src_start is not None:
local_addrs.append(src_base + run_src_start * stride) local_addrs.append(src_base + run_src_start * stride)
remote_addrs.append(tgt_base + run_tgt_start * stride) remote_addrs.append(tgt_base + run_tgt_start * stride)
lengths.append(run_len * stride) lengths.append(run_len * stride)
run_src_start, run_tgt_start, run_len = src, tgt_idx, 1
if run_src_start is not None:
local_addrs.append(src_base + run_src_start * stride)
remote_addrs.append(tgt_base + run_tgt_start * stride)
lengths.append(run_len * stride)
for L in range(layer_num):
_emit_runs(
k_src_bases[L],
target_snapshot_buf_base + target_k_layer_offsets[L],
stride_k,
)
_emit_runs(
v_src_bases[L],
target_snapshot_buf_base + target_v_layer_offsets[L],
stride_v,
)
t0 = time.perf_counter() t0 = time.perf_counter()
try: try:
ret = self.engine.batch_transfer_sync_write( ret = self.engine.batch_transfer_sync_write(
target_snapshot_session_id, local_addrs, remote_addrs, lengths target_snapshot_session_id, local_addrs, remote_addrs, lengths,
) )
except Exception as e: except Exception as e:
logger.exception("SnapshotLinkController.push_session_kv threw: %s", e) logger.exception(
"SnapshotLinkController.push_session_to_snapshot_buf threw: %s", e
)
return -1, 0 return -1, 0
t1 = time.perf_counter() t1 = time.perf_counter()
bytes_pushed = sum(lengths) bytes_pushed = sum(lengths)
logger.info( logger.info(
"SnapshotLinkController.push_session_kv%s: %d ops, %d B, " "push_session_to_snapshot_buf%s: %d ops, %d B, ret=%d, %.2f ms",
"ret=%d, %.2f ms",
target_snapshot_session_id, len(lengths), bytes_pushed, ret, target_snapshot_session_id, len(lengths), bytes_pushed, ret,
(t1 - t0) * 1000.0, (t1 - t0) * 1000.0,
) )

View File

@@ -1662,38 +1662,38 @@ class SnapshotPrepareReceiveReqInput(BaseReq):
@dataclass @dataclass
class SnapshotPrepareReceiveReqOutput(BaseReq): class SnapshotPrepareReceiveReqOutput(BaseReq):
"""P-side response. New schema points D at P's dedicated snapshot_buf."""
ok: bool ok: bool
reason: Optional[str] = None reason: Optional[str] = None
# Layout the D side needs to address P's kv_pool slots: # P's mooncake snapshot session id (host:rpc_port) for D's batch write target
# k_base_ptrs[L] = base device address of layer L's K buffer snapshot_session_id: str = ""
# v_base_ptrs[L] = base device address of layer L's V buffer # snapshot_buf base pointer + per-layer offsets, replacing the old
# slot_indices = the contiguous range P allocated (list[int], length=num_tokens) # kv_pool slot_indices scheme that competed with P's prefill work and
# stride_k_bytes = bytes per token K = head_num * head_dim * dtype.itemsize # always hit alloc-failed. See docs/SNAPSHOT_STORE_REFACTOR_ZH.md.
# stride_v_bytes = bytes per token V (often equals stride_k_bytes) snapshot_buf_base_ptr: int = 0
# P also registers these slot regions with mooncake before returning. snapshot_buf_capacity_bytes: int = 0
k_base_ptrs: List[int] = field(default_factory=list) k_layer_offsets: List[int] = field(default_factory=list) # bytes within snapshot_buf
v_base_ptrs: List[int] = field(default_factory=list) v_layer_offsets: List[int] = field(default_factory=list)
slot_indices: List[int] = field(default_factory=list) num_tokens: int = 0
stride_k_bytes: int = 0 stride_k_bytes: int = 0
stride_v_bytes: int = 0 stride_v_bytes: int = 0
layer_num: int = 0 layer_num: int = 0
# P's mooncake snapshot session id (host:rpc_port) for D's batch write target
snapshot_session_id: str = ""
available_tokens: int = 0 available_tokens: int = 0
@dataclass @dataclass
class SnapshotDumpReqInput(BaseReq): class SnapshotDumpReqInput(BaseReq):
"""D-side: dump session KV via snapshot_link to a target P endpoint.""" """D-side: dump session KV via snapshot_link into P's snapshot_buf slab."""
session_id: str session_id: str
target_snapshot_session_id: str # P's mooncake snapshot session id target_snapshot_session_id: str
target_k_base_ptrs: List[int] = field(default_factory=list) target_snapshot_buf_base: int = 0
target_v_base_ptrs: List[int] = field(default_factory=list) target_k_layer_offsets: List[int] = field(default_factory=list)
target_slot_indices: List[int] = field(default_factory=list) target_v_layer_offsets: List[int] = field(default_factory=list)
target_stride_k_bytes: int = 0 target_stride_k_bytes: int = 0
target_stride_v_bytes: int = 0 target_stride_v_bytes: int = 0
ib_device: Optional[str] = None # for the D-side SnapshotPeer initialization ib_device: Optional[str] = None
@dataclass @dataclass
@@ -1709,11 +1709,10 @@ class SnapshotDumpReqOutput(BaseReq):
@dataclass @dataclass
class SnapshotFinalizeIngestReqInput(BaseReq): class SnapshotFinalizeIngestReqInput(BaseReq):
"""P-side: insert (token_ids, kv_indices) into radix tree after D's push.""" """P-side: copy snapshot_buf slab into kv_pool + insert into radix tree."""
session_id: str session_id: str
token_ids: List[int] token_ids: List[int]
slot_indices: List[int]
@dataclass @dataclass

View File

@@ -902,6 +902,7 @@ class Scheduler(
ib_device=ib, ib_device=ib,
kv_pool_layer_buffers=layer_buffers, kv_pool_layer_buffers=layer_buffers,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tree_cache=self.tree_cache,
) )
logger.info( logger.info(
"Snapshot link controller initialized: %s, sid=%s, %d layer bufs", "Snapshot link controller initialized: %s, sid=%s, %d layer bufs",
@@ -3750,7 +3751,12 @@ class Scheduler(
def snapshot_prepare_receive( def snapshot_prepare_receive(
self, recv_req: SnapshotPrepareReceiveReqInput self, recv_req: SnapshotPrepareReceiveReqInput
) -> SnapshotPrepareReceiveReqOutput: ) -> SnapshotPrepareReceiveReqOutput:
"""P-side: alloc kv_pool slots, return slot/buffer layout for D's batch_push.""" """P-side: carve snapshot_buf slab + return its layout to caller.
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: this no longer
touches the kv_pool allocator. The slab is in a dedicated
snapshot_buf so prepare can never lose to P's prefill work.
"""
ctrl = self.snapshot_link_controller ctrl = self.snapshot_link_controller
if ctrl is None: if ctrl is None:
return SnapshotPrepareReceiveReqOutput( return SnapshotPrepareReceiveReqOutput(
@@ -3765,25 +3771,27 @@ class Scheduler(
record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens) record = ctrl.prepare_receive(recv_req.session_id, recv_req.num_tokens)
if record is None: if record is None:
return SnapshotPrepareReceiveReqOutput( return SnapshotPrepareReceiveReqOutput(
ok=False, reason="alloc-failed", ok=False, reason="snapshot-buf-full",
available_tokens=available, available_tokens=available,
) )
return SnapshotPrepareReceiveReqOutput( return SnapshotPrepareReceiveReqOutput(
ok=True, ok=True,
k_base_ptrs=ctrl.get_k_base_ptrs(),
v_base_ptrs=ctrl.get_v_base_ptrs(),
slot_indices=record.slot_indices,
stride_k_bytes=ctrl.get_stride_k_bytes(),
stride_v_bytes=ctrl.get_stride_v_bytes(),
layer_num=ctrl.layer_num,
snapshot_session_id=ctrl.snapshot_session_id, snapshot_session_id=ctrl.snapshot_session_id,
snapshot_buf_base_ptr=ctrl.snapshot_buf_ptr,
snapshot_buf_capacity_bytes=ctrl.snapshot_buf_bytes,
k_layer_offsets=record.k_layer_offsets,
v_layer_offsets=record.v_layer_offsets,
num_tokens=record.num_tokens,
stride_k_bytes=record.per_token_k_bytes,
stride_v_bytes=record.per_token_v_bytes,
layer_num=ctrl.layer_num,
available_tokens=available, available_tokens=available,
) )
def snapshot_dump( def snapshot_dump(
self, recv_req: SnapshotDumpReqInput self, recv_req: SnapshotDumpReqInput
) -> SnapshotDumpReqOutput: ) -> SnapshotDumpReqOutput:
"""D-side: gather session KV from kv_pool, RDMA-write to remote slots.""" """D-side: gather session KV from kv_pool, RDMA-write into P's snapshot_buf."""
ctrl = self.snapshot_link_controller ctrl = self.snapshot_link_controller
if ctrl is None: if ctrl is None:
return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled") return SnapshotDumpReqOutput(ok=False, reason="snapshot-link-disabled")
@@ -3795,110 +3803,60 @@ class Scheduler(
kv_committed_len = int(slot.kv_committed_len) kv_committed_len = int(slot.kv_committed_len)
if kv_committed_len == 0: if kv_committed_len == 0:
return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len") return SnapshotDumpReqOutput(ok=False, reason="zero-committed-len")
# Read kv_indices for the session's prefix
try: try:
kv_idx_tensor = self.req_to_token_pool.req_to_token[ kv_idx_tensor = self.req_to_token_pool.req_to_token[
slot.req_pool_idx, :kv_committed_len slot.req_pool_idx, :kv_committed_len
] ]
src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()] src_slot_indices = [int(x) for x in kv_idx_tensor.tolist()]
except Exception as e: except Exception as e:
logger.exception("snapshot_dump: failed to read kv_indices: %s", e) return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed:{e!r}")
return SnapshotDumpReqOutput(ok=False, reason=f"read-indices-failed: {e!r}")
# Truncate to the count P prepared for (must match)
target_n = len(recv_req.target_slot_indices)
if target_n > kv_committed_len:
return SnapshotDumpReqOutput(
ok=False,
reason=f"target-larger-than-source({target_n}>{kv_committed_len})",
)
src_slot_indices = src_slot_indices[:target_n]
try: try:
ret, bytes_pushed = ctrl.push_session_kv( ret, bytes_pushed = ctrl.push_session_to_snapshot_buf(
target_snapshot_session_id=recv_req.target_snapshot_session_id, target_snapshot_session_id=recv_req.target_snapshot_session_id,
src_slot_indices=src_slot_indices, src_slot_indices=src_slot_indices,
target_k_base_ptrs=recv_req.target_k_base_ptrs, target_snapshot_buf_base=recv_req.target_snapshot_buf_base,
target_v_base_ptrs=recv_req.target_v_base_ptrs, target_k_layer_offsets=recv_req.target_k_layer_offsets,
target_slot_indices=recv_req.target_slot_indices[:target_n], target_v_layer_offsets=recv_req.target_v_layer_offsets,
target_stride_k_bytes=recv_req.target_stride_k_bytes, target_per_token_k_bytes=recv_req.target_stride_k_bytes,
target_stride_v_bytes=recv_req.target_stride_v_bytes, target_per_token_v_bytes=recv_req.target_stride_v_bytes,
) )
except Exception as e: except Exception as e:
logger.exception("snapshot_dump: push_session_kv threw: %s", e) return SnapshotDumpReqOutput(ok=False, reason=f"push-failed:{e!r}")
return SnapshotDumpReqOutput(ok=False, reason=f"push-failed: {e!r}")
if ret != 0: if ret != 0:
return SnapshotDumpReqOutput( return SnapshotDumpReqOutput(
ok=False, ok=False, reason=f"mooncake-batch-write-ret={ret}",
reason=f"mooncake-batch-write-ret={ret}",
bytes_pushed=int(bytes_pushed), bytes_pushed=int(bytes_pushed),
kv_committed_len=int(kv_committed_len), kv_committed_len=kv_committed_len,
) )
return SnapshotDumpReqOutput( return SnapshotDumpReqOutput(
ok=True, bytes_pushed=int(bytes_pushed), ok=True, bytes_pushed=int(bytes_pushed),
kv_committed_len=int(kv_committed_len), kv_committed_len=kv_committed_len,
token_ids=[], # caller already has token_ids token_ids=[],
) )
def snapshot_finalize_ingest( def snapshot_finalize_ingest(
self, recv_req: SnapshotFinalizeIngestReqInput self, recv_req: SnapshotFinalizeIngestReqInput
) -> SnapshotFinalizeIngestReqOutput: ) -> SnapshotFinalizeIngestReqOutput:
"""P-side: insert (token_ids, slot_indices) into radix tree.""" """P-side: copy snapshot_buf slab into kv_pool + insert into radix tree.
Refactored per docs/SNAPSHOT_STORE_REFACTOR_ZH.md: kv_pool alloc
happens HERE (deferred from prepare_receive), so we never block
D's RDMA write on kv_pool contention.
"""
ctrl = self.snapshot_link_controller ctrl = self.snapshot_link_controller
if ctrl is None: if ctrl is None:
return SnapshotFinalizeIngestReqOutput( return SnapshotFinalizeIngestReqOutput(
ok=False, reason="snapshot-link-disabled", ok=False, reason="snapshot-link-disabled",
) )
record = ctrl.take_record(recv_req.session_id) ok, reason, inserted_prefix_len = ctrl.ingest_snapshot_into_kvpool(
if record is None: session_id=recv_req.session_id,
return SnapshotFinalizeIngestReqOutput( token_ids=list(recv_req.token_ids),
ok=False, reason="no-pending-ingest", )
)
# Sanity: the slot indices we're about to insert should match the ones we reserved.
if list(recv_req.slot_indices) != record.slot_indices:
# The caller passed back the slot indices we returned in prepare; if they
# don't match, something's gone wrong. Free reserved slots and bail.
try:
ctrl._free_slots(record.slot_indices)
except Exception:
pass
return SnapshotFinalizeIngestReqOutput(
ok=False,
reason="slot-indices-mismatch",
)
n_tokens = min(len(recv_req.token_ids), len(record.slot_indices))
if n_tokens == 0:
ctrl._free_slots(record.slot_indices)
return SnapshotFinalizeIngestReqOutput(ok=False, reason="empty-token-ids")
try:
import torch
from sglang.srt.mem_cache.base_prefix_cache import InsertParams
from sglang.srt.mem_cache.radix_cache import RadixKey
kv_indices = torch.tensor(
record.slot_indices[:n_tokens],
dtype=torch.int64,
device=self.tree_cache.token_to_kv_pool_allocator.device,
)
radix_key = RadixKey(recv_req.token_ids[:n_tokens], None)
inner = (
self.tree_cache.inner
if isinstance(self.tree_cache, SessionAwareCache)
else self.tree_cache
)
result = inner.insert(InsertParams(key=radix_key, value=kv_indices))
inserted = int(result.prefix_len)
except Exception as e:
logger.exception("snapshot_finalize_ingest: radix insert failed: %s", e)
try:
ctrl._free_slots(record.slot_indices)
except Exception:
pass
return SnapshotFinalizeIngestReqOutput(
ok=False, reason=f"radix-insert-failed: {e!r}",
)
return SnapshotFinalizeIngestReqOutput( return SnapshotFinalizeIngestReqOutput(
ok=True, inserted_prefix_len=inserted, ok=bool(ok), reason=reason if not ok else None,
inserted_prefix_len=int(inserted_prefix_len),
) )
def _compute_backpressure_pause_hint( def _compute_backpressure_pause_hint(

View File

@@ -181,27 +181,18 @@ class SchedulerRuntimeCheckerMixin:
return memory_leak, token_msg return memory_leak, token_msg
def _check_radix_cache_memory(self: Scheduler): def _check_radix_cache_memory(self: Scheduler):
# NB: as of SnapshotStore refactor (see docs/SNAPSHOT_STORE_REFACTOR_ZH.md)
# prepare_receive no longer touches kv_pool — slots are alloc'd from
# a dedicated snapshot_buf. So no snapshot_reserved accounting needed.
_, _, available_size, evictable_size = self._get_token_info() _, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
session_held = self._session_held_tokens() session_held = self._session_held_tokens()
# Snapshot link prepare_receive reserves slots that aren't yet visible
# to radix / session bookkeeping until finalize_ingest. Count them so
# the leak check doesn't fire while a snapshot ingest is in-flight.
snapshot_reserved = 0
ctrl = getattr(self, "snapshot_link_controller", None)
if ctrl is not None:
try:
snapshot_reserved = sum(
len(rec.slot_indices) for rec in ctrl._ingest_records.values()
)
except Exception:
snapshot_reserved = 0
memory_leak = (available_size + evictable_size) != ( memory_leak = (available_size + evictable_size) != (
self.max_total_num_tokens - protected_size - session_held - snapshot_reserved self.max_total_num_tokens - protected_size - session_held
) )
token_msg = ( token_msg = (
f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, " f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, "
f"{protected_size=}, {session_held=}, {snapshot_reserved=}\n" f"{protected_size=}, {session_held=}\n"
) )
return memory_leak, token_msg return memory_leak, token_msg