feat(sglang): support decode session cache admission
This commit is contained in:
@@ -140,9 +140,9 @@ class DecodeReqToTokenPool:
|
|||||||
# Indices of reqs that already have a req_pool_idx and will reuse
|
# Indices of reqs that already have a req_pool_idx and will reuse
|
||||||
# their existing slot (e.g. chunked prefill continuing across chunks).
|
# their existing slot (e.g. chunked prefill continuing across chunks).
|
||||||
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
|
reusing = [i for i, r in enumerate(reqs) if r.req_pool_idx is not None]
|
||||||
assert (
|
# Keep decode-side allocation behavior aligned with the shared pool:
|
||||||
len(reusing) <= 1
|
# local append-prefill can legitimately batch multiple resident-session
|
||||||
), "only one chunked request may reuse req_pool_idx in a batch"
|
# requests that each reuse their existing req_pool_idx.
|
||||||
assert all(
|
assert all(
|
||||||
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing
|
reqs[i].is_chunked > 0 or reqs[i].kv_committed_len > 0 for i in reusing
|
||||||
), "reusing request must be chunked or have committed KV"
|
), "reusing request must be chunked or have committed KV"
|
||||||
@@ -657,6 +657,12 @@ class DecodePreallocQueue:
|
|||||||
allocatable_tokens = self._allocatable_tokens(
|
allocatable_tokens = self._allocatable_tokens(
|
||||||
retractable_tokens=retractable_tokens, count_retracted=True
|
retractable_tokens=retractable_tokens, count_retracted=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def refresh_allocatable_tokens() -> int:
|
||||||
|
return self._allocatable_tokens(
|
||||||
|
retractable_tokens=retractable_tokens, count_retracted=True
|
||||||
|
)
|
||||||
|
|
||||||
# First, remove all failed requests from the queue
|
# First, remove all failed requests from the queue
|
||||||
for i, decode_req in enumerate(self.queue):
|
for i, decode_req in enumerate(self.queue):
|
||||||
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
|
if rids_to_check is not None and decode_req.req.rid not in rids_to_check:
|
||||||
@@ -680,7 +686,14 @@ class DecodePreallocQueue:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if self.req_to_token_pool.available_size() <= 0:
|
if self.req_to_token_pool.available_size() <= 0:
|
||||||
break
|
self.scheduler.maybe_trim_decode_session_cache(
|
||||||
|
required_tokens=max(1, self.num_reserved_decode_tokens),
|
||||||
|
force=True,
|
||||||
|
max_sessions=1,
|
||||||
|
)
|
||||||
|
allocatable_tokens = refresh_allocatable_tokens()
|
||||||
|
if self.req_to_token_pool.available_size() <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||||
break
|
break
|
||||||
@@ -704,9 +717,46 @@ class DecodePreallocQueue:
|
|||||||
)
|
)
|
||||||
> allocatable_tokens
|
> allocatable_tokens
|
||||||
):
|
):
|
||||||
break
|
required_tokens = max(
|
||||||
|
0,
|
||||||
|
max(
|
||||||
|
required_tokens_for_request,
|
||||||
|
origin_input_len
|
||||||
|
+ min(
|
||||||
|
decode_req.req.sampling_params.max_new_tokens,
|
||||||
|
CLIP_MAX_NEW_TOKEN,
|
||||||
|
)
|
||||||
|
- retractable_tokens,
|
||||||
|
)
|
||||||
|
- allocatable_tokens,
|
||||||
|
)
|
||||||
|
if required_tokens > 0:
|
||||||
|
self.scheduler.maybe_trim_decode_session_cache(
|
||||||
|
required_tokens=required_tokens,
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
allocatable_tokens = refresh_allocatable_tokens()
|
||||||
|
if (
|
||||||
|
max(
|
||||||
|
required_tokens_for_request,
|
||||||
|
origin_input_len
|
||||||
|
+ min(
|
||||||
|
decode_req.req.sampling_params.max_new_tokens,
|
||||||
|
CLIP_MAX_NEW_TOKEN,
|
||||||
|
)
|
||||||
|
- retractable_tokens,
|
||||||
|
)
|
||||||
|
> allocatable_tokens
|
||||||
|
):
|
||||||
|
break
|
||||||
if required_tokens_for_request > allocatable_tokens:
|
if required_tokens_for_request > allocatable_tokens:
|
||||||
break
|
self.scheduler.maybe_trim_decode_session_cache(
|
||||||
|
required_tokens=required_tokens_for_request - allocatable_tokens,
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
allocatable_tokens = refresh_allocatable_tokens()
|
||||||
|
if required_tokens_for_request > allocatable_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
allocatable_tokens -= required_tokens_for_request
|
allocatable_tokens -= required_tokens_for_request
|
||||||
dst_kv_indices = self._pre_alloc(decode_req.req)
|
dst_kv_indices = self._pre_alloc(decode_req.req)
|
||||||
@@ -1135,6 +1185,27 @@ class DecodeTransferQueue:
|
|||||||
|
|
||||||
class SchedulerDisaggregationDecodeMixin:
|
class SchedulerDisaggregationDecodeMixin:
|
||||||
|
|
||||||
|
def _merge_last_local_prefill_batch(self: Scheduler):
|
||||||
|
if self.last_batch is None or not self.last_batch.forward_mode.is_extend():
|
||||||
|
return
|
||||||
|
|
||||||
|
chunked_req_to_exclude = set()
|
||||||
|
if self.last_batch.chunked_req is not None:
|
||||||
|
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
||||||
|
|
||||||
|
last_bs = self.last_batch.batch_size()
|
||||||
|
self.last_batch.filter_batch(
|
||||||
|
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
||||||
|
)
|
||||||
|
if self.last_batch.batch_size() < last_bs:
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
|
if not self.last_batch.is_empty():
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
self.running_batch = self.last_batch
|
||||||
|
else:
|
||||||
|
self.running_batch.merge_batch(self.last_batch)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def event_loop_normal_disagg_decode(self: Scheduler):
|
def event_loop_normal_disagg_decode(self: Scheduler):
|
||||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||||
@@ -1213,6 +1284,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
) -> Optional[ScheduleBatch]:
|
) -> Optional[ScheduleBatch]:
|
||||||
"""Process prebuilt batch and schedule the next decode batch."""
|
"""Process prebuilt batch and schedule the next decode batch."""
|
||||||
|
self._merge_last_local_prefill_batch()
|
||||||
|
|
||||||
# Process pending prebuilt batch: output processing + filter + merge
|
# Process pending prebuilt batch: output processing + filter + merge
|
||||||
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
||||||
if new_prebuilt_batch:
|
if new_prebuilt_batch:
|
||||||
@@ -1229,6 +1302,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
else:
|
else:
|
||||||
self.running_batch.merge_batch(new_prebuilt_batch)
|
self.running_batch.merge_batch(new_prebuilt_batch)
|
||||||
|
|
||||||
|
new_local_batch = self.get_new_local_extend_batch()
|
||||||
|
if new_local_batch is not None:
|
||||||
|
ret = self.maybe_prepare_mlp_sync_batch(new_local_batch)
|
||||||
|
if ret:
|
||||||
|
set_schedule_time_batch(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
# Schedule decode batch
|
# Schedule decode batch
|
||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
ret = None
|
ret = None
|
||||||
@@ -1241,6 +1321,22 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
set_schedule_time_batch(ret)
|
set_schedule_time_batch(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def get_new_local_extend_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||||
|
if not self.server_args.disaggregation_decode_allow_local_prefill:
|
||||||
|
return None
|
||||||
|
if len(self.decode_direct_waiting_queue) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
original_waiting_queue = self.waiting_queue
|
||||||
|
try:
|
||||||
|
self.waiting_queue = self.decode_direct_waiting_queue
|
||||||
|
new_batch = self.get_new_batch_prefill()
|
||||||
|
self.decode_direct_waiting_queue = self.waiting_queue
|
||||||
|
finally:
|
||||||
|
self.waiting_queue = original_waiting_queue
|
||||||
|
|
||||||
|
return new_batch
|
||||||
|
|
||||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||||
"""Create a schedulebatch for fake completed prefill"""
|
"""Create a schedulebatch for fake completed prefill"""
|
||||||
if self.grammar_manager.has_waiting_grammars():
|
if self.grammar_manager.has_waiting_grammars():
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
LoadLoRAAdapterFromTensorsReqInput,
|
LoadLoRAAdapterFromTensorsReqInput,
|
||||||
LoadLoRAAdapterReqInput,
|
LoadLoRAAdapterReqInput,
|
||||||
|
DirectAppendAdmissionReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
ParseFunctionCallReq,
|
ParseFunctionCallReq,
|
||||||
PauseGenerationReqInput,
|
PauseGenerationReqInput,
|
||||||
@@ -1289,6 +1290,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/session_cache/admit_direct_append")
|
||||||
|
async def admit_direct_append(obj: DirectAppendAdmissionReqInput):
|
||||||
|
return await _global_state.tokenizer_manager.admit_direct_append(obj)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||||
@auth_level(AuthLevel.ADMIN_OPTIONAL)
|
@auth_level(AuthLevel.ADMIN_OPTIONAL)
|
||||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||||
|
|||||||
@@ -1597,6 +1597,30 @@ class OpenSessionReqOutput(BaseReq):
|
|||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DirectAppendAdmissionReqInput(BaseReq):
|
||||||
|
session_id: str
|
||||||
|
uncached_input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DirectAppendAdmissionReqOutput(BaseReq):
|
||||||
|
can_admit: bool
|
||||||
|
resident: bool
|
||||||
|
reason: Optional[str] = None
|
||||||
|
required_tokens: int = 0
|
||||||
|
available_tokens_before: int = 0
|
||||||
|
available_tokens_after: int = 0
|
||||||
|
evicted_session_count: int = 0
|
||||||
|
freed_tokens: int = 0
|
||||||
|
token_usage: float = 0.0
|
||||||
|
num_running_reqs: int = 0
|
||||||
|
decode_prealloc_queue_reqs: int = 0
|
||||||
|
decode_transfer_queue_reqs: int = 0
|
||||||
|
decode_retracted_queue_reqs: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HealthCheckOutput(BaseReq):
|
class HealthCheckOutput(BaseReq):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1566,6 +1566,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
# Init tensors
|
# Init tensors
|
||||||
reqs = self.reqs
|
reqs = self.reqs
|
||||||
|
for req in reqs:
|
||||||
|
if req.session is None or not req.session.streaming:
|
||||||
|
continue
|
||||||
|
actual_extend_len = max(0, len(req.fill_ids) - len(req.prefix_indices))
|
||||||
|
if req.extend_input_len != actual_extend_len:
|
||||||
|
logger.warning(
|
||||||
|
"Correcting streaming-session extend_input_len from %d to %d "
|
||||||
|
"(rid=%s, session_id=%s, fill_len=%d, prefix_len=%d, kv_committed_len=%d)",
|
||||||
|
req.extend_input_len,
|
||||||
|
actual_extend_len,
|
||||||
|
req.rid,
|
||||||
|
req.session.session_id,
|
||||||
|
len(req.fill_ids),
|
||||||
|
len(req.prefix_indices),
|
||||||
|
req.kv_committed_len,
|
||||||
|
)
|
||||||
|
req.set_extend_input_len(actual_extend_len)
|
||||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
seq_lens = [len(r.fill_ids) for r in reqs]
|
seq_lens = [len(r.fill_ids) for r in reqs]
|
||||||
|
|||||||
@@ -94,6 +94,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ClearHiCacheReqOutput,
|
ClearHiCacheReqOutput,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ContinueGenerationReqInput,
|
ContinueGenerationReqInput,
|
||||||
|
DirectAppendAdmissionReqInput,
|
||||||
|
DirectAppendAdmissionReqOutput,
|
||||||
DestroyWeightsUpdateGroupReqInput,
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
DetachHiCacheStorageReqInput,
|
DetachHiCacheStorageReqInput,
|
||||||
DetachHiCacheStorageReqOutput,
|
DetachHiCacheStorageReqOutput,
|
||||||
@@ -844,6 +846,7 @@ class Scheduler(
|
|||||||
|
|
||||||
def init_running_status(self):
|
def init_running_status(self):
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
|
self.decode_direct_waiting_queue: List[Req] = []
|
||||||
# The running decoding batch for continuous batching
|
# The running decoding batch for continuous batching
|
||||||
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
||||||
# The current forward batch
|
# The current forward batch
|
||||||
@@ -1215,6 +1218,7 @@ class Scheduler(
|
|||||||
(AbortReq, self.abort_request),
|
(AbortReq, self.abort_request),
|
||||||
(OpenSessionReqInput, self.open_session),
|
(OpenSessionReqInput, self.open_session),
|
||||||
(CloseSessionReqInput, self.close_session),
|
(CloseSessionReqInput, self.close_session),
|
||||||
|
(DirectAppendAdmissionReqInput, self.admit_direct_append),
|
||||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||||
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
||||||
@@ -1589,6 +1593,7 @@ class Scheduler(
|
|||||||
def process_input_requests(self, recv_reqs: List):
|
def process_input_requests(self, recv_reqs: List):
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
self.session_controller.maybe_reap(now)
|
self.session_controller.maybe_reap(now)
|
||||||
|
self.maybe_trim_decode_session_cache()
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
# Skip health check when server is busy — ongoing requests already carry health info.
|
# Skip health check when server is busy — ongoing requests already carry health info.
|
||||||
if is_health_check_generate_req(recv_req) and not self.is_fully_idle(
|
if is_health_check_generate_req(recv_req) and not self.is_fully_idle(
|
||||||
@@ -1781,6 +1786,10 @@ class Scheduler(
|
|||||||
# Invalid request for disaggregated mode
|
# Invalid request for disaggregated mode
|
||||||
if (
|
if (
|
||||||
recv_req.bootstrap_room is None
|
recv_req.bootstrap_room is None
|
||||||
|
and not (
|
||||||
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||||
|
and self.server_args.disaggregation_decode_allow_local_prefill
|
||||||
|
)
|
||||||
and self.transfer_backend != TransferBackend.FAKE
|
and self.transfer_backend != TransferBackend.FAKE
|
||||||
):
|
):
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -1949,6 +1958,12 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
req.time_stats.set_prefill_bootstrap_queue_entry_time()
|
req.time_stats.set_prefill_bootstrap_queue_entry_time()
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
if self._should_allow_local_prefill_on_decode(req):
|
||||||
|
if not self._set_or_validate_priority(req):
|
||||||
|
return
|
||||||
|
self.decode_direct_waiting_queue.append(req)
|
||||||
|
req.time_stats.set_wait_queue_entry_time()
|
||||||
|
return
|
||||||
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
||||||
if not is_retracted:
|
if not is_retracted:
|
||||||
req.time_stats.set_decode_prealloc_queue_entry_time()
|
req.time_stats.set_decode_prealloc_queue_entry_time()
|
||||||
@@ -1957,6 +1972,88 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
||||||
|
|
||||||
|
def _should_allow_local_prefill_on_decode(self, req: Req) -> bool:
|
||||||
|
return (
|
||||||
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||||
|
and self.server_args.disaggregation_decode_allow_local_prefill
|
||||||
|
and req.bootstrap_room is None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_session_cache_low_watermark_tokens(self) -> int:
|
||||||
|
return min(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
max(
|
||||||
|
self.server_args.num_reserved_decode_tokens * 16,
|
||||||
|
self.max_total_num_tokens // 12,
|
||||||
|
16384,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_session_cache_target_available_tokens(self) -> int:
|
||||||
|
return min(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
max(
|
||||||
|
self._decode_session_cache_low_watermark_tokens(),
|
||||||
|
self.server_args.num_reserved_decode_tokens * 24,
|
||||||
|
self.max_total_num_tokens // 8,
|
||||||
|
24576,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def maybe_trim_decode_session_cache(
|
||||||
|
self,
|
||||||
|
required_tokens: int = 0,
|
||||||
|
force: bool = False,
|
||||||
|
max_sessions: Optional[int] = None,
|
||||||
|
exclude_session_ids: Optional[set[str]] = None,
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
if (
|
||||||
|
self.disaggregation_mode != DisaggregationMode.DECODE
|
||||||
|
or not isinstance(self.tree_cache, SessionAwareCache)
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"evicted_session_count": 0,
|
||||||
|
"freed_tokens": 0,
|
||||||
|
"available_tokens_before": 0,
|
||||||
|
"available_tokens_after": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
available_tokens = self.token_to_kv_pool_allocator.available_size()
|
||||||
|
low_watermark_tokens = self._decode_session_cache_low_watermark_tokens()
|
||||||
|
target_available_tokens = self._decode_session_cache_target_available_tokens()
|
||||||
|
min_available_tokens = max(
|
||||||
|
low_watermark_tokens,
|
||||||
|
available_tokens + max(0, required_tokens),
|
||||||
|
)
|
||||||
|
target_available_tokens = max(
|
||||||
|
target_available_tokens,
|
||||||
|
min_available_tokens,
|
||||||
|
)
|
||||||
|
if not force and available_tokens >= low_watermark_tokens:
|
||||||
|
return {
|
||||||
|
"evicted_session_count": 0,
|
||||||
|
"freed_tokens": 0,
|
||||||
|
"available_tokens_before": int(available_tokens),
|
||||||
|
"available_tokens_after": int(available_tokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.session_controller.evict_idle_streaming_sessions_lru(
|
||||||
|
required_tokens=required_tokens,
|
||||||
|
min_available_tokens=min_available_tokens,
|
||||||
|
target_available_tokens=target_available_tokens,
|
||||||
|
max_sessions=max_sessions,
|
||||||
|
exclude_session_ids=exclude_session_ids,
|
||||||
|
)
|
||||||
|
if result["evicted_session_count"] > 0:
|
||||||
|
logger.info(
|
||||||
|
"Trimmed decode session cache via LRU. "
|
||||||
|
f"#evicted_sessions: {result['evicted_session_count']}, "
|
||||||
|
f"#freed_tokens: {result['freed_tokens']}, "
|
||||||
|
f"#available_tokens: {result['available_tokens_before']} -> "
|
||||||
|
f"{result['available_tokens_after']}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def _set_or_validate_priority(self, req: Req) -> bool:
|
def _set_or_validate_priority(self, req: Req) -> bool:
|
||||||
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
||||||
if self.enable_priority_scheduling and req.priority is None:
|
if self.enable_priority_scheduling and req.priority is None:
|
||||||
@@ -2558,8 +2655,19 @@ class Scheduler(
|
|||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
self.tree_cache.flush_write_through_acks()
|
self.tree_cache.flush_write_through_acks()
|
||||||
|
|
||||||
# Check if decode out of memory
|
# Check if decode out of memory. Before retracting active decode work,
|
||||||
if (kv_full_retract_flag := not batch.check_decode_mem()) or (
|
# let the worker evict idle streaming-session KV held outside the radix tree.
|
||||||
|
kv_full_retract_flag = not batch.check_decode_mem()
|
||||||
|
idle_session_eviction = None
|
||||||
|
if kv_full_retract_flag:
|
||||||
|
idle_session_eviction = self.maybe_trim_decode_session_cache(
|
||||||
|
required_tokens=max(1, self.server_args.num_reserved_decode_tokens),
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
if idle_session_eviction["freed_tokens"] > 0:
|
||||||
|
kv_full_retract_flag = not batch.check_decode_mem()
|
||||||
|
|
||||||
|
if kv_full_retract_flag or (
|
||||||
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
||||||
):
|
):
|
||||||
old_available_tokens = self.token_to_kv_pool_allocator.available_size()
|
old_available_tokens = self.token_to_kv_pool_allocator.available_size()
|
||||||
@@ -2602,6 +2710,13 @@ class Scheduler(
|
|||||||
msg_details += (
|
msg_details += (
|
||||||
f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
f", #new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
||||||
)
|
)
|
||||||
|
if idle_session_eviction is not None:
|
||||||
|
msg_details += (
|
||||||
|
", #idle_session_evicted: "
|
||||||
|
f"{idle_session_eviction['evicted_session_count']}"
|
||||||
|
", #idle_tokens_freed: "
|
||||||
|
f"{idle_session_eviction['freed_tokens']}"
|
||||||
|
)
|
||||||
logger.warning(msg_prefix + msg_details)
|
logger.warning(msg_prefix + msg_details)
|
||||||
|
|
||||||
for req in retracted_reqs:
|
for req in retracted_reqs:
|
||||||
@@ -2916,6 +3031,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode)
|
# Waiting queues: waiting + bootstrapping + preallocation + kv transfer (decode)
|
||||||
idle &= len(self.waiting_queue) == 0
|
idle &= len(self.waiting_queue) == 0
|
||||||
|
idle &= len(self.decode_direct_waiting_queue) == 0
|
||||||
|
|
||||||
if not for_health_check:
|
if not for_health_check:
|
||||||
# Grammar queue and prefill inflight queue may not produce batch
|
# Grammar queue and prefill inflight queue may not produce batch
|
||||||
@@ -3077,6 +3193,9 @@ class Scheduler(
|
|||||||
"graph": round(self.tp_worker.model_runner.graph_mem_usage, 2),
|
"graph": round(self.tp_worker.model_runner.graph_mem_usage, 2),
|
||||||
}
|
}
|
||||||
ret["effective_max_running_requests_per_dp"] = self.max_running_requests
|
ret["effective_max_running_requests_per_dp"] = self.max_running_requests
|
||||||
|
ret["session_cache"] = (
|
||||||
|
self.session_controller.get_streaming_session_cache_status()
|
||||||
|
)
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
|
if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
|
||||||
ret["avg_spec_accept_length"] = (
|
ret["avg_spec_accept_length"] = (
|
||||||
@@ -3375,6 +3494,86 @@ class Scheduler(
|
|||||||
def close_session(self, recv_req: CloseSessionReqInput):
|
def close_session(self, recv_req: CloseSessionReqInput):
|
||||||
self.session_controller.close(recv_req)
|
self.session_controller.close(recv_req)
|
||||||
|
|
||||||
|
def admit_direct_append(
|
||||||
|
self, recv_req: DirectAppendAdmissionReqInput
|
||||||
|
) -> DirectAppendAdmissionReqOutput:
|
||||||
|
if (
|
||||||
|
self.disaggregation_mode != DisaggregationMode.DECODE
|
||||||
|
or not self.server_args.disaggregation_decode_allow_local_prefill
|
||||||
|
or not isinstance(self.tree_cache, SessionAwareCache)
|
||||||
|
):
|
||||||
|
return DirectAppendAdmissionReqOutput(
|
||||||
|
can_admit=False,
|
||||||
|
resident=False,
|
||||||
|
reason="unsupported",
|
||||||
|
)
|
||||||
|
|
||||||
|
session_cache_status = self.session_controller.get_streaming_session_cache_status(
|
||||||
|
recv_req.session_id
|
||||||
|
)
|
||||||
|
target_session = session_cache_status.get("target_session")
|
||||||
|
resident = bool(
|
||||||
|
isinstance(target_session, dict) and target_session.get("resident")
|
||||||
|
)
|
||||||
|
if not resident:
|
||||||
|
return DirectAppendAdmissionReqOutput(
|
||||||
|
can_admit=False,
|
||||||
|
resident=False,
|
||||||
|
reason="session-not-resident",
|
||||||
|
available_tokens_before=int(
|
||||||
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
),
|
||||||
|
available_tokens_after=int(
|
||||||
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
),
|
||||||
|
token_usage=(
|
||||||
|
1.0
|
||||||
|
- self.token_to_kv_pool_allocator.available_size()
|
||||||
|
/ max(1, self.max_total_num_tokens)
|
||||||
|
),
|
||||||
|
num_running_reqs=len(self.running_batch.reqs),
|
||||||
|
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
|
||||||
|
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
|
||||||
|
decode_retracted_queue_reqs=len(
|
||||||
|
self.disagg_decode_prealloc_queue.retracted_queue
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
required_tokens = max(0, recv_req.uncached_input_tokens) + max(
|
||||||
|
0, recv_req.output_tokens
|
||||||
|
)
|
||||||
|
available_tokens_before = int(self.token_to_kv_pool_allocator.available_size())
|
||||||
|
trim_result = self.maybe_trim_decode_session_cache(
|
||||||
|
required_tokens=required_tokens,
|
||||||
|
force=available_tokens_before < required_tokens,
|
||||||
|
exclude_session_ids={recv_req.session_id},
|
||||||
|
)
|
||||||
|
available_tokens_after = int(self.token_to_kv_pool_allocator.available_size())
|
||||||
|
decode_retracted_queue_reqs = len(self.disagg_decode_prealloc_queue.retracted_queue)
|
||||||
|
can_admit = (
|
||||||
|
decode_retracted_queue_reqs == 0
|
||||||
|
and available_tokens_after >= required_tokens
|
||||||
|
)
|
||||||
|
reason = None if can_admit else "no-space"
|
||||||
|
|
||||||
|
return DirectAppendAdmissionReqOutput(
|
||||||
|
can_admit=can_admit,
|
||||||
|
resident=True,
|
||||||
|
reason=reason,
|
||||||
|
required_tokens=int(required_tokens),
|
||||||
|
available_tokens_before=available_tokens_before,
|
||||||
|
available_tokens_after=available_tokens_after,
|
||||||
|
evicted_session_count=int(trim_result["evicted_session_count"]),
|
||||||
|
freed_tokens=int(trim_result["freed_tokens"]),
|
||||||
|
token_usage=(
|
||||||
|
1.0 - available_tokens_after / max(1, self.max_total_num_tokens)
|
||||||
|
),
|
||||||
|
num_running_reqs=len(self.running_batch.reqs),
|
||||||
|
decode_prealloc_queue_reqs=len(self.disagg_decode_prealloc_queue.queue),
|
||||||
|
decode_transfer_queue_reqs=len(self.disagg_decode_transfer_queue.queue),
|
||||||
|
decode_retracted_queue_reqs=decode_retracted_queue_reqs,
|
||||||
|
)
|
||||||
|
|
||||||
def maybe_sleep_on_idle(self):
|
def maybe_sleep_on_idle(self):
|
||||||
if self.idle_sleeper is not None:
|
if self.idle_sleeper is not None:
|
||||||
self.idle_sleeper.maybe_sleep()
|
self.idle_sleeper.maybe_sleep()
|
||||||
|
|||||||
@@ -229,6 +229,12 @@ class Session:
|
|||||||
priority=req.priority,
|
priority=req.priority,
|
||||||
routing_key=req.routing_key,
|
routing_key=req.routing_key,
|
||||||
http_worker_ipc=req.http_worker_ipc,
|
http_worker_ipc=req.http_worker_ipc,
|
||||||
|
bootstrap_host=req.bootstrap_host,
|
||||||
|
bootstrap_port=req.bootstrap_port,
|
||||||
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
routed_dp_rank=req.routed_dp_rank,
|
||||||
|
disagg_prefill_dp_rank=req.disagg_prefill_dp_rank,
|
||||||
|
extra_key=req.extra_key,
|
||||||
time_stats=req.time_stats,
|
time_stats=req.time_stats,
|
||||||
)
|
)
|
||||||
if last_req is not None:
|
if last_req is not None:
|
||||||
@@ -284,7 +290,7 @@ class SessionController:
|
|||||||
else:
|
else:
|
||||||
self._close(session_id)
|
self._close(session_id)
|
||||||
|
|
||||||
def _close(self, session_id: str):
|
def _close(self, session_id: str) -> int:
|
||||||
session = self.sessions[session_id]
|
session = self.sessions[session_id]
|
||||||
if session.streaming and session.req_nodes:
|
if session.streaming and session.req_nodes:
|
||||||
assert len(session.req_nodes) == 1
|
assert len(session.req_nodes) == 1
|
||||||
@@ -303,9 +309,166 @@ class SessionController:
|
|||||||
mm.release_features()
|
mm.release_features()
|
||||||
node.req.multimodal_inputs = None
|
node.req.multimodal_inputs = None
|
||||||
|
|
||||||
|
freed_tokens = 0
|
||||||
if isinstance(self.tree_cache, SessionAwareCache):
|
if isinstance(self.tree_cache, SessionAwareCache):
|
||||||
self.tree_cache.release_session(session_id)
|
freed_tokens = self.tree_cache.release_session(session_id)
|
||||||
del self.sessions[session_id]
|
del self.sessions[session_id]
|
||||||
|
return freed_tokens
|
||||||
|
|
||||||
|
def _is_idle_streaming_session(self, session: Session) -> bool:
|
||||||
|
if not session.streaming:
|
||||||
|
return False
|
||||||
|
if not session.req_nodes:
|
||||||
|
return False
|
||||||
|
return all(node.req.finished() for node in session.req_nodes.values())
|
||||||
|
|
||||||
|
def get_streaming_session_cache_status(
|
||||||
|
self, session_id: Optional[str] = None
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
if not isinstance(self.tree_cache, SessionAwareCache):
|
||||||
|
return {
|
||||||
|
"enabled": False,
|
||||||
|
"eviction_policy": None,
|
||||||
|
"session_count": 0,
|
||||||
|
"resident_session_count": 0,
|
||||||
|
"held_tokens": 0,
|
||||||
|
"available_tokens": 0,
|
||||||
|
"capacity_tokens": 0,
|
||||||
|
"idle_evictable_session_count": 0,
|
||||||
|
"idle_evictable_tokens": 0,
|
||||||
|
"sessions": [],
|
||||||
|
"target_session": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
all_statuses = self.tree_cache.list_session_statuses()
|
||||||
|
session_map = {status["session_id"]: status for status in all_statuses}
|
||||||
|
idle_evictable_tokens = 0
|
||||||
|
idle_evictable_session_count = 0
|
||||||
|
sessions = []
|
||||||
|
for status in all_statuses:
|
||||||
|
session = self.sessions.get(status["session_id"])
|
||||||
|
idle_evictable = session is not None and self._is_idle_streaming_session(
|
||||||
|
session
|
||||||
|
)
|
||||||
|
if idle_evictable and status["resident"]:
|
||||||
|
idle_evictable_session_count += 1
|
||||||
|
idle_evictable_tokens += int(status["resident_tokens"])
|
||||||
|
if session_id is None or status["session_id"] == session_id:
|
||||||
|
sessions.append(
|
||||||
|
{
|
||||||
|
**status,
|
||||||
|
"streaming": bool(session.streaming) if session is not None else True,
|
||||||
|
"idle_evictable": idle_evictable,
|
||||||
|
"timed_out": session.is_timed_out() if session is not None else False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_session = session_map.get(session_id) if session_id is not None else None
|
||||||
|
return {
|
||||||
|
"enabled": True,
|
||||||
|
"eviction_policy": "lru",
|
||||||
|
"session_count": len(all_statuses),
|
||||||
|
"resident_session_count": sum(1 for status in all_statuses if status["resident"]),
|
||||||
|
"held_tokens": int(self.tree_cache.session_held_tokens()),
|
||||||
|
"available_tokens": int(self.tree_cache.token_to_kv_pool_allocator.available_size()),
|
||||||
|
"capacity_tokens": int(self.tree_cache.token_to_kv_pool_allocator.size),
|
||||||
|
"idle_evictable_session_count": idle_evictable_session_count,
|
||||||
|
"idle_evictable_tokens": int(idle_evictable_tokens),
|
||||||
|
"sessions": sessions,
|
||||||
|
"target_session": (
|
||||||
|
{
|
||||||
|
**target_session,
|
||||||
|
"idle_evictable": (
|
||||||
|
self._is_idle_streaming_session(self.sessions[target_session["session_id"]])
|
||||||
|
if target_session is not None
|
||||||
|
and target_session["session_id"] in self.sessions
|
||||||
|
else False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
if target_session is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def evict_idle_streaming_sessions(
|
||||||
|
self,
|
||||||
|
required_tokens: int = 0,
|
||||||
|
max_sessions: Optional[int] = None,
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
return self.evict_idle_streaming_sessions_lru(
|
||||||
|
required_tokens=required_tokens,
|
||||||
|
max_sessions=max_sessions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def evict_idle_streaming_sessions_lru(
|
||||||
|
self,
|
||||||
|
required_tokens: int = 0,
|
||||||
|
min_available_tokens: int = 0,
|
||||||
|
target_available_tokens: Optional[int] = None,
|
||||||
|
max_sessions: Optional[int] = None,
|
||||||
|
exclude_session_ids: Optional[set[str]] = None,
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
if not isinstance(self.tree_cache, SessionAwareCache):
|
||||||
|
return {
|
||||||
|
"evicted_session_count": 0,
|
||||||
|
"freed_tokens": 0,
|
||||||
|
"available_tokens_before": 0,
|
||||||
|
"available_tokens_after": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
available_tokens_before = int(
|
||||||
|
self.tree_cache.token_to_kv_pool_allocator.available_size()
|
||||||
|
)
|
||||||
|
evicted_session_count = 0
|
||||||
|
freed_tokens = 0
|
||||||
|
target_available_tokens = max(
|
||||||
|
int(min_available_tokens),
|
||||||
|
int(target_available_tokens or 0),
|
||||||
|
)
|
||||||
|
candidates = []
|
||||||
|
for status in self.tree_cache.list_session_statuses():
|
||||||
|
if exclude_session_ids and status["session_id"] in exclude_session_ids:
|
||||||
|
continue
|
||||||
|
session = self.sessions.get(status["session_id"])
|
||||||
|
if session is None or not self._is_idle_streaming_session(session):
|
||||||
|
continue
|
||||||
|
if not status["resident"]:
|
||||||
|
continue
|
||||||
|
candidates.append(status)
|
||||||
|
|
||||||
|
for status in candidates:
|
||||||
|
if max_sessions is not None and evicted_session_count >= max_sessions:
|
||||||
|
break
|
||||||
|
available_tokens_now = available_tokens_before + freed_tokens
|
||||||
|
if (
|
||||||
|
required_tokens > 0
|
||||||
|
and freed_tokens >= required_tokens
|
||||||
|
and available_tokens_now >= min_available_tokens
|
||||||
|
):
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
required_tokens <= 0
|
||||||
|
and min_available_tokens <= 0
|
||||||
|
and target_available_tokens <= 0
|
||||||
|
):
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
required_tokens <= 0
|
||||||
|
and available_tokens_now >= min_available_tokens
|
||||||
|
and available_tokens_now >= target_available_tokens
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
freed = self._close(status["session_id"])
|
||||||
|
evicted_session_count += 1
|
||||||
|
freed_tokens += int(freed)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"evicted_session_count": evicted_session_count,
|
||||||
|
"freed_tokens": freed_tokens,
|
||||||
|
"available_tokens_before": available_tokens_before,
|
||||||
|
"available_tokens_after": available_tokens_before + freed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
def maybe_reap(self, now: float, interval: float = 1.0):
|
def maybe_reap(self, now: float, interval: float = 1.0):
|
||||||
# reap sessions every second
|
# reap sessions every second
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ClearHiCacheReqInput,
|
ClearHiCacheReqInput,
|
||||||
ClearHiCacheReqOutput,
|
ClearHiCacheReqOutput,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
|
DirectAppendAdmissionReqInput,
|
||||||
|
DirectAppendAdmissionReqOutput,
|
||||||
DestroyWeightsUpdateGroupReqInput,
|
DestroyWeightsUpdateGroupReqInput,
|
||||||
DestroyWeightsUpdateGroupReqOutput,
|
DestroyWeightsUpdateGroupReqOutput,
|
||||||
DetachHiCacheStorageReqInput,
|
DetachHiCacheStorageReqInput,
|
||||||
@@ -220,6 +222,9 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.get_internal_state_communicator = _Communicator(
|
self.get_internal_state_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.direct_append_admission_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.set_internal_state_communicator = _Communicator(
|
self.set_internal_state_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -316,6 +321,10 @@ class TokenizerCommunicatorMixin:
|
|||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
self.get_internal_state_communicator.handle_recv,
|
self.get_internal_state_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
DirectAppendAdmissionReqOutput,
|
||||||
|
self.direct_append_admission_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
SetInternalStateReqOutput,
|
SetInternalStateReqOutput,
|
||||||
self.set_internal_state_communicator.handle_recv,
|
self.set_internal_state_communicator.handle_recv,
|
||||||
@@ -871,6 +880,16 @@ class TokenizerCommunicatorMixin:
|
|||||||
# Many DP ranks
|
# Many DP ranks
|
||||||
return [res.internal_state for res in responses]
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
|
async def admit_direct_append(
|
||||||
|
self: TokenizerManager,
|
||||||
|
obj: DirectAppendAdmissionReqInput,
|
||||||
|
) -> DirectAppendAdmissionReqOutput:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
responses: List[DirectAppendAdmissionReqOutput] = (
|
||||||
|
await self.direct_append_admission_communicator(obj)
|
||||||
|
)
|
||||||
|
return responses[0]
|
||||||
|
|
||||||
async def set_internal_state(
|
async def set_internal_state(
|
||||||
self: TokenizerManager, obj: SetInternalStateReq
|
self: TokenizerManager, obj: SetInternalStateReq
|
||||||
) -> List[bool]:
|
) -> List[bool]:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
@@ -57,6 +58,7 @@ class SessionSlot:
|
|||||||
mamba_next_track_idx: Any = None
|
mamba_next_track_idx: Any = None
|
||||||
mamba_last_track_seqlen: Any = None
|
mamba_last_track_seqlen: Any = None
|
||||||
mamba_branching_seqlen: Any = None
|
mamba_branching_seqlen: Any = None
|
||||||
|
last_access_time: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_holding_kv(self) -> bool:
|
def is_holding_kv(self) -> bool:
|
||||||
@@ -80,6 +82,7 @@ class SessionSlot:
|
|||||||
self.mamba_next_track_idx = req.mamba_next_track_idx
|
self.mamba_next_track_idx = req.mamba_next_track_idx
|
||||||
self.mamba_last_track_seqlen = req.mamba_last_track_seqlen
|
self.mamba_last_track_seqlen = req.mamba_last_track_seqlen
|
||||||
self.mamba_branching_seqlen = req.mamba_branching_seqlen
|
self.mamba_branching_seqlen = req.mamba_branching_seqlen
|
||||||
|
self.last_access_time = time.monotonic()
|
||||||
|
|
||||||
req.req_pool_idx = None
|
req.req_pool_idx = None
|
||||||
req.mamba_pool_idx = None
|
req.mamba_pool_idx = None
|
||||||
@@ -97,6 +100,7 @@ class SessionSlot:
|
|||||||
req.mamba_next_track_idx = self.mamba_next_track_idx
|
req.mamba_next_track_idx = self.mamba_next_track_idx
|
||||||
req.mamba_last_track_seqlen = self.mamba_last_track_seqlen
|
req.mamba_last_track_seqlen = self.mamba_last_track_seqlen
|
||||||
req.mamba_branching_seqlen = self.mamba_branching_seqlen
|
req.mamba_branching_seqlen = self.mamba_branching_seqlen
|
||||||
|
self.last_access_time = time.monotonic()
|
||||||
|
|
||||||
# NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared
|
# NOTE: req_pool_idx and mamba_pool_idx are intentionally NOT cleared
|
||||||
# from the slot. During chunked prefill, a request may be rejected by
|
# from the slot. During chunked prefill, a request may be rejected by
|
||||||
@@ -247,7 +251,7 @@ class SessionAwareCache(BasePrefixCache):
|
|||||||
"""Release all KV resources held by a streaming session."""
|
"""Release all KV resources held by a streaming session."""
|
||||||
slot = self.slots.pop(session_id, None)
|
slot = self.slots.pop(session_id, None)
|
||||||
if slot is None:
|
if slot is None:
|
||||||
return
|
return 0
|
||||||
|
|
||||||
if slot.last_node is not None:
|
if slot.last_node is not None:
|
||||||
if slot.swa_uuid_for_lock is not None:
|
if slot.swa_uuid_for_lock is not None:
|
||||||
@@ -261,20 +265,58 @@ class SessionAwareCache(BasePrefixCache):
|
|||||||
if slot.is_holding_kv:
|
if slot.is_holding_kv:
|
||||||
start = slot.cache_protected_len
|
start = slot.cache_protected_len
|
||||||
end = slot.kv_allocated_len
|
end = slot.kv_allocated_len
|
||||||
|
freed_tokens = max(0, end - start)
|
||||||
if start < end:
|
if start < end:
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
slot.req_pool_idx, start:end
|
slot.req_pool_idx, start:end
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
self.req_to_token_pool.free_slots.append(slot.req_pool_idx)
|
self.req_to_token_pool.free_slots.append(slot.req_pool_idx)
|
||||||
|
return freed_tokens
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def slot_held_tokens(self, slot: SessionSlot) -> int:
|
||||||
|
if not slot.is_holding_kv:
|
||||||
|
return 0
|
||||||
|
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
|
||||||
|
return max(0, allocated - slot.cache_protected_len)
|
||||||
|
|
||||||
|
def get_session_status(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
slot = self.slots.get(session_id)
|
||||||
|
if slot is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"resident": slot.is_holding_kv,
|
||||||
|
"resident_tokens": int(self.slot_held_tokens(slot)),
|
||||||
|
"kv_committed_len": int(slot.kv_committed_len),
|
||||||
|
"kv_allocated_len": int(slot.kv_allocated_len),
|
||||||
|
"cache_protected_len": int(slot.cache_protected_len),
|
||||||
|
"last_access_time": float(slot.last_access_time),
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_session_statuses(self) -> list[Dict[str, Any]]:
|
||||||
|
statuses = []
|
||||||
|
for session_id, slot in self.slots.items():
|
||||||
|
statuses.append(
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"resident": slot.is_holding_kv,
|
||||||
|
"resident_tokens": int(self.slot_held_tokens(slot)),
|
||||||
|
"kv_committed_len": int(slot.kv_committed_len),
|
||||||
|
"kv_allocated_len": int(slot.kv_allocated_len),
|
||||||
|
"cache_protected_len": int(slot.cache_protected_len),
|
||||||
|
"last_access_time": float(slot.last_access_time),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
statuses.sort(key=lambda item: item["last_access_time"])
|
||||||
|
return statuses
|
||||||
|
|
||||||
def session_held_tokens(self) -> int:
|
def session_held_tokens(self) -> int:
|
||||||
"""Total KV tokens held by session slots, not tracked by the tree."""
|
"""Total KV tokens held by session slots, not tracked by the tree."""
|
||||||
total = 0
|
total = 0
|
||||||
for slot in self.slots.values():
|
for slot in self.slots.values():
|
||||||
if slot.is_holding_kv:
|
total += self.slot_held_tokens(slot)
|
||||||
allocated = ceil_align(slot.kv_allocated_len, self.page_size)
|
|
||||||
total += allocated - slot.cache_protected_len
|
|
||||||
return total
|
return total
|
||||||
|
|
||||||
def session_held_full_tokens(self) -> int:
|
def session_held_full_tokens(self) -> int:
|
||||||
|
|||||||
@@ -697,6 +697,7 @@ class ServerArgs:
|
|||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
disaggregation_decode_enable_offload_kvcache: bool = False
|
disaggregation_decode_enable_offload_kvcache: bool = False
|
||||||
|
disaggregation_decode_allow_local_prefill: bool = False
|
||||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||||
# FIXME: hack to reduce ITL when decode bs is small
|
# FIXME: hack to reduce ITL when decode bs is small
|
||||||
disaggregation_decode_polling_interval: int = 1
|
disaggregation_decode_polling_interval: int = 1
|
||||||
@@ -5772,6 +5773,14 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable async KV cache offloading on decode server (PD mode).",
|
help="Enable async KV cache offloading on decode server (PD mode).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-decode-allow-local-prefill",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Allow decode workers in PD mode to accept direct local requests "
|
||||||
|
"without bootstrap metadata and run local append-prefill."
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-reserved-decode-tokens",
|
"--num-reserved-decode-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
Reference in New Issue
Block a user