feat(sglang): support decode session cache admission

This commit is contained in:
2026-04-24 12:30:41 +00:00
parent bded08301f
commit b8e6f13c20
9 changed files with 589 additions and 14 deletions

View File

@@ -140,9 +140,9 @@ class DecodeReqToTokenPool:
# Indices of reqs that already have a req_pool_idx and will reuse # Indices of reqs that already have a req_pool_idx and will reuse
# their existing slot (e.g. chunked prefill continuing across chunks). # their existing slot (e.g. chunked prefill continuing across chunks).
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None] reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
assert ( # Keep decode-side allocation behavior aligned with the shared pool:
len(reusing) <= 1 # local append-prefill can legitimately batch multiple resident-session
), "only one chunked request may reuse req_pool_idx in a batch" # requests that each reuse their existing req_pool_idx.
assert all( assert all(
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing
), "reusing request must be chunked or have committed KV" ), "reusing request must be chunked or have committed KV"
@@ -657,6 +657,12 @@ class DecodePreallocQueue:
allocatable_tokens = self._allocatable_tokens( allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True retractable_tokens=retractable_tokens, count_retracted=True
) )
def refresh_allocatable_tokens() -> int:
return self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue # First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue): for i, decode_req in enumerate(self.queue):
if rids_to_check is not None and decode_req.req.rid not in rids_to_check: if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
@@ -680,7 +686,14 @@ class DecodePreallocQueue:
continue continue
if self.req_to_token_pool.available_size() <= 0: if self.req_to_token_pool.available_size() <= 0:
break self.scheduler.maybe_trim_decode_session_cache(
required_tokens=max(1, self.num_reserved_decode_tokens),
force=True,
max_sessions=1,
)
allocatable_tokens = refresh_allocatable_tokens()
if self.req_to_token_pool.available_size() <= 0:
break
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break break
@@ -704,9 +717,46 @@ class DecodePreallocQueue:
) )
> allocatable_tokens > allocatable_tokens
): ):
break required_tokens = max(
0,
max(
required_tokens_for_request,
origin_input_len
+ min(
decode_req.req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKEN,
)
- retractable_tokens,
)
- allocatable_tokens,
)
if required_tokens > 0:
self.scheduler.maybe_trim_decode_session_cache(
required_tokens=required_tokens,
force=True,
)
allocatable_tokens = refresh_allocatable_tokens()
if (
max(
required_tokens_for_request,
origin_input_len
+ min(
decode_req.req.sampling_params.max_new_tokens,
CLIP_MAX_NEW_TOKEN,
)
- retractable_tokens,
)
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens: if required_tokens_for_request > allocatable_tokens:
break self.scheduler.maybe_trim_decode_session_cache(
required_tokens=required_tokens_for_request - allocatable_tokens,
force=True,
)
allocatable_tokens = refresh_allocatable_tokens()
if required_tokens_for_request > allocatable_tokens:
break
allocatable_tokens -= required_tokens_for_request allocatable_tokens -= required_tokens_for_request
dst_kv_indices = self._pre_alloc(decode_req.req) dst_kv_indices = self._pre_alloc(decode_req.req)
@@ -1135,6 +1185,27 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin:
def _merge_last_local_prefill_batch(self: Scheduler):
if self.last_batch is None or not self.last_batch.forward_mode.is_extend():
return
chunked_req_to_exclude = set()
if self.last_batch.chunked_req is not None:
chunked_req_to_exclude.add(self.last_batch.chunked_req)
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False
if not self.last_batch.is_empty():
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler): def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode.""" """A normal scheduler loop for decode worker in disaggregation mode."""
@@ -1213,6 +1284,8 @@ class SchedulerDisaggregationDecodeMixin:
self: Scheduler, self: Scheduler,
) -> Optional[ScheduleBatch]: ) -> Optional[ScheduleBatch]:
"""Process prebuilt batch and schedule the next decode batch.""" """Process prebuilt batch and schedule the next decode batch."""
self._merge_last_local_prefill_batch()
# Process pending prebuilt batch: output processing + filter + merge # Process pending prebuilt batch: output processing + filter + merge
new_prebuilt_batch = self.get_new_prebuilt_batch() new_prebuilt_batch = self.get_new_prebuilt_batch()
if new_prebuilt_batch: if new_prebuilt_batch:
@@ -1229,6 +1302,13 @@ class SchedulerDisaggregationDecodeMixin:
else: else:
self.running_batch.merge_batch(new_prebuilt_batch) self.running_batch.merge_batch(new_prebuilt_batch)
new_local_batch = self.get_new_local_extend_batch()
if new_local_batch is not None:
ret = self.maybe_prepare_mlp_sync_batch(new_local_batch)
if ret:
set_schedule_time_batch(ret)
return ret
# Schedule decode batch # Schedule decode batch
if self.running_batch.is_empty(): if self.running_batch.is_empty():
ret = None ret = None
@@ -1241,6 +1321,22 @@ class SchedulerDisaggregationDecodeMixin:
set_schedule_time_batch(ret) set_schedule_time_batch(ret)
return ret return ret
def get_new_local_extend_batch(self: Scheduler) -> Optional[ScheduleBatch]:
if not self.server_args.disaggregation_decode_allow_local_prefill:
return None
if len(self.decode_direct_waiting_queue) == 0:
return None
original_waiting_queue = self.waiting_queue
try:
self.waiting_queue = self.decode_direct_waiting_queue
new_batch = self.get_new_batch_prefill()
self.decode_direct_waiting_queue = self.waiting_queue
finally:
self.waiting_queue = original_waiting_queue
return new_batch
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill""" """Create a schedulebatch for fake completed prefill"""
if self.grammar_manager.has_waiting_grammars(): if self.grammar_manager.has_waiting_grammars():

View File

@@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterFromTensorsReqInput, LoadLoRAAdapterFromTensorsReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
DirectAppendAdmissionReqInput,
OpenSessionReqInput, OpenSessionReqInput,
ParseFunctionCallReq, ParseFunctionCallReq,
PauseGenerationReqInput, PauseGenerationReqInput,
@@ -1289,6 +1290,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.post("/session_cache/admit_direct_append")
async def admit_direct_append(obj: DirectAppendAdmissionReqInput):
return await _global_state.tokenizer_manager.admit_direct_append(obj)
@app.api_route("/configure_logging", methods=["GET", "POST"]) @app.api_route("/configure_logging", methods=["GET", "POST"])
@auth_level(AuthLevel.ADMIN_OPTIONAL) @auth_level(AuthLevel.ADMIN_OPTIONAL)
async def configure_logging(obj: ConfigureLoggingReq, request: Request): async def configure_logging(obj: ConfigureLoggingReq, request: Request):

View File

@@ -1597,6 +1597,30 @@ class OpenSessionReqOutput(BaseReq):
success: bool success: bool
@dataclass
class DirectAppendAdmissionReqInput(BaseReq):
session_id: str
uncached_input_tokens: int
output_tokens: int
@dataclass
class DirectAppendAdmissionReqOutput(BaseReq):
can_admit: bool
resident: bool
reason: Optional[str] = None
required_tokens: int = 0
available_tokens_before: int = 0
available_tokens_after: int = 0
evicted_session_count: int = 0
freed_tokens: int = 0
token_usage: float = 0.0
num_running_reqs: int = 0
decode_prealloc_queue_reqs: int = 0
decode_transfer_queue_reqs: int = 0
decode_retracted_queue_reqs: int = 0
@dataclass @dataclass
class HealthCheckOutput(BaseReq): class HealthCheckOutput(BaseReq):
pass pass

View File

@@ -1566,6 +1566,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Init tensors # Init tensors
reqs = self.reqs reqs = self.reqs
for req in reqs:
if req.session is None or not req.session.streaming:
continue
actual_extend_len = max(0, len(req.fill_ids) - len(req.prefix_indices))
if req.extend_input_len != actual_extend_len:
logger.warning(
"Correcting streaming-session extend_input_len from %d to %d "
"(rid=%s, session_id=%s, fill_len=%d, prefix_len=%d, kv_committed_len=%d)",
req.extend_input_len,
actual_extend_len,
req.rid,
req.session.session_id,
len(req.fill_ids),
len(req.prefix_indices),
req.kv_committed_len,
)
req.set_extend_input_len(actual_extend_len)
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids) extend_num_tokens = sum(len(ids) for ids in input_ids)
seq_lens = [len(r.fill_ids) for r in reqs] seq_lens = [len(r.fill_ids) for r in reqs]

View File

@@ -94,6 +94,8 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqOutput, ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
ContinueGenerationReqInput, ContinueGenerationReqInput,
DirectAppendAdmissionReqInput,
DirectAppendAdmissionReqOutput,
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
DetachHiCacheStorageReqInput, DetachHiCacheStorageReqInput,
DetachHiCacheStorageReqOutput, DetachHiCacheStorageReqOutput,
@@ -844,6 +846,7 @@ class Scheduler(
def init_running_status(self): def init_running_status(self):
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.decode_direct_waiting_queue: List[Req] = []
# The running decoding batch for continuous batching # The running decoding batch for continuous batching
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch # The current forward batch
@@ -1215,6 +1218,7 @@ class Scheduler(
(AbortReq, self.abort_request), (AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session), (OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session), (CloseSessionReqInput, self.close_session),
(DirectAppendAdmissionReqInput, self.admit_direct_append),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group), (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
@@ -1589,6 +1593,7 @@ class Scheduler(
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
now = time.monotonic() now = time.monotonic()
self.session_controller.maybe_reap(now) self.session_controller.maybe_reap(now)
self.maybe_trim_decode_session_cache()
for recv_req in recv_reqs: for recv_req in recv_reqs:
# Skip health check when server is busy — ongoing requests already carry health info. # Skip health check when server is busy — ongoing requests already carry health info.
if is_health_check_generate_req(recv_req) and not self.is_fully_idle( if is_health_check_generate_req(recv_req) and not self.is_fully_idle(
@@ -1781,6 +1786,10 @@ class Scheduler(
# Invalid request for disaggregated mode # Invalid request for disaggregated mode
if ( if (
recv_req.bootstrap_room is None recv_req.bootstrap_room is None
and not (
self.disaggregation_mode == DisaggregationMode.DECODE
and self.server_args.disaggregation_decode_allow_local_prefill
)
and self.transfer_backend != TransferBackend.FAKE and self.transfer_backend != TransferBackend.FAKE
): ):
error_msg = ( error_msg = (
@@ -1949,6 +1958,12 @@ class Scheduler(
) )
req.time_stats.set_prefill_bootstrap_queue_entry_time() req.time_stats.set_prefill_bootstrap_queue_entry_time()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
if self._should_allow_local_prefill_on_decode(req):
if not self._set_or_validate_priority(req):
return
self.decode_direct_waiting_queue.append(req)
req.time_stats.set_wait_queue_entry_time()
return
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted) self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
if not is_retracted: if not is_retracted:
req.time_stats.set_decode_prealloc_queue_entry_time() req.time_stats.set_decode_prealloc_queue_entry_time()
@@ -1957,6 +1972,88 @@ class Scheduler(
else: else:
raise ValueError(f"Invalid {self.disaggregation_mode=}") raise ValueError(f"Invalid {self.disaggregation_mode=}")
def _should_allow_local_prefill_on_decode(self, req: Req) -> bool:
return (
self.disaggregation_mode == DisaggregationMode.DECODE
and self.server_args.disaggregation_decode_allow_local_prefill
and req.bootstrap_room is None
)
def _decode_session_cache_low_watermark_tokens(self) -> int:
return min(
self.max_total_num_tokens,
max(
self.server_args.num_reserved_decode_tokens * 16,
self.max_total_num_tokens // 12,
16384,
),
)
def _decode_session_cache_target_available_tokens(self) -> int:
return min(
self.max_total_num_tokens,
max(
self._decode_session_cache_low_watermark_tokens(),
self.server_args.num_reserved_decode_tokens * 24,
self.max_total_num_tokens // 8,
24576,
),
)
def maybe_trim_decode_session_cache(
self,
required_tokens: int = 0,
force: bool = False,
max_sessions: Optional[int] = None,
exclude_session_ids: Optional[set[str]] = None,
) -> Dict[str, int]:
if (
self.disaggregation_mode != DisaggregationMode.DECODE
or not isinstance(self.tree_cache, SessionAwareCache)
):
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
}
available_tokens = self.token_to_kv_pool_allocator.available_size()
low_watermark_tokens = self._decode_session_cache_low_watermark_tokens()
target_available_tokens = self._decode_session_cache_target_available_tokens()
min_available_tokens = max(
low_watermark_tokens,
available_tokens + max(0, required_tokens),
)
target_available_tokens = max(
target_available_tokens,
min_available_tokens,
)
if not force and available_tokens >= low_watermark_tokens:
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": int(available_tokens),
"available_tokens_after": int(available_tokens),
}
result = self.session_controller.evict_idle_streaming_sessions_lru(
required_tokens=required_tokens,
min_available_tokens=min_available_tokens,
target_available_tokens=target_available_tokens,
max_sessions=max_sessions,
exclude_session_ids=exclude_session_ids,
)
if result["evicted_session_count"] > 0:
logger.info(
"Trimmed decode session cache via LRU. "
f"#evicted_sessions: {result['evicted_session_count']}, "
f"#freed_tokens: {result['freed_tokens']}, "
f"#available_tokens: {result['available_tokens_before']} -> "
f"{result['available_tokens_after']}"
)
return result
def _set_or_validate_priority(self, req: Req) -> bool: def _set_or_validate_priority(self, req: Req) -> bool:
"""Set the default priority value, or abort the request based on the priority scheduling mode.""" """Set the default priority value, or abort the request based on the priority scheduling mode."""
if self.enable_priority_scheduling and req.priority is None: if self.enable_priority_scheduling and req.priority is None:
@@ -2558,8 +2655,19 @@ class Scheduler(
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
self.tree_cache.flush_write_through_acks() self.tree_cache.flush_write_through_acks()
# Check if decode out of memory # Check if decode out of memory. Before retracting active decode work,
if (kv_full_retract_flag := not batch.check_decode_mem()) or ( # let the worker evict idle streaming-session KV held outside the radix tree.
kv_full_retract_flag = not batch.check_decode_mem()
idle_session_eviction = None
if kv_full_retract_flag:
idle_session_eviction = self.maybe_trim_decode_session_cache(
required_tokens=max(1, self.server_args.num_reserved_decode_tokens),
force=True,
)
if idle_session_eviction["freed_tokens"] > 0:
kv_full_retract_flag = not batch.check_decode_mem()
if kv_full_retract_flag or (
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
): ):
old_available_tokens = self.token_to_kv_pool_allocator.available_size() old_available_tokens = self.token_to_kv_pool_allocator.available_size()
@@ -2602,6 +2710,13 @@ class Scheduler(
msg_details += ( msg_details += (
f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}" f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
) )
if idle_session_eviction is not None:
msg_details += (
", #idle_session_evicted: "
f"{idle_session_eviction['evicted_session_count']}"
", #idle_tokens_freed: "
f"{idle_session_eviction['freed_tokens']}"
)
logger.warning(msg_prefix + msg_details) logger.warning(msg_prefix + msg_details)
for req in retracted_reqs: for req in retracted_reqs:
@@ -2916,6 +3031,7 @@ class Scheduler(
# Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode) # Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode)
idle &= len(self.waiting_queue) == 0 idle &= len(self.waiting_queue) == 0
idle &= len(self.decode_direct_waiting_queue) == 0
if not for_health_check: if not for_health_check:
# Grammar queue and prefill inflight queue may not produce batch # Grammar queue and prefill inflight queue may not produce batch
@@ -3077,6 +3193,9 @@ class Scheduler(
"graph": round(self.tp_worker.model_runner.graph_mem_usage, 2), "graph": round(self.tp_worker.model_runner.graph_mem_usage, 2),
} }
ret["effective_max_running_requests_per_dp"] = self.max_running_requests ret["effective_max_running_requests_per_dp"] = self.max_running_requests
ret["session_cache"] = (
self.session_controller.get_streaming_session_cache_status()
)
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
ret["avg_spec_accept_length"] = ( ret["avg_spec_accept_length"] = (
@@ -3375,6 +3494,86 @@ class Scheduler(
def close_session(self, recv_req: CloseSessionReqInput): def close_session(self, recv_req: CloseSessionReqInput):
self.session_controller.close(recv_req) self.session_controller.close(recv_req)
def admit_direct_append(
self, recv_req: DirectAppendAdmissionReqInput
) -> DirectAppendAdmissionReqOutput:
if (
self.disaggregation_mode != DisaggregationMode.DECODE
or not self.server_args.disaggregation_decode_allow_local_prefill
or not isinstance(self.tree_cache, SessionAwareCache)
):
return DirectAppendAdmissionReqOutput(
can_admit=False,
resident=False,
reason="unsupported",
)
session_cache_status = self.session_controller.get_streaming_session_cache_status(
recv_req.session_id
)
target_session = session_cache_status.get("target_session")
resident = bool(
isinstance(target_session, dict) and target_session.get("resident")
)
if not resident:
return DirectAppendAdmissionReqOutput(
can_admit=False,
resident=False,
reason="session-not-resident",
available_tokens_before=int(
self.token_to_kv_pool_allocator.available_size()
),
available_tokens_after=int(
self.token_to_kv_pool_allocator.available_size()
),
token_usage=(
1.0
- self.token_to_kv_pool_allocator.available_size()
/ max(1, self.max_total_num_tokens)
),
num_running_reqs=len(self.running_batch.reqs),
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
decode_retracted_queue_reqs=len(
self.disagg_decode_prealloc_queue.retracted_queue
),
)
required_tokens = max(0, recv_req.uncached_input_tokens) + max(
0, recv_req.output_tokens
)
available_tokens_before = int(self.token_to_kv_pool_allocator.available_size())
trim_result = self.maybe_trim_decode_session_cache(
required_tokens=required_tokens,
force=available_tokens_before < required_tokens,
exclude_session_ids={recv_req.session_id},
)
available_tokens_after = int(self.token_to_kv_pool_allocator.available_size())
decode_retracted_queue_reqs = len(self.disagg_decode_prealloc_queue.retracted_queue)
can_admit = (
decode_retracted_queue_reqs == 0
and available_tokens_after >= required_tokens
)
reason = None if can_admit else "no-space"
return DirectAppendAdmissionReqOutput(
can_admit=can_admit,
resident=True,
reason=reason,
required_tokens=int(required_tokens),
available_tokens_before=available_tokens_before,
available_tokens_after=available_tokens_after,
evicted_session_count=int(trim_result["evicted_session_count"]),
freed_tokens=int(trim_result["freed_tokens"]),
token_usage=(
1.0 - available_tokens_after / max(1, self.max_total_num_tokens)
),
num_running_reqs=len(self.running_batch.reqs),
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
decode_retracted_queue_reqs=decode_retracted_queue_reqs,
)
def maybe_sleep_on_idle(self): def maybe_sleep_on_idle(self):
if self.idle_sleeper is not None: if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep() self.idle_sleeper.maybe_sleep()

View File

@@ -229,6 +229,12 @@ class Session:
priority=req.priority, priority=req.priority,
routing_key=req.routing_key, routing_key=req.routing_key,
http_worker_ipc=req.http_worker_ipc, http_worker_ipc=req.http_worker_ipc,
bootstrap_host=req.bootstrap_host,
bootstrap_port=req.bootstrap_port,
bootstrap_room=req.bootstrap_room,
routed_dp_rank=req.routed_dp_rank,
disagg_prefill_dp_rank=req.disagg_prefill_dp_rank,
extra_key=req.extra_key,
time_stats=req.time_stats, time_stats=req.time_stats,
) )
if last_req is not None: if last_req is not None:
@@ -284,7 +290,7 @@ class SessionController:
else: else:
self._close(session_id) self._close(session_id)
def _close(self, session_id: str): def _close(self, session_id: str) -> int:
session = self.sessions[session_id] session = self.sessions[session_id]
if session.streaming and session.req_nodes: if session.streaming and session.req_nodes:
assert len(session.req_nodes) == 1 assert len(session.req_nodes) == 1
@@ -303,9 +309,166 @@ class SessionController:
mm.release_features() mm.release_features()
node.req.multimodal_inputs = None node.req.multimodal_inputs = None
freed_tokens = 0
if isinstance(self.tree_cache, SessionAwareCache): if isinstance(self.tree_cache, SessionAwareCache):
self.tree_cache.release_session(session_id) freed_tokens = self.tree_cache.release_session(session_id)
del self.sessions[session_id] del self.sessions[session_id]
return freed_tokens
def _is_idle_streaming_session(self, session: Session) -> bool:
if not session.streaming:
return False
if not session.req_nodes:
return False
return all(node.req.finished() for node in session.req_nodes.values())
def get_streaming_session_cache_status(
self, session_id: Optional[str] = None
) -> Dict[str, object]:
if not isinstance(self.tree_cache, SessionAwareCache):
return {
"enabled": False,
"eviction_policy": None,
"session_count": 0,
"resident_session_count": 0,
"held_tokens": 0,
"available_tokens": 0,
"capacity_tokens": 0,
"idle_evictable_session_count": 0,
"idle_evictable_tokens": 0,
"sessions": [],
"target_session": None,
}
all_statuses = self.tree_cache.list_session_statuses()
session_map = {status["session_id"]: status for status in all_statuses}
idle_evictable_tokens = 0
idle_evictable_session_count = 0
sessions = []
for status in all_statuses:
session = self.sessions.get(status["session_id"])
idle_evictable = session is not None and self._is_idle_streaming_session(
session
)
if idle_evictable and status["resident"]:
idle_evictable_session_count += 1
idle_evictable_tokens += int(status["resident_tokens"])
if session_id is None or status["session_id"] == session_id:
sessions.append(
{
**status,
"streaming": bool(session.streaming) if session is not None else True,
"idle_evictable": idle_evictable,
"timed_out": session.is_timed_out() if session is not None else False,
}
)
target_session = session_map.get(session_id) if session_id is not None else None
return {
"enabled": True,
"eviction_policy": "lru",
"session_count": len(all_statuses),
"resident_session_count": sum(1 for status in all_statuses if status["resident"]),
"held_tokens": int(self.tree_cache.session_held_tokens()),
"available_tokens": int(self.tree_cache.token_to_kv_pool_allocator.available_size()),
"capacity_tokens": int(self.tree_cache.token_to_kv_pool_allocator.size),
"idle_evictable_session_count": idle_evictable_session_count,
"idle_evictable_tokens": int(idle_evictable_tokens),
"sessions": sessions,
"target_session": (
{
**target_session,
"idle_evictable": (
self._is_idle_streaming_session(self.sessions[target_session["session_id"]])
if target_session is not None
and target_session["session_id"] in self.sessions
else False
),
}
if target_session is not None
else None
),
}
def evict_idle_streaming_sessions(
self,
required_tokens: int = 0,
max_sessions: Optional[int] = None,
) -> Dict[str, int]:
return self.evict_idle_streaming_sessions_lru(
required_tokens=required_tokens,
max_sessions=max_sessions,
)
def evict_idle_streaming_sessions_lru(
self,
required_tokens: int = 0,
min_available_tokens: int = 0,
target_available_tokens: Optional[int] = None,
max_sessions: Optional[int] = None,
exclude_session_ids: Optional[set[str]] = None,
) -> Dict[str, int]:
if not isinstance(self.tree_cache, SessionAwareCache):
return {
"evicted_session_count": 0,
"freed_tokens": 0,
"available_tokens_before": 0,
"available_tokens_after": 0,
}
available_tokens_before = int(
self.tree_cache.token_to_kv_pool_allocator.available_size()
)
evicted_session_count = 0
freed_tokens = 0
target_available_tokens = max(
int(min_available_tokens),
int(target_available_tokens or 0),
)
candidates = []
for status in self.tree_cache.list_session_statuses():
if exclude_session_ids and status["session_id"] in exclude_session_ids:
continue
session = self.sessions.get(status["session_id"])
if session is None or not self._is_idle_streaming_session(session):
continue
if not status["resident"]:
continue
candidates.append(status)
for status in candidates:
if max_sessions is not None and evicted_session_count >= max_sessions:
break
available_tokens_now = available_tokens_before + freed_tokens
if (
required_tokens > 0
and freed_tokens >= required_tokens
and available_tokens_now >= min_available_tokens
):
break
if (
required_tokens <= 0
and min_available_tokens <= 0
and target_available_tokens <= 0
):
break
if (
required_tokens <= 0
and available_tokens_now >= min_available_tokens
and available_tokens_now >= target_available_tokens
):
break
freed = self._close(status["session_id"])
evicted_session_count += 1
freed_tokens += int(freed)
return {
"evicted_session_count": evicted_session_count,
"freed_tokens": freed_tokens,
"available_tokens_before": available_tokens_before,
"available_tokens_after": available_tokens_before + freed_tokens,
}
def maybe_reap(self, now: float, interval: float = 1.0): def maybe_reap(self, now: float, interval: float = 1.0):
# reap sessions every second # reap sessions every second

View File

@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput, ClearHiCacheReqInput,
ClearHiCacheReqOutput, ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
DirectAppendAdmissionReqInput,
DirectAppendAdmissionReqOutput,
DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput, DestroyWeightsUpdateGroupReqOutput,
DetachHiCacheStorageReqInput, DetachHiCacheStorageReqInput,
@@ -220,6 +222,9 @@ class TokenizerCommunicatorMixin:
self.get_internal_state_communicator = _Communicator( self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.direct_append_admission_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator( self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
@@ -316,6 +321,10 @@ class TokenizerCommunicatorMixin:
GetInternalStateReqOutput, GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv, self.get_internal_state_communicator.handle_recv,
), ),
(
DirectAppendAdmissionReqOutput,
self.direct_append_admission_communicator.handle_recv,
),
( (
SetInternalStateReqOutput, SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv, self.set_internal_state_communicator.handle_recv,
@@ -871,6 +880,16 @@ class TokenizerCommunicatorMixin:
# Many DP ranks # Many DP ranks
return [res.internal_state for res in responses] return [res.internal_state for res in responses]
async def admit_direct_append(
self: TokenizerManager,
obj: DirectAppendAdmissionReqInput,
) -> DirectAppendAdmissionReqOutput:
self.auto_create_handle_loop()
responses: List[DirectAppendAdmissionReqOutput] = (
await self.direct_append_admission_communicator(obj)
)
return responses[0]
async def set_internal_state( async def set_internal_state(
self: TokenizerManager, obj: SetInternalStateReq self: TokenizerManager, obj: SetInternalStateReq
) -> List[bool]: ) -> List[bool]:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
@@ -57,6 +58,7 @@ class SessionSlot:
mamba_next_track_idx: Any = None mamba_next_track_idx: Any = None
mamba_last_track_seqlen: Any = None mamba_last_track_seqlen: Any = None
mamba_branching_seqlen: Any = None mamba_branching_seqlen: Any = None
last_access_time: float = field(default_factory=time.monotonic)
@property @property
def is_holding_kv(self) -> bool: def is_holding_kv(self) -> bool:
@@ -80,6 +82,7 @@ class SessionSlot:
self.mamba_next_track_idx = req.mamba_next_track_idx self.mamba_next_track_idx = req.mamba_next_track_idx
self.mamba_last_track_seqlen = req.mamba_last_track_seqlen self.mamba_last_track_seqlen = req.mamba_last_track_seqlen
self.mamba_branching_seqlen = req.mamba_branching_seqlen self.mamba_branching_seqlen = req.mamba_branching_seqlen
self.last_access_time = time.monotonic()
req.req_pool_idx = None req.req_pool_idx = None
req.mamba_pool_idx = None req.mamba_pool_idx = None
@@ -97,6 +100,7 @@ class SessionSlot:
req.mamba_next_track_idx = self.mamba_next_track_idx req.mamba_next_track_idx = self.mamba_next_track_idx
req.mamba_last_track_seqlen = self.mamba_last_track_seqlen req.mamba_last_track_seqlen = self.mamba_last_track_seqlen
req.mamba_branching_seqlen = self.mamba_branching_seqlen req.mamba_branching_seqlen = self.mamba_branching_seqlen
self.last_access_time = time.monotonic()
# NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared # NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared
# from the slot. During chunked prefill, a request may be rejected by # from the slot. During chunked prefill, a request may be rejected by
@@ -247,7 +251,7 @@ class SessionAwareCache(BasePrefixCache):
"""Release all KV resources held by a streaming session.""" """Release all KV resources held by a streaming session."""
slot = self.slots.pop(session_id, None) slot = self.slots.pop(session_id, None)
if slot is None: if slot is None:
return return 0
if slot.last_node is not None: if slot.last_node is not None:
if slot.swa_uuid_for_lock is not None: if slot.swa_uuid_for_lock is not None:
@@ -261,20 +265,58 @@ class SessionAwareCache(BasePrefixCache):
if slot.is_holding_kv: if slot.is_holding_kv:
start = slot.cache_protected_len start = slot.cache_protected_len
end = slot.kv_allocated_len end = slot.kv_allocated_len
freed_tokens = max(0, end - start)
if start < end: if start < end:
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
slot.req_pool_idx, start:end slot.req_pool_idx, start:end
] ]
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free_slots.append(slot.req_pool_idx) self.req_to_token_pool.free_slots.append(slot.req_pool_idx)
return freed_tokens
return 0
def slot_held_tokens(self, slot: SessionSlot) -> int:
if not slot.is_holding_kv:
return 0
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
return max(0, allocated - slot.cache_protected_len)
def get_session_status(self, session_id: str) -> Optional[Dict[str, Any]]:
slot = self.slots.get(session_id)
if slot is None:
return None
return {
"session_id": session_id,
"resident": slot.is_holding_kv,
"resident_tokens": int(self.slot_held_tokens(slot)),
"kv_committed_len": int(slot.kv_committed_len),
"kv_allocated_len": int(slot.kv_allocated_len),
"cache_protected_len": int(slot.cache_protected_len),
"last_access_time": float(slot.last_access_time),
}
def list_session_statuses(self) -> list[Dict[str, Any]]:
statuses = []
for session_id, slot in self.slots.items():
statuses.append(
{
"session_id": session_id,
"resident": slot.is_holding_kv,
"resident_tokens": int(self.slot_held_tokens(slot)),
"kv_committed_len": int(slot.kv_committed_len),
"kv_allocated_len": int(slot.kv_allocated_len),
"cache_protected_len": int(slot.cache_protected_len),
"last_access_time": float(slot.last_access_time),
}
)
statuses.sort(key=lambda item: item["last_access_time"])
return statuses
def session_held_tokens(self) -> int: def session_held_tokens(self) -> int:
"""Total KV tokens held by session slots, not tracked by the tree.""" """Total KV tokens held by session slots, not tracked by the tree."""
total = 0 total = 0
for slot in self.slots.values(): for slot in self.slots.values():
if slot.is_holding_kv: total += self.slot_held_tokens(slot)
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
total += allocated - slot.cache_protected_len
return total return total
def session_held_full_tokens(self) -> int: def session_held_full_tokens(self) -> int:

View File

@@ -697,6 +697,7 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_ib_device: Optional[str] = None disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_offload_kvcache: bool = False disaggregation_decode_enable_offload_kvcache: bool = False
disaggregation_decode_allow_local_prefill: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small # FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1 disaggregation_decode_polling_interval: int = 1
@@ -5772,6 +5773,14 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable async KV cache offloading on decode server (PD mode).", help="Enable async KV cache offloading on decode server (PD mode).",
) )
parser.add_argument(
"--disaggregation-decode-allow-local-prefill",
action="store_true",
help=(
"Allow decode workers in PD mode to accept direct local requests "
"without bootstrap metadata and run local append-prefill."
),
)
parser.add_argument( parser.add_argument(
"--num-reserved-decode-tokens", "--num-reserved-decode-tokens",
type=int, type=int,