feat(sglang): support decode session cache admission

This commit is contained in:
2026-04-24 12:30:41 +00:00
parent bded08301f
commit b8e6f13c20
9 changed files with 589 additions and 14 deletions

View File

@@ -140,9 +140,9 @@ class DecodeReqToTokenPool:
# Indices of reqs that already have a req_pool_idx and will reuse
# their existing slot (e.g. chunked prefill continuing across chunks).
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
assert (
len(reusing) <= 1
), "only one chunked request may reuse req_pool_idx in a batch"
# Keep decode-side allocation behavior aligned with the shared pool:
# local append-prefill can legitimately batch multiple resident-session
# requests that each reuse their existing req_pool_idx.
assert all(
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing
), "reusing request must be chunked or have committed KV"
@@ -657,6 +657,12 @@ class DecodePreallocQueue:
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
def refresh_allocatable_tokens() -> int:
return self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
@@ -679,6 +685,13 @@ class DecodePreallocQueue:
if not decode_req.waiting_for_input:
continue
if self.req_to_token_pool.available_size() <= 0:
self.scheduler.maybe_trim_decode_session_cache(
required_tokens=max(1, self.num_reserved_decode_tokens),
force=True,
max_sessions=1,
)
allocatable_tokens = refresh_allocatable_tokens()
if self.req_to_token_pool.available_size() <= 0:
break
@@ -692,6 +705,37 @@ class DecodePreallocQueue:
origin_input_len + self.num_reserved_decode_tokens
)
if (
max(
required_tokens_for_request,
origin_input_len
+ min(
decode_req.req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKEN,
)
- retractable_tokens,
)
> allocatable_tokens
):
required_tokens = max(
0,
max(
required_tokens_for_request,
origin_input_len
+ min(
decode_req.req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKEN,
)
- retractable_tokens,
)
- allocatable_tokens,
)
if required_tokens > 0:
self.scheduler.maybe_trim_decode_session_cache(
required_tokens=required_tokens,
force=True,
)
allocatable_tokens = refresh_allocatable_tokens()
if (
max(
required_tokens_for_request,
@@ -705,6 +749,12 @@ class DecodePreallocQueue:
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens:
self.scheduler.maybe_trim_decode_session_cache(
required_tokens=required_tokens_for_request - allocatable_tokens,
force=True,
)
allocatable_tokens = refresh_allocatable_tokens()
if required_tokens_for_request > allocatable_tokens:
break
@@ -1135,6 +1185,27 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin:
def _merge_last_local_prefill_batch(self: Scheduler):
if self.last_batch is None or not self.last_batch.forward_mode.is_extend():
return
chunked_req_to_exclude = set()
if self.last_batch.chunked_req is not None:
chunked_req_to_exclude.add(self.last_batch.chunked_req)
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)
@torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
@@ -1213,6 +1284,8 @@ class SchedulerDisaggregationDecodeMixin:
self: Scheduler,
) -> Optional[ScheduleBatch]:
"""Process prebuilt batch and schedule the next decode batch."""
self._merge_last_local_prefill_batch()
# Process pending prebuilt batch: output processing + filter + merge
new_prebuilt_batch = self.get_new_prebuilt_batch()
if new_prebuilt_batch:
@@ -1229,6 +1302,13 @@ class SchedulerDisaggregationDecodeMixin:
else:
self.running_batch.merge_batch(new_prebuilt_batch)
new_local_batch = self.get_new_local_extend_batch()
if new_local_batch is not None:
ret = self.maybe_prepare_mlp_sync_batch(new_local_batch)
if ret:
set_schedule_time_batch(ret)
return ret
# Schedule decode batch
if self.running_batch.is_empty():
ret = None
@@ -1241,6 +1321,22 @@ class SchedulerDisaggregationDecodeMixin:
set_schedule_time_batch(ret)
return ret
def get_new_local_extend_batch(self: Scheduler) -> Optional[ScheduleBatch]:
if not self.server_args.disaggregation_decode_allow_local_prefill:
return None
if len(self.decode_direct_waiting_queue) == 0:
return None
original_waiting_queue = self.waiting_queue
try:
self.waiting_queue = self.decode_direct_waiting_queue
new_batch = self.get_new_batch_prefill()
self.decode_direct_waiting_queue = self.waiting_queue
finally:
self.waiting_queue = original_waiting_queue
return new_batch
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
if self.grammar_manager.has_waiting_grammars():

View File

@@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterFromTensorsReqInput,
LoadLoRAAdapterReqInput,
DirectAppendAdmissionReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
PauseGenerationReqInput,
@@ -1289,6 +1290,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
return _create_error_response(e)
@app.post("/session_cache/admit_direct_append")
async def admit_direct_append(obj: DirectAppendAdmissionReqInput):
return await _global_state.tokenizer_manager.admit_direct_append(obj)
@app.api_route("/configure_logging", methods=["GET", "POST"])
@auth_level(AuthLevel.ADMIN_OPTIONAL)
async def configure_logging(obj: ConfigureLoggingReq, request: Request):

View File

@@ -1597,6 +1597,30 @@ class OpenSessionReqOutput(BaseReq):
success: bool
@dataclass
class DirectAppendAdmissionReqInput(BaseReq):
session_id: str
uncached_input_tokens: int
output_tokens: int
@dataclass
class DirectAppendAdmissionReqOutput(BaseReq):
can_admit: bool
resident: bool
reason: Optional[str] = None
required_tokens: int = 0
available_tokens_before: int = 0
available_tokens_after: int = 0
evicted_session_count: int = 0
freed_tokens: int = 0
token_usage: float = 0.0
num_running_reqs: int = 0
decode_prealloc_queue_reqs: int = 0
decode_transfer_queue_reqs: int = 0
decode_retracted_queue_reqs: int = 0
@dataclass
class HealthCheckOutput(BaseReq):
pass

View File

@@ -1566,6 +1566,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Init tensors
reqs = self.reqs
for req in reqs:
if req.session is None or not req.session.streaming:
continue
actual_extend_len = max(0, len(req.fill_ids) - len(req.prefix_indices))
if req.extend_input_len != actual_extend_len:
logger.warning(
"Correcting streaming-session extend_input_len from %d to %d "
"(rid=%s, session_id=%s, fill_len=%d, prefix_len=%d, kv_committed_len=%d)",
req.extend_input_len,
actual_extend_len,
req.rid,
req.session.session_id,
len(req.fill_ids),
len(req.prefix_indices),
req.kv_committed_len,
)
req.set_extend_input_len(actual_extend_len)
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs]

View File

@@ -94,6 +94,8 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqOutput,
CloseSessionReqInput,
ContinueGenerationReqInput,
DirectAppendAdmissionReqInput,
DirectAppendAdmissionReqOutput,
DestroyWeightsUpdateGroupReqInput,
DetachHiCacheStorageReqInput,
DetachHiCacheStorageReqOutput,
@@ -844,6 +846,7 @@ class Scheduler(
def init_running_status(self):
self.waiting_queue: List[Req] = []
self.decode_direct_waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
@@ -1215,6 +1218,7 @@ class Scheduler(
(AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(DirectAppendAdmissionReqInput, self.admit_direct_append),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
@@ -1589,6 +1593,7 @@ class Scheduler(
def process_input_requests(self, recv_reqs: List):
now = time.monotonic()
self.session_controller.maybe_reap(now)
self.maybe_trim_decode_session_cache()
for recv_req in recv_reqs:
# Skip health check when server is busy — ongoing requests already carry health info.
if is_health_check_generate_req(recv_req) and not self.is_fully_idle(
@@ -1781,6 +1786,10 @@ class Scheduler(
# Invalid request for disaggregated mode
if (
recv_req.bootstrap_room is None
and not (
self.disaggregation_mode == DisaggregationMode.DECODE
and self.server_args.disaggregation_decode_allow_local_prefill
)
and self.transfer_backend != TransferBackend.FAKE
):
error_msg = (
@@ -1949,6 +1958,12 @@ class Scheduler(
)
req.time_stats.set_prefill_bootstrap_queue_entry_time()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
if self._should_allow_local_prefill_on_decode(req):
if not self._set_or_validate_priority(req):
return
self.decode_direct_waiting_queue.append(req)
req.time_stats.set_wait_queue_entry_time()
return
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
if not is_retracted:
req.time_stats.set_decode_prealloc_queue_entry_time()
@@ -1957,6 +1972,88 @@ class Scheduler(
else:
raise ValueError(f"Invalid {self.disaggregation_mode=}")
def _should_allow_local_prefill_on_decode(self, req: Req) -> bool:
return (
self.disaggregation_mode == DisaggregationMode.DECODE
and self.server_args.disaggregation_decode_allow_local_prefill
and req.bootstrap_room is None
)
def _decode_session_cache_low_watermark_tokens(self) -> int:
return min(
self.max_total_num_tokens,
max(
self.server_args.num_reserved_decode_tokens * 16,
self.max_total_num_tokens // 12,
16384,
),
)
def _decode_session_cache_target_available_tokens(self) -> int:
return min(
self.max_total_num_tokens,
max(
self._decode_session_cache_low_watermark_tokens(),
self.server_args.num_reserved_decode_tokens * 24,
self.max_total_num_tokens // 8,
24576,
),
)
def maybe_trim_decode_session_cache(
self,
required_tokens: int = 0,
force: bool = False,
max_sessions: Optional[int] = None,
exclude_session_ids: Optional[set[str]] = None,
) -> Dict[str, int]:
if (
self.disaggregation_mode != DisaggregationMode.DECODE
or not isinstance(self.tree_cache, SessionAwareCache)
):
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
}
available_tokens = self.token_to_kv_pool_allocator.available_size()
low_watermark_tokens = self._decode_session_cache_low_watermark_tokens()
target_available_tokens = self._decode_session_cache_target_available_tokens()
min_available_tokens = max(
low_watermark_tokens,
available_tokens + max(0, required_tokens),
)
target_available_tokens = max(
target_available_tokens,
min_available_tokens,
)
if not force and available_tokens >= low_watermark_tokens:
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": int(available_tokens),
"available_tokens_after": int(available_tokens),
}
result = self.session_controller.evict_idle_streaming_sessions_lru(
required_tokens=required_tokens,
min_available_tokens=min_available_tokens,
target_available_tokens=target_available_tokens,
max_sessions=max_sessions,
exclude_session_ids=exclude_session_ids,
)
if result["evicted_session_count"] > 0:
logger.info(
"Trimmed decode session cache via LRU. "
f"#evicted_sessions: {result['evicted_session_count']}, "
f"#freed_tokens: {result['freed_tokens']}, "
f"#available_tokens: {result['available_tokens_before']} -> "
f"{result['available_tokens_after']}"
)
return result
def _set_or_validate_priority(self, req: Req) -> bool:
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None:
@@ -2558,8 +2655,19 @@ class Scheduler(
if self.enable_hierarchical_cache:
self.tree_cache.flush_write_through_acks()
# Check if decode out of memory
if (kv_full_retract_flag := not batch.check_decode_mem()) or (
# Check if decode out of memory. Before retracting active decode work,
# let the worker evict idle streaming-session KV held outside the radix tree.
kv_full_retract_flag = not batch.check_decode_mem()
idle_session_eviction = None
if kv_full_retract_flag:
idle_session_eviction = self.maybe_trim_decode_session_cache(
required_tokens=max(1, self.server_args.num_reserved_decode_tokens),
force=True,
)
if idle_session_eviction["freed_tokens"] > 0:
kv_full_retract_flag = not batch.check_decode_mem()
if kv_full_retract_flag or (
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
):
old_available_tokens = self.token_to_kv_pool_allocator.available_size()
@@ -2602,6 +2710,13 @@ class Scheduler(
msg_details += (
f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
)
if idle_session_eviction is not None:
msg_details += (
", #idle_session_evicted: "
f"{idle_session_eviction['evicted_session_count']}"
", #idle_tokens_freed: "
f"{idle_session_eviction['freed_tokens']}"
)
logger.warning(msg_prefix + msg_details)
for req in retracted_reqs:
@@ -2916,6 +3031,7 @@ class Scheduler(
# Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode)
idle &= len(self.waiting_queue) == 0
idle &= len(self.decode_direct_waiting_queue) == 0
if not for_health_check:
# Grammar queue and prefill inflight queue may not produce batch
@@ -3077,6 +3193,9 @@ class Scheduler(
"graph": round(self.tp_worker.model_runner.graph_mem_usage, 2),
}
ret["effective_max_running_requests_per_dp"] = self.max_running_requests
ret["session_cache"] = (
self.session_controller.get_streaming_session_cache_status()
)
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
ret["avg_spec_accept_length"] = (
@@ -3375,6 +3494,86 @@ class Scheduler(
def close_session(self, recv_req: CloseSessionReqInput):
self.session_controller.close(recv_req)
def admit_direct_append(
self, recv_req: DirectAppendAdmissionReqInput
) -> DirectAppendAdmissionReqOutput:
if (
self.disaggregation_mode != DisaggregationMode.DECODE
or not self.server_args.disaggregation_decode_allow_local_prefill
or not isinstance(self.tree_cache, SessionAwareCache)
):
return DirectAppendAdmissionReqOutput(
can_admit=False,
resident=False,
reason="unsupported",
)
session_cache_status = self.session_controller.get_streaming_session_cache_status(
recv_req.session_id
)
target_session = session_cache_status.get("target_session")
resident = bool(
isinstance(target_session, dict) and target_session.get("resident")
)
if not resident:
return DirectAppendAdmissionReqOutput(
can_admit=False,
resident=False,
reason="session-not-resident",
available_tokens_before=int(
self.token_to_kv_pool_allocator.available_size()
),
available_tokens_after=int(
self.token_to_kv_pool_allocator.available_size()
),
token_usage=(
1.0
- self.token_to_kv_pool_allocator.available_size()
/ max(1, self.max_total_num_tokens)
),
num_running_reqs=len(self.running_batch.reqs),
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
decode_retracted_queue_reqs=len(
self.disagg_decode_prealloc_queue.retracted_queue
),
)
required_tokens = max(0, recv_req.uncached_input_tokens) + max(
0, recv_req.output_tokens
)
available_tokens_before = int(self.token_to_kv_pool_allocator.available_size())
trim_result = self.maybe_trim_decode_session_cache(
required_tokens=required_tokens,
force=available_tokens_before < required_tokens,
exclude_session_ids={recv_req.session_id},
)
available_tokens_after = int(self.token_to_kv_pool_allocator.available_size())
decode_retracted_queue_reqs = len(self.disagg_decode_prealloc_queue.retracted_queue)
can_admit = (
decode_retracted_queue_reqs == 0
and available_tokens_after >= required_tokens
)
reason = None if can_admit else "no-space"
return DirectAppendAdmissionReqOutput(
can_admit=can_admit,
resident=True,
reason=reason,
required_tokens=int(required_tokens),
available_tokens_before=available_tokens_before,
available_tokens_after=available_tokens_after,
evicted_session_count=int(trim_result["evicted_session_count"]),
freed_tokens=int(trim_result["freed_tokens"]),
token_usage=(
1.0 - available_tokens_after / max(1, self.max_total_num_tokens)
),
num_running_reqs=len(self.running_batch.reqs),
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
decode_retracted_queue_reqs=decode_retracted_queue_reqs,
)
def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep()

View File

@@ -229,6 +229,12 @@ class Session:
priority=req.priority,
routing_key=req.routing_key,
http_worker_ipc=req.http_worker_ipc,
bootstrap_host=req.bootstrap_host,
bootstrap_port=req.bootstrap_port,
bootstrap_room=req.bootstrap_room,
routed_dp_rank=req.routed_dp_rank,
disagg_prefill_dp_rank=req.disagg_prefill_dp_rank,
extra_key=req.extra_key,
time_stats=req.time_stats,
)
if last_req is not None:
@@ -284,7 +290,7 @@ class SessionController:
else:
self._close(session_id)
def _close(self, session_id: str):
def _close(self, session_id: str) -> int:
session = self.sessions[session_id]
if session.streaming and session.req_nodes:
assert len(session.req_nodes) == 1
@@ -303,9 +309,166 @@ class SessionController:
mm.release_features()
node.req.multimodal_inputs = None
freed_tokens = 0
if isinstance(self.tree_cache, SessionAwareCache):
self.tree_cache.release_session(session_id)
freed_tokens = self.tree_cache.release_session(session_id)
del self.sessions[session_id]
return freed_tokens
def _is_idle_streaming_session(self, session: Session) -> bool:
if not session.streaming:
return False
if not session.req_nodes:
return False
return all(node.req.finished() for node in session.req_nodes.values())
def get_streaming_session_cache_status(
self, session_id: Optional[str] = None
) -> Dict[str, object]:
if not isinstance(self.tree_cache, SessionAwareCache):
return {
"enabled": False,
"eviction_policy": None,
"session_count": 0,
"resident_session_count": 0,
"held_tokens": 0,
"available_tokens": 0,
"capacity_tokens": 0,
"idle_evictable_session_count": 0,
"idle_evictable_tokens": 0,
"sessions": [],
"target_session": None,
}
all_statuses = self.tree_cache.list_session_statuses()
session_map = {status["session_id"]: status for status in all_statuses}
idle_evictable_tokens = 0
idle_evictable_session_count = 0
sessions = []
for status in all_statuses:
session = self.sessions.get(status["session_id"])
idle_evictable = session is not None and self._is_idle_streaming_session(
session
)
if idle_evictable and status["resident"]:
idle_evictable_session_count += 1
idle_evictable_tokens += int(status["resident_tokens"])
if session_id is None or status["session_id"] == session_id:
sessions.append(
{
**status,
"streaming": bool(session.streaming) if session is not None else True,
"idle_evictable": idle_evictable,
"timed_out": session.is_timed_out() if session is not None else False,
}
)
target_session = session_map.get(session_id) if session_id is not None else None
return {
"enabled": True,
"eviction_policy": "lru",
"session_count": len(all_statuses),
"resident_session_count": sum(1 for status in all_statuses if status["resident"]),
"held_tokens": int(self.tree_cache.session_held_tokens()),
"available_tokens": int(self.tree_cache.token_to_kv_pool_allocator.available_size()),
"capacity_tokens": int(self.tree_cache.token_to_kv_pool_allocator.size),
"idle_evictable_session_count": idle_evictable_session_count,
"idle_evictable_tokens": int(idle_evictable_tokens),
"sessions": sessions,
"target_session": (
{
**target_session,
"idle_evictable": (
self._is_idle_streaming_session(self.sessions[target_session["session_id"]])
if target_session is not None
and target_session["session_id"] in self.sessions
else False
),
}
if target_session is not None
else None
),
}
def evict_idle_streaming_sessions(
self,
required_tokens: int = 0,
max_sessions: Optional[int] = None,
) -> Dict[str, int]:
return self.evict_idle_streaming_sessions_lru(
required_tokens=required_tokens,
max_sessions=max_sessions,
)
def evict_idle_streaming_sessions_lru(
self,
required_tokens: int = 0,
min_available_tokens: int = 0,
target_available_tokens: Optional[int] = None,
max_sessions: Optional[int] = None,
exclude_session_ids: Optional[set[str]] = None,
) -> Dict[str, int]:
if not isinstance(self.tree_cache, SessionAwareCache):
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
}
available_tokens_before = int(
self.tree_cache.token_to_kv_pool_allocator.available_size()
)
evicted_session_count = 0
freed_tokens = 0
target_available_tokens = max(
int(min_available_tokens),
int(target_available_tokens or 0),
)
candidates = []
for status in self.tree_cache.list_session_statuses():
if exclude_session_ids and status["session_id"] in exclude_session_ids:
continue
session = self.sessions.get(status["session_id"])
if session is None or not self._is_idle_streaming_session(session):
continue
if not status["resident"]:
continue
candidates.append(status)
for status in candidates:
if max_sessions is not None and evicted_session_count >= max_sessions:
break
available_tokens_now = available_tokens_before + freed_tokens
if (
required_tokens > 0
and freed_tokens >= required_tokens
and available_tokens_now >= min_available_tokens
):
break
if (
required_tokens <= 0
and min_available_tokens <= 0
and target_available_tokens <= 0
):
break
if (
required_tokens <= 0
and available_tokens_now >= min_available_tokens
and available_tokens_now >= target_available_tokens
):
break
freed = self._close(status["session_id"])
evicted_session_count += 1
freed_tokens += int(freed)
return {
"evicted_session_count": evicted_session_count,
"freed_tokens": freed_tokens,
"available_tokens_before": available_tokens_before,
"available_tokens_after": available_tokens_before + freed_tokens,
}
def maybe_reap(self, now: float, interval: float = 1.0):
# reap sessions every second

View File

@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput,
DirectAppendAdmissionReqInput,
DirectAppendAdmissionReqOutput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
DetachHiCacheStorageReqInput,
@@ -220,6 +222,9 @@ class TokenizerCommunicatorMixin:
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.direct_append_admission_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -316,6 +321,10 @@ class TokenizerCommunicatorMixin:
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
DirectAppendAdmissionReqOutput,
self.direct_append_admission_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
@@ -871,6 +880,16 @@ class TokenizerCommunicatorMixin:
# Many DP ranks
return [res.internal_state for res in responses]
async def admit_direct_append(
self: TokenizerManager,
obj: DirectAppendAdmissionReqInput,
) -> DirectAppendAdmissionReqOutput:
self.auto_create_handle_loop()
responses: List[DirectAppendAdmissionReqOutput] = (
await self.direct_append_admission_communicator(obj)
)
return responses[0]
async def set_internal_state(
self: TokenizerManager, obj: SetInternalStateReq
) -> List[bool]:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional
@@ -57,6 +58,7 @@ class SessionSlot:
mamba_next_track_idx: Any = None
mamba_last_track_seqlen: Any = None
mamba_branching_seqlen: Any = None
last_access_time: float = field(default_factory=time.monotonic)
@property
def is_holding_kv(self) -> bool:
@@ -80,6 +82,7 @@ class SessionSlot:
self.mamba_next_track_idx = req.mamba_next_track_idx
self.mamba_last_track_seqlen = req.mamba_last_track_seqlen
self.mamba_branching_seqlen = req.mamba_branching_seqlen
self.last_access_time = time.monotonic()
req.req_pool_idx = None
req.mamba_pool_idx = None
@@ -97,6 +100,7 @@ class SessionSlot:
req.mamba_next_track_idx = self.mamba_next_track_idx
req.mamba_last_track_seqlen = self.mamba_last_track_seqlen
req.mamba_branching_seqlen = self.mamba_branching_seqlen
self.last_access_time = time.monotonic()
# NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared
# from the slot. During chunked prefill, a request may be rejected by
@@ -247,7 +251,7 @@ class SessionAwareCache(BasePrefixCache):
"""Release all KV resources held by a streaming session."""
slot = self.slots.pop(session_id, None)
if slot is None:
return
return 0
if slot.last_node is not None:
if slot.swa_uuid_for_lock is not None:
@@ -261,20 +265,58 @@ class SessionAwareCache(BasePrefixCache):
if slot.is_holding_kv:
start = slot.cache_protected_len
end = slot.kv_allocated_len
freed_tokens = max(0, end - start)
if start < end:
kv_indices = self.req_to_token_pool.req_to_token[
slot.req_pool_idx, start:end
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free_slots.append(slot.req_pool_idx)
return freed_tokens
return 0
def slot_held_tokens(self, slot: SessionSlot) -> int:
if not slot.is_holding_kv:
return 0
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
return max(0, allocated - slot.cache_protected_len)
def get_session_status(self, session_id: str) -> Optional[Dict[str, Any]]:
slot = self.slots.get(session_id)
if slot is None:
return None
return {
"session_id": session_id,
"resident": slot.is_holding_kv,
"resident_tokens": int(self.slot_held_tokens(slot)),
"kv_committed_len": int(slot.kv_committed_len),
"kv_allocated_len": int(slot.kv_allocated_len),
"cache_protected_len": int(slot.cache_protected_len),
"last_access_time": float(slot.last_access_time),
}
def list_session_statuses(self) -> list[Dict[str, Any]]:
statuses = []
for session_id, slot in self.slots.items():
statuses.append(
{
"session_id": session_id,
"resident": slot.is_holding_kv,
"resident_tokens": int(self.slot_held_tokens(slot)),
"kv_committed_len": int(slot.kv_committed_len),
"kv_allocated_len": int(slot.kv_allocated_len),
"cache_protected_len": int(slot.cache_protected_len),
"last_access_time": float(slot.last_access_time),
}
)
statuses.sort(key=lambda item: item["last_access_time"])
return statuses
def session_held_tokens(self) -> int:
"""Total KV tokens held by session slots, not tracked by the tree."""
total = 0
for slot in self.slots.values():
if slot.is_holding_kv:
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
total += allocated - slot.cache_protected_len
total += self.slot_held_tokens(slot)
return total
def session_held_full_tokens(self) -> int:

View File

@@ -697,6 +697,7 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998
disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_offload_kvcache: bool = False
disaggregation_decode_allow_local_prefill: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
@@ -5772,6 +5773,14 @@ class ServerArgs:
action="store_true",
help="Enable async KV cache offloading on decode server (PD mode).",
)
parser.add_argument(
"--disaggregation-decode-allow-local-prefill",
action="store_true",
help=(
"Allow decode workers in PD mode to accept direct local requests "
"without bootstrap metadata and run local append-prefill."
),
)
parser.add_argument(
"--num-reserved-decode-tokens",
type=int,