feat(sglang): drop streaming-session reqs with fill_ids < prefix_indices

Fix A from docs/E3_FINDINGS_ZH.md §3. The existing streaming-session
correction at the top of ScheduleBatch.prepare_for_extend zeroes
req.extend_input_len when len(fill_ids) <= len(prefix_indices), but
the per-req invariant later in the same function (assert
seq_len - pre_len == req.extend_input_len) is computed from raw
fill_ids/prefix_indices lengths and has no path to be satisfied
when fill_len < prefix_len. The result is an AssertionError that
crashes the entire decode worker.

Add a pre-filter pass at the start of prepare_for_extend that
detects this state, marks the affected reqs with FINISH_ABORT (so
the client gets an error response instead of the worker hanging),
and drops them from the batch before the correction loop runs. If
all reqs are filtered, populate empty tensor/list state and return
early so downstream model.forward sees a valid no-op batch.

This treats fill_ids < prefix_indices as upstream state
inconsistency that should be reported to the client rather than
silently miscomputed. The narrower invariant after this filter:
prepare_for_extend's body only ever sees streaming-session reqs
where actual_extend_len > 0, which is the regime the existing
correction logic was designed for.

Reproduced by E3 first run on 2026-05-12 02:51:21 UTC (rid
6f4318e93dd543a49dbf19248cfc1e6f, session 1000195, fill_len=6648,
prefix_len=43459) — masked in E1/E2 because the cap-out failure
cascade prevented sessions from accumulating deep enough committed
prefix to trigger the inconsistency.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
tim
2026-05-12 12:12:14 +08:00
parent d40db1f117
commit 986f351365

View File

@@ -1564,6 +1564,74 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DLLM, we use a separate forward mode
self.forward_mode = ForwardMode.DLLM_EXTEND
# Pre-filter pass: drop streaming-session reqs whose committed prefix
# already covers fill_ids. The streaming-session correction below would
# set extend_input_len = max(0, fill_len - prefix_len) = 0 for these
# reqs, but the downstream invariant at the per-req loop
# (`assert seq_len - pre_len == req.extend_input_len`) is computed from
# raw fill_ids/prefix_indices lengths and has no path to be satisfied
# when fill_len < prefix_len. Treat the condition as upstream state
# inconsistency, abort the affected reqs (so the client sees an error
# response instead of the worker crashing), and continue with the
# remaining batch. See docs/E3_FINDINGS_ZH.md for the failure mode
# this guards against.
if self.reqs:
kept_reqs = []
for req in self.reqs:
if (
req.session is not None
and req.session.streaming
and len(req.fill_ids) < len(req.prefix_indices)
):
logger.error(
"Dropping streaming-session req with fill_ids shorter than "
"prefix_indices (rid=%s, session_id=%s, fill_len=%d, "
"prefix_len=%d, kv_committed_len=%d). Upstream state "
"inconsistency would crash prepare_for_extend's invariant; "
"aborting this req. See docs/E3_FINDINGS_ZH.md.",
req.rid,
req.session.session_id,
len(req.fill_ids),
len(req.prefix_indices),
req.kv_committed_len,
)
req.finished_reason = FINISH_ABORT(
message=(
"streaming-session inconsistency: fill_ids "
f"({len(req.fill_ids)}) < prefix_indices "
f"({len(req.prefix_indices)})"
),
)
else:
kept_reqs.append(req)
if len(kept_reqs) != len(self.reqs):
self.reqs = kept_reqs
if not self.reqs:
# Whole batch filtered. Set empty tensor / list state so
# downstream callers (model_runner.forward, batch_result handlers)
# see a valid no-op batch and skip the model pass cleanly.
_pin = is_pin_memory_available(self.device)
empty_long = torch.zeros(0, dtype=torch.int64, pin_memory=_pin).to(
self.device, non_blocking=True
)
empty_int = torch.zeros(0, dtype=torch.int32, pin_memory=_pin).to(
self.device, non_blocking=True
)
self.input_ids = empty_long
self.req_pool_indices = empty_int
self.seq_lens = empty_long
self.seq_lens_cpu = torch.zeros(0, dtype=torch.int64)
self.orig_seq_lens = empty_int
self.prefix_lens = []
self.extend_lens = []
self.extend_num_tokens = 0
self.out_cache_loc = empty_int
self.input_embeds = None
self.multimodal_inputs = []
self.token_type_ids = None
return
# Init tensors
reqs = self.reqs
for req in reqs: