diff --git a/third_party/sglang/python/sglang/srt/disaggregation/decode.py b/third_party/sglang/python/sglang/srt/disaggregation/decode.py index f54c882..a906a1d 100644 --- a/third_party/sglang/python/sglang/srt/disaggregation/decode.py +++ b/third_party/sglang/python/sglang/srt/disaggregation/decode.py @@ -140,9 +140,9 @@ class DecodeReqToTokenPool: # Indices of reqs that already have a req_pool_idx and will reuse # their existing slot (e.g. chunked prefill continuing across chunks). reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None] - assert ( - len(reusing) <= 1 - ), "only one chunked request may reuse req_pool_idx in a batch" + # Keep decode-side allocation behavior aligned with the shared pool: + # local append-prefill can legitimately batch multiple resident-session + # requests that each reuse their existing req_pool_idx. assert all( reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing ), "reusing request must be chunked or have committed KV" @@ -657,6 +657,12 @@ class DecodePreallocQueue: allocatable_tokens = self._allocatable_tokens( retractable_tokens=retractable_tokens, count_retracted=True ) + + def refresh_allocatable_tokens() -> int: + return self._allocatable_tokens( + retractable_tokens=retractable_tokens, count_retracted=True + ) + # First, remove all failed requests from the queue for i, decode_req in enumerate(self.queue): if rids_to_check is not None and decode_req.req.rid not in rids_to_check: @@ -680,7 +686,14 @@ class DecodePreallocQueue: continue if self.req_to_token_pool.available_size() <= 0: - break + self.scheduler.maybe_trim_decode_session_cache( + required_tokens=max(1, self.num_reserved_decode_tokens), + force=True, + max_sessions=1, + ) + allocatable_tokens = refresh_allocatable_tokens() + if self.req_to_token_pool.available_size() <= 0: + break if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: break @@ -704,9 +717,46 @@ class DecodePreallocQueue: ) > allocatable_tokens ): - break + required_tokens = max( + 0, + max( + required_tokens_for_request, + origin_input_len + + min( + decode_req.req.sampling_params.max_new_tokens, + CLIP_MAX_NEW_TOKEN, + ) + - retractable_tokens, + ) + - allocatable_tokens, + ) + if required_tokens > 0: + self.scheduler.maybe_trim_decode_session_cache( + required_tokens=required_tokens, + force=True, + ) + allocatable_tokens = refresh_allocatable_tokens() + if ( + max( + required_tokens_for_request, + origin_input_len + + min( + decode_req.req.sampling_params.max_new_tokens, + CLIP_MAX_NEW_TOKEN, + ) + - retractable_tokens, + ) + > allocatable_tokens + ): + break if required_tokens_for_request > allocatable_tokens: - break + self.scheduler.maybe_trim_decode_session_cache( + required_tokens=required_tokens_for_request - allocatable_tokens, + force=True, + ) + allocatable_tokens = refresh_allocatable_tokens() + if required_tokens_for_request > allocatable_tokens: + break allocatable_tokens -= required_tokens_for_request dst_kv_indices = self._pre_alloc(decode_req.req) @@ -1135,6 +1185,27 @@ class DecodeTransferQueue: class SchedulerDisaggregationDecodeMixin: + def _merge_last_local_prefill_batch(self: Scheduler): + if self.last_batch is None or not self.last_batch.forward_mode.is_extend(): + return + + chunked_req_to_exclude = set() + if self.last_batch.chunked_req is not None: + chunked_req_to_exclude.add(self.last_batch.chunked_req) + + last_bs = self.last_batch.batch_size() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) + if self.last_batch.batch_size() < last_bs: + self.running_batch.batch_is_full = False + + if not self.last_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.last_batch + else: + self.running_batch.merge_batch(self.last_batch) + @torch.no_grad() def event_loop_normal_disagg_decode(self: Scheduler): """A normal scheduler loop for decode worker in disaggregation mode.""" @@ -1213,6 +1284,8 @@ class SchedulerDisaggregationDecodeMixin: self: Scheduler, ) -> Optional[ScheduleBatch]: """Process prebuilt batch and schedule the next decode batch.""" + self._merge_last_local_prefill_batch() + # Process pending prebuilt batch: output processing + filter + merge new_prebuilt_batch = self.get_new_prebuilt_batch() if new_prebuilt_batch: @@ -1229,6 +1302,13 @@ class SchedulerDisaggregationDecodeMixin: else: self.running_batch.merge_batch(new_prebuilt_batch) + new_local_batch = self.get_new_local_extend_batch() + if new_local_batch is not None: + ret = self.maybe_prepare_mlp_sync_batch(new_local_batch) + if ret: + set_schedule_time_batch(ret) + return ret + # Schedule decode batch if self.running_batch.is_empty(): ret = None @@ -1241,6 +1321,22 @@ class SchedulerDisaggregationDecodeMixin: set_schedule_time_batch(ret) return ret + def get_new_local_extend_batch(self: Scheduler) -> Optional[ScheduleBatch]: + if not self.server_args.disaggregation_decode_allow_local_prefill: + return None + if len(self.decode_direct_waiting_queue) == 0: + return None + + original_waiting_queue = self.waiting_queue + try: + self.waiting_queue = self.decode_direct_waiting_queue + new_batch = self.get_new_batch_prefill() + self.decode_direct_waiting_queue = self.waiting_queue + finally: + self.waiting_queue = original_waiting_queue + + return new_batch + def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: """Create a schedulebatch for fake completed prefill""" if self.grammar_manager.has_waiting_grammars(): diff --git a/third_party/sglang/python/sglang/srt/entrypoints/http_server.py b/third_party/sglang/python/sglang/srt/entrypoints/http_server.py index 6978e0c..07f5521 100644 --- a/third_party/sglang/python/sglang/srt/entrypoints/http_server.py +++ b/third_party/sglang/python/sglang/srt/entrypoints/http_server.py @@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import ( InitWeightsUpdateGroupReqInput, LoadLoRAAdapterFromTensorsReqInput, LoadLoRAAdapterReqInput, + DirectAppendAdmissionReqInput, OpenSessionReqInput, ParseFunctionCallReq, PauseGenerationReqInput, @@ -1289,6 +1290,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request): return _create_error_response(e) +@app.post("/session_cache/admit_direct_append") +async def admit_direct_append(obj: DirectAppendAdmissionReqInput): + return await _global_state.tokenizer_manager.admit_direct_append(obj) + + @app.api_route("/configure_logging", methods=["GET", "POST"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) async def configure_logging(obj: ConfigureLoggingReq, request: Request): diff --git a/third_party/sglang/python/sglang/srt/managers/io_struct.py b/third_party/sglang/python/sglang/srt/managers/io_struct.py index bd97965..c36cc4d 100644 --- a/third_party/sglang/python/sglang/srt/managers/io_struct.py +++ b/third_party/sglang/python/sglang/srt/managers/io_struct.py @@ -1597,6 +1597,30 @@ class OpenSessionReqOutput(BaseReq): success: bool +@dataclass +class DirectAppendAdmissionReqInput(BaseReq): + session_id: str + uncached_input_tokens: int + output_tokens: int + + +@dataclass +class DirectAppendAdmissionReqOutput(BaseReq): + can_admit: bool + resident: bool + reason: Optional[str] = None + required_tokens: int = 0 + available_tokens_before: int = 0 + available_tokens_after: int = 0 + evicted_session_count: int = 0 + freed_tokens: int = 0 + token_usage: float = 0.0 + num_running_reqs: int = 0 + decode_prealloc_queue_reqs: int = 0 + decode_transfer_queue_reqs: int = 0 + decode_retracted_queue_reqs: int = 0 + + @dataclass class HealthCheckOutput(BaseReq): pass diff --git a/third_party/sglang/python/sglang/srt/managers/schedule_batch.py b/third_party/sglang/python/sglang/srt/managers/schedule_batch.py index 0b26be6..f42bc0e 100644 --- a/third_party/sglang/python/sglang/srt/managers/schedule_batch.py +++ b/third_party/sglang/python/sglang/srt/managers/schedule_batch.py @@ -1566,6 +1566,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Init tensors reqs = self.reqs + for req in reqs: + if req.session is None or not req.session.streaming: + continue + actual_extend_len = max(0, len(req.fill_ids) - len(req.prefix_indices)) + if req.extend_input_len != actual_extend_len: + logger.warning( + "Correcting streaming-session extend_input_len from %d to %d " + "(rid=%s, session_id=%s, fill_len=%d, prefix_len=%d, kv_committed_len=%d)", + req.extend_input_len, + actual_extend_len, + req.rid, + req.session.session_id, + len(req.fill_ids), + len(req.prefix_indices), + req.kv_committed_len, + ) + req.set_extend_input_len(actual_extend_len) input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [len(r.fill_ids) for r in reqs] diff --git a/third_party/sglang/python/sglang/srt/managers/scheduler.py b/third_party/sglang/python/sglang/srt/managers/scheduler.py index 67af2d0..2eac6ea 100644 --- a/third_party/sglang/python/sglang/srt/managers/scheduler.py +++ b/third_party/sglang/python/sglang/srt/managers/scheduler.py @@ -94,6 +94,8 @@ from sglang.srt.managers.io_struct import ( ClearHiCacheReqOutput, CloseSessionReqInput, ContinueGenerationReqInput, + DirectAppendAdmissionReqInput, + DirectAppendAdmissionReqOutput, DestroyWeightsUpdateGroupReqInput, DetachHiCacheStorageReqInput, DetachHiCacheStorageReqOutput, @@ -844,6 +846,7 @@ class Scheduler( def init_running_status(self): self.waiting_queue: List[Req] = [] + self.decode_direct_waiting_queue: List[Req] = [] # The running decoding batch for continuous batching self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) # The current forward batch @@ -1215,6 +1218,7 @@ class Scheduler( (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), + (DirectAppendAdmissionReqInput, self.admit_direct_append), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group), @@ -1589,6 +1593,7 @@ class Scheduler( def process_input_requests(self, recv_reqs: List): now = time.monotonic() self.session_controller.maybe_reap(now) + self.maybe_trim_decode_session_cache() for recv_req in recv_reqs: # Skip health check when server is busy — ongoing requests already carry health info. if is_health_check_generate_req(recv_req) and not self.is_fully_idle( @@ -1781,6 +1786,10 @@ class Scheduler( # Invalid request for disaggregated mode if ( recv_req.bootstrap_room is None + and not ( + self.disaggregation_mode == DisaggregationMode.DECODE + and self.server_args.disaggregation_decode_allow_local_prefill + ) and self.transfer_backend != TransferBackend.FAKE ): error_msg = ( @@ -1949,6 +1958,12 @@ class Scheduler( ) req.time_stats.set_prefill_bootstrap_queue_entry_time() elif self.disaggregation_mode == DisaggregationMode.DECODE: + if self._should_allow_local_prefill_on_decode(req): + if not self._set_or_validate_priority(req): + return + self.decode_direct_waiting_queue.append(req) + req.time_stats.set_wait_queue_entry_time() + return self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted) if not is_retracted: req.time_stats.set_decode_prealloc_queue_entry_time() @@ -1957,6 +1972,88 @@ class Scheduler( else: raise ValueError(f"Invalid {self.disaggregation_mode=}") + def _should_allow_local_prefill_on_decode(self, req: Req) -> bool: + return ( + self.disaggregation_mode == DisaggregationMode.DECODE + and self.server_args.disaggregation_decode_allow_local_prefill + and req.bootstrap_room is None + ) + + def _decode_session_cache_low_watermark_tokens(self) -> int: + return min( + self.max_total_num_tokens, + max( + self.server_args.num_reserved_decode_tokens * 16, + self.max_total_num_tokens // 12, + 16384, + ), + ) + + def _decode_session_cache_target_available_tokens(self) -> int: + return min( + self.max_total_num_tokens, + max( + self._decode_session_cache_low_watermark_tokens(), + self.server_args.num_reserved_decode_tokens * 24, + self.max_total_num_tokens // 8, + 24576, + ), + ) + + def maybe_trim_decode_session_cache( + self, + required_tokens: int = 0, + force: bool = False, + max_sessions: Optional[int] = None, + exclude_session_ids: Optional[set[str]] = None, + ) -> Dict[str, int]: + if ( + self.disaggregation_mode != DisaggregationMode.DECODE + or not isinstance(self.tree_cache, SessionAwareCache) + ): + return { + "evicted_session_count": 0, + "freed_tokens": 0, + "available_tokens_before": 0, + "available_tokens_after": 0, + } + + available_tokens = self.token_to_kv_pool_allocator.available_size() + low_watermark_tokens = self._decode_session_cache_low_watermark_tokens() + target_available_tokens = self._decode_session_cache_target_available_tokens() + min_available_tokens = max( + low_watermark_tokens, + available_tokens + max(0, required_tokens), + ) + target_available_tokens = max( + target_available_tokens, + min_available_tokens, + ) + if not force and available_tokens >= low_watermark_tokens: + return { + "evicted_session_count": 0, + "freed_tokens": 0, + "available_tokens_before": int(available_tokens), + "available_tokens_after": int(available_tokens), + } + + result = self.session_controller.evict_idle_streaming_sessions_lru( + required_tokens=required_tokens, + min_available_tokens=min_available_tokens, + target_available_tokens=target_available_tokens, + max_sessions=max_sessions, + exclude_session_ids=exclude_session_ids, + ) + if result["evicted_session_count"] > 0: + logger.info( + "Trimmed decode session cache via LRU. " + f"#evicted_sessions: {result['evicted_session_count']}, " + f"#freed_tokens: {result['freed_tokens']}, " + f"#available_tokens: {result['available_tokens_before']} -> " + f"{result['available_tokens_after']}" + ) + return result + def _set_or_validate_priority(self, req: Req) -> bool: """Set the default priority value, or abort the request based on the priority scheduling mode.""" if self.enable_priority_scheduling and req.priority is None: @@ -2558,8 +2655,19 @@ class Scheduler( if self.enable_hierarchical_cache: self.tree_cache.flush_write_through_acks() - # Check if decode out of memory - if (kv_full_retract_flag := not batch.check_decode_mem()) or ( + # Check if decode out of memory. Before retracting active decode work, + # let the worker evict idle streaming-session KV held outside the radix tree. + kv_full_retract_flag = not batch.check_decode_mem() + idle_session_eviction = None + if kv_full_retract_flag: + idle_session_eviction = self.maybe_trim_decode_session_cache( + required_tokens=max(1, self.server_args.num_reserved_decode_tokens), + force=True, + ) + if idle_session_eviction["freed_tokens"] > 0: + kv_full_retract_flag = not batch.check_decode_mem() + + if kv_full_retract_flag or ( TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 ): old_available_tokens = self.token_to_kv_pool_allocator.available_size() @@ -2602,6 +2710,13 @@ class Scheduler( msg_details += ( f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}" ) + if idle_session_eviction is not None: + msg_details += ( + ", #idle_session_evicted: " + f"{idle_session_eviction['evicted_session_count']}" + ", #idle_tokens_freed: " + f"{idle_session_eviction['freed_tokens']}" + ) logger.warning(msg_prefix + msg_details) for req in retracted_reqs: @@ -2916,6 +3031,7 @@ class Scheduler( # Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode) idle &= len(self.waiting_queue) == 0 + idle &= len(self.decode_direct_waiting_queue) == 0 if not for_health_check: # Grammar queue and prefill inflight queue may not produce batch @@ -3077,6 +3193,9 @@ class Scheduler( "graph": round(self.tp_worker.model_runner.graph_mem_usage, 2), } ret["effective_max_running_requests_per_dp"] = self.max_running_requests + ret["session_cache"] = ( + self.session_controller.get_streaming_session_cache_status() + ) if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: ret["avg_spec_accept_length"] = ( @@ -3375,6 +3494,86 @@ class Scheduler( def close_session(self, recv_req: CloseSessionReqInput): self.session_controller.close(recv_req) + def admit_direct_append( + self, recv_req: DirectAppendAdmissionReqInput + ) -> DirectAppendAdmissionReqOutput: + if ( + self.disaggregation_mode != DisaggregationMode.DECODE + or not self.server_args.disaggregation_decode_allow_local_prefill + or not isinstance(self.tree_cache, SessionAwareCache) + ): + return DirectAppendAdmissionReqOutput( + can_admit=False, + resident=False, + reason="unsupported", + ) + + session_cache_status = self.session_controller.get_streaming_session_cache_status( + recv_req.session_id + ) + target_session = session_cache_status.get("target_session") + resident = bool( + isinstance(target_session, dict) and target_session.get("resident") + ) + if not resident: + return DirectAppendAdmissionReqOutput( + can_admit=False, + resident=False, + reason="session-not-resident", + available_tokens_before=int( + self.token_to_kv_pool_allocator.available_size() + ), + available_tokens_after=int( + self.token_to_kv_pool_allocator.available_size() + ), + token_usage=( + 1.0 + - self.token_to_kv_pool_allocator.available_size() + / max(1, self.max_total_num_tokens) + ), + num_running_reqs=len(self.running_batch.reqs), + decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue), + decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue), + decode_retracted_queue_reqs=len( + self.disagg_decode_prealloc_queue.retracted_queue + ), + ) + + required_tokens = max(0, recv_req.uncached_input_tokens) + max( + 0, recv_req.output_tokens + ) + available_tokens_before = int(self.token_to_kv_pool_allocator.available_size()) + trim_result = self.maybe_trim_decode_session_cache( + required_tokens=required_tokens, + force=available_tokens_before < required_tokens, + exclude_session_ids={recv_req.session_id}, + ) + available_tokens_after = int(self.token_to_kv_pool_allocator.available_size()) + decode_retracted_queue_reqs = len(self.disagg_decode_prealloc_queue.retracted_queue) + can_admit = ( + decode_retracted_queue_reqs == 0 + and available_tokens_after >= required_tokens + ) + reason = None if can_admit else "no-space" + + return DirectAppendAdmissionReqOutput( + can_admit=can_admit, + resident=True, + reason=reason, + required_tokens=int(required_tokens), + available_tokens_before=available_tokens_before, + available_tokens_after=available_tokens_after, + evicted_session_count=int(trim_result["evicted_session_count"]), + freed_tokens=int(trim_result["freed_tokens"]), + token_usage=( + 1.0 - available_tokens_after / max(1, self.max_total_num_tokens) + ), + num_running_reqs=len(self.running_batch.reqs), + decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue), + decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue), + decode_retracted_queue_reqs=decode_retracted_queue_reqs, + ) + def maybe_sleep_on_idle(self): if self.idle_sleeper is not None: self.idle_sleeper.maybe_sleep() diff --git a/third_party/sglang/python/sglang/srt/managers/session_controller.py b/third_party/sglang/python/sglang/srt/managers/session_controller.py index caf165c..e9ca529 100644 --- a/third_party/sglang/python/sglang/srt/managers/session_controller.py +++ b/third_party/sglang/python/sglang/srt/managers/session_controller.py @@ -229,6 +229,12 @@ class Session: priority=req.priority, routing_key=req.routing_key, http_worker_ipc=req.http_worker_ipc, + bootstrap_host=req.bootstrap_host, + bootstrap_port=req.bootstrap_port, + bootstrap_room=req.bootstrap_room, + routed_dp_rank=req.routed_dp_rank, + disagg_prefill_dp_rank=req.disagg_prefill_dp_rank, + extra_key=req.extra_key, time_stats=req.time_stats, ) if last_req is not None: @@ -284,7 +290,7 @@ class SessionController: else: self._close(session_id) - def _close(self, session_id: str): + def _close(self, session_id: str) -> int: session = self.sessions[session_id] if session.streaming and session.req_nodes: assert len(session.req_nodes) == 1 @@ -303,9 +309,166 @@ class SessionController: mm.release_features() node.req.multimodal_inputs = None + freed_tokens = 0 if isinstance(self.tree_cache, SessionAwareCache): - self.tree_cache.release_session(session_id) + freed_tokens = self.tree_cache.release_session(session_id) del self.sessions[session_id] + return freed_tokens + + def _is_idle_streaming_session(self, session: Session) -> bool: + if not session.streaming: + return False + if not session.req_nodes: + return False + return all(node.req.finished() for node in session.req_nodes.values()) + + def get_streaming_session_cache_status( + self, session_id: Optional[str] = None + ) -> Dict[str, object]: + if not isinstance(self.tree_cache, SessionAwareCache): + return { + "enabled": False, + "eviction_policy": None, + "session_count": 0, + "resident_session_count": 0, + "held_tokens": 0, + "available_tokens": 0, + "capacity_tokens": 0, + "idle_evictable_session_count": 0, + "idle_evictable_tokens": 0, + "sessions": [], + "target_session": None, + } + + all_statuses = self.tree_cache.list_session_statuses() + session_map = {status["session_id"]: status for status in all_statuses} + idle_evictable_tokens = 0 + idle_evictable_session_count = 0 + sessions = [] + for status in all_statuses: + session = self.sessions.get(status["session_id"]) + idle_evictable = session is not None and self._is_idle_streaming_session( + session + ) + if idle_evictable and status["resident"]: + idle_evictable_session_count += 1 + idle_evictable_tokens += int(status["resident_tokens"]) + if session_id is None or status["session_id"] == session_id: + sessions.append( + { + **status, + "streaming": bool(session.streaming) if session is not None else True, + "idle_evictable": idle_evictable, + "timed_out": session.is_timed_out() if session is not None else False, + } + ) + + target_session = session_map.get(session_id) if session_id is not None else None + return { + "enabled": True, + "eviction_policy": "lru", + "session_count": len(all_statuses), + "resident_session_count": sum(1 for status in all_statuses if status["resident"]), + "held_tokens": int(self.tree_cache.session_held_tokens()), + "available_tokens": int(self.tree_cache.token_to_kv_pool_allocator.available_size()), + "capacity_tokens": int(self.tree_cache.token_to_kv_pool_allocator.size), + "idle_evictable_session_count": idle_evictable_session_count, + "idle_evictable_tokens": int(idle_evictable_tokens), + "sessions": sessions, + "target_session": ( + { + **target_session, + "idle_evictable": ( + self._is_idle_streaming_session(self.sessions[target_session["session_id"]]) + if target_session is not None + and target_session["session_id"] in self.sessions + else False + ), + } + if target_session is not None + else None + ), + } + + def evict_idle_streaming_sessions( + self, + required_tokens: int = 0, + max_sessions: Optional[int] = None, + ) -> Dict[str, int]: + return self.evict_idle_streaming_sessions_lru( + required_tokens=required_tokens, + max_sessions=max_sessions, + ) + + def evict_idle_streaming_sessions_lru( + self, + required_tokens: int = 0, + min_available_tokens: int = 0, + target_available_tokens: Optional[int] = None, + max_sessions: Optional[int] = None, + exclude_session_ids: Optional[set[str]] = None, + ) -> Dict[str, int]: + if not isinstance(self.tree_cache, SessionAwareCache): + return { + "evicted_session_count": 0, + "freed_tokens": 0, + "available_tokens_before": 0, + "available_tokens_after": 0, + } + + available_tokens_before = int( + self.tree_cache.token_to_kv_pool_allocator.available_size() + ) + evicted_session_count = 0 + freed_tokens = 0 + target_available_tokens = max( + int(min_available_tokens), + int(target_available_tokens or 0), + ) + candidates = [] + for status in self.tree_cache.list_session_statuses(): + if exclude_session_ids and status["session_id"] in exclude_session_ids: + continue + session = self.sessions.get(status["session_id"]) + if session is None or not self._is_idle_streaming_session(session): + continue + if not status["resident"]: + continue + candidates.append(status) + + for status in candidates: + if max_sessions is not None and evicted_session_count >= max_sessions: + break + available_tokens_now = available_tokens_before + freed_tokens + if ( + required_tokens > 0 + and freed_tokens >= required_tokens + and available_tokens_now >= min_available_tokens + ): + break + if ( + required_tokens <= 0 + and min_available_tokens <= 0 + and target_available_tokens <= 0 + ): + break + if ( + required_tokens <= 0 + and available_tokens_now >= min_available_tokens + and available_tokens_now >= target_available_tokens + ): + break + + freed = self._close(status["session_id"]) + evicted_session_count += 1 + freed_tokens += int(freed) + + return { + "evicted_session_count": evicted_session_count, + "freed_tokens": freed_tokens, + "available_tokens_before": available_tokens_before, + "available_tokens_after": available_tokens_before + freed_tokens, + } def maybe_reap(self, now: float, interval: float = 1.0): # reap sessions every second diff --git a/third_party/sglang/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/third_party/sglang/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 544c609..a1a3c1b 100644 --- a/third_party/sglang/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/third_party/sglang/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, + DirectAppendAdmissionReqInput, + DirectAppendAdmissionReqOutput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, DetachHiCacheStorageReqInput, @@ -220,6 +222,9 @@ class TokenizerCommunicatorMixin: self.get_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.direct_append_admission_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.set_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -316,6 +321,10 @@ class TokenizerCommunicatorMixin: GetInternalStateReqOutput, self.get_internal_state_communicator.handle_recv, ), + ( + DirectAppendAdmissionReqOutput, + self.direct_append_admission_communicator.handle_recv, + ), ( SetInternalStateReqOutput, self.set_internal_state_communicator.handle_recv, @@ -871,6 +880,16 @@ class TokenizerCommunicatorMixin: # Many DP ranks return [res.internal_state for res in responses] + async def admit_direct_append( + self: TokenizerManager, + obj: DirectAppendAdmissionReqInput, + ) -> DirectAppendAdmissionReqOutput: + self.auto_create_handle_loop() + responses: List[DirectAppendAdmissionReqOutput] = ( + await self.direct_append_admission_communicator(obj) + ) + return responses[0] + async def set_internal_state( self: TokenizerManager, obj: SetInternalStateReq ) -> List[bool]: diff --git a/third_party/sglang/python/sglang/srt/mem_cache/session_aware_cache.py b/third_party/sglang/python/sglang/srt/mem_cache/session_aware_cache.py index 0298a3a..e06c867 100644 --- a/third_party/sglang/python/sglang/srt/mem_cache/session_aware_cache.py +++ b/third_party/sglang/python/sglang/srt/mem_cache/session_aware_cache.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Optional @@ -57,6 +58,7 @@ class SessionSlot: mamba_next_track_idx: Any = None mamba_last_track_seqlen: Any = None mamba_branching_seqlen: Any = None + last_access_time: float = field(default_factory=time.monotonic) @property def is_holding_kv(self) -> bool: @@ -80,6 +82,7 @@ class SessionSlot: self.mamba_next_track_idx = req.mamba_next_track_idx self.mamba_last_track_seqlen = req.mamba_last_track_seqlen self.mamba_branching_seqlen = req.mamba_branching_seqlen + self.last_access_time = time.monotonic() req.req_pool_idx = None req.mamba_pool_idx = None @@ -97,6 +100,7 @@ class SessionSlot: req.mamba_next_track_idx = self.mamba_next_track_idx req.mamba_last_track_seqlen = self.mamba_last_track_seqlen req.mamba_branching_seqlen = self.mamba_branching_seqlen + self.last_access_time = time.monotonic() # NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared # from the slot. During chunked prefill, a request may be rejected by @@ -247,7 +251,7 @@ class SessionAwareCache(BasePrefixCache): """Release all KV resources held by a streaming session.""" slot = self.slots.pop(session_id, None) if slot is None: - return + return 0 if slot.last_node is not None: if slot.swa_uuid_for_lock is not None: @@ -261,20 +265,58 @@ class SessionAwareCache(BasePrefixCache): if slot.is_holding_kv: start = slot.cache_protected_len end = slot.kv_allocated_len + freed_tokens = max(0, end - start) if start < end: kv_indices = self.req_to_token_pool.req_to_token[ slot.req_pool_idx, start:end ] self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free_slots.append(slot.req_pool_idx) + return freed_tokens + return 0 + + def slot_held_tokens(self, slot: SessionSlot) -> int: + if not slot.is_holding_kv: + return 0 + allocated = ceil_align(slot.kv_allocated_len, self.page_size) + return max(0, allocated - slot.cache_protected_len) + + def get_session_status(self, session_id: str) -> Optional[Dict[str, Any]]: + slot = self.slots.get(session_id) + if slot is None: + return None + return { + "session_id": session_id, + "resident": slot.is_holding_kv, + "resident_tokens": int(self.slot_held_tokens(slot)), + "kv_committed_len": int(slot.kv_committed_len), + "kv_allocated_len": int(slot.kv_allocated_len), + "cache_protected_len": int(slot.cache_protected_len), + "last_access_time": float(slot.last_access_time), + } + + def list_session_statuses(self) -> list[Dict[str, Any]]: + statuses = [] + for session_id, slot in self.slots.items(): + statuses.append( + { + "session_id": session_id, + "resident": slot.is_holding_kv, + "resident_tokens": int(self.slot_held_tokens(slot)), + "kv_committed_len": int(slot.kv_committed_len), + "kv_allocated_len": int(slot.kv_allocated_len), + "cache_protected_len": int(slot.cache_protected_len), + "last_access_time": float(slot.last_access_time), + } + ) + statuses.sort(key=lambda item: item["last_access_time"]) + return statuses def session_held_tokens(self) -> int: """Total KV tokens held by session slots, not tracked by the tree.""" total = 0 for slot in self.slots.values(): - if slot.is_holding_kv: - allocated = ceil_align(slot.kv_allocated_len, self.page_size) - total += allocated - slot.cache_protected_len + total += self.slot_held_tokens(slot) return total def session_held_full_tokens(self) -> int: diff --git a/third_party/sglang/python/sglang/srt/server_args.py b/third_party/sglang/python/sglang/srt/server_args.py index d91ced8..26a993d 100644 --- a/third_party/sglang/python/sglang/srt/server_args.py +++ b/third_party/sglang/python/sglang/srt/server_args.py @@ -697,6 +697,7 @@ class ServerArgs: disaggregation_bootstrap_port: int = 8998 disaggregation_ib_device: Optional[str] = None disaggregation_decode_enable_offload_kvcache: bool = False + disaggregation_decode_allow_local_prefill: bool = False num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 @@ -5772,6 +5773,14 @@ class ServerArgs: action="store_true", help="Enable async KV cache offloading on decode server (PD mode).", ) + parser.add_argument( + "--disaggregation-decode-allow-local-prefill", + action="store_true", + help=( + "Allow decode workers in PD mode to accept direct local requests " + "without bootstrap metadata and run local append-prefill." + ), + ) parser.add_argument( "--num-reserved-decode-tokens", type=int,