feat(sglang): drop streaming-session reqs with fill_ids < prefix_indices
Fix A from docs/E3_FINDINGS_ZH.md §3. The existing streaming-session correction at the top of ScheduleBatch.prepare_for_extend zeroes req.extend_input_len when len(fill_ids) <= len(prefix_indices), but the per-req invariant later in the same function (assert seq_len - pre_len == req.extend_input_len) is computed from raw fill_ids/prefix_indices lengths and has no path to be satisfied when fill_len < prefix_len. The result is an AssertionError that crashes the entire decode worker. Add a pre-filter pass at the start of prepare_for_extend that detects this state, marks the affected reqs with FINISH_ABORT (so the client gets an error response instead of the worker hanging), and drops them from the batch before the correction loop runs. If all reqs are filtered, populate empty tensor/list state and return early so downstream model.forward sees a valid no-op batch. This treats fill_ids < prefix_indices as upstream state inconsistency that should be reported to the client rather than silently miscomputed. The narrower invariant after this filter: prepare_for_extend's body only ever sees streaming-session reqs where actual_extend_len > 0, which is the regime the existing correction logic was designed for. Reproduced by E3 first run on 2026-05-12 02:51:21 UTC (rid 6f4318e93dd543a49dbf19248cfc1e6f, session 1000195, fill_len=6648, prefix_len=43459) — masked in E1/E2 because the cap-out failure cascade prevented sessions from accumulating deep enough committed prefix to trigger the inconsistency. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1564,6 +1564,74 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# For DLLM, we use a separate forward mode
|
||||
self.forward_mode = ForwardMode.DLLM_EXTEND
|
||||
|
||||
# Pre-filter pass: drop streaming-session reqs whose committed prefix
|
||||
# already covers fill_ids. The streaming-session correction below would
|
||||
# set extend_input_len = max(0, fill_len - prefix_len) = 0 for these
|
||||
# reqs, but the downstream invariant at the per-req loop
|
||||
# (`assert seq_len - pre_len == req.extend_input_len`) is computed from
|
||||
# raw fill_ids/prefix_indices lengths and has no path to be satisfied
|
||||
# when fill_len < prefix_len. Treat the condition as upstream state
|
||||
# inconsistency, abort the affected reqs (so the client sees an error
|
||||
# response instead of the worker crashing), and continue with the
|
||||
# remaining batch. See docs/E3_FINDINGS_ZH.md for the failure mode
|
||||
# this guards against.
|
||||
if self.reqs:
|
||||
kept_reqs = []
|
||||
for req in self.reqs:
|
||||
if (
|
||||
req.session is not None
|
||||
and req.session.streaming
|
||||
and len(req.fill_ids) < len(req.prefix_indices)
|
||||
):
|
||||
logger.error(
|
||||
"Dropping streaming-session req with fill_ids shorter than "
|
||||
"prefix_indices (rid=%s, session_id=%s, fill_len=%d, "
|
||||
"prefix_len=%d, kv_committed_len=%d). Upstream state "
|
||||
"inconsistency would crash prepare_for_extend's invariant; "
|
||||
"aborting this req. See docs/E3_FINDINGS_ZH.md.",
|
||||
req.rid,
|
||||
req.session.session_id,
|
||||
len(req.fill_ids),
|
||||
len(req.prefix_indices),
|
||||
req.kv_committed_len,
|
||||
)
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
message=(
|
||||
"streaming-session inconsistency: fill_ids "
|
||||
f"({len(req.fill_ids)}) < prefix_indices "
|
||||
f"({len(req.prefix_indices)})"
|
||||
),
|
||||
)
|
||||
else:
|
||||
kept_reqs.append(req)
|
||||
if len(kept_reqs) != len(self.reqs):
|
||||
self.reqs = kept_reqs
|
||||
|
||||
if not self.reqs:
|
||||
# Whole batch filtered. Set empty tensor / list state so
|
||||
# downstream callers (model_runner.forward, batch_result handlers)
|
||||
# see a valid no-op batch and skip the model pass cleanly.
|
||||
_pin = is_pin_memory_available(self.device)
|
||||
empty_long = torch.zeros(0, dtype=torch.int64, pin_memory=_pin).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
empty_int = torch.zeros(0, dtype=torch.int32, pin_memory=_pin).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_ids = empty_long
|
||||
self.req_pool_indices = empty_int
|
||||
self.seq_lens = empty_long
|
||||
self.seq_lens_cpu = torch.zeros(0, dtype=torch.int64)
|
||||
self.orig_seq_lens = empty_int
|
||||
self.prefix_lens = []
|
||||
self.extend_lens = []
|
||||
self.extend_num_tokens = 0
|
||||
self.out_cache_loc = empty_int
|
||||
self.input_embeds = None
|
||||
self.multimodal_inputs = []
|
||||
self.token_type_ids = None
|
||||
return
|
||||
|
||||
# Init tensors
|
||||
reqs = self.reqs
|
||||
for req in reqs:
|
||||
|
||||
Reference in New Issue
Block a user