proxy: split session_affinity per mode + vLLM patch self-check (M4, S2)
- Replace the global session_affinity dict with two namespace-isolated ones (combined / prefill) so a session_id never indexes the wrong instance list across mode switches. Keep `session_affinity` as a read-only alias to the combined dict for any existing tooling. - Add a startup _verify_vllm_patch() that scans vllm.v1.core.sched.scheduler.Scheduler for the original `assert req_id in self.requests` line. If the patch was not re-applied after a vLLM upgrade we now print a loud warning at lifespan startup instead of dying mid-experiment on a KV-transfer abort race.
This commit is contained in:
@@ -169,7 +169,12 @@ global_args = None
|
||||
combined_instances: list[InstanceState] = []
|
||||
prefill_instances: list[InstanceState] = []
|
||||
decode_instances: list[InstanceState] = []
|
||||
session_affinity: dict[str, int] = {}
|
||||
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
|
||||
# different instance lists, so a shared dict could mis-route after a mode switch.
|
||||
session_affinity_combined: dict[str, int] = {}
|
||||
session_affinity_prefill: dict[str, int] = {}
|
||||
# Backwards-compat alias used by /stats etc.
|
||||
session_affinity = session_affinity_combined
|
||||
is_pd_sep = False
|
||||
_breakdown_log: list[dict] = []
|
||||
|
||||
@@ -224,11 +229,37 @@ async def _reconcile_loop():
|
||||
inst.active_p_offloads = 0
|
||||
|
||||
|
||||
def _verify_vllm_patch():
|
||||
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
|
||||
|
||||
The patch turns an `assert req_id in self.requests` into a soft warn so
|
||||
that engines do not crash on the KV-transfer abort race (see REPORT
|
||||
§3.x). If somebody upgrades vLLM without re-applying the patch, the
|
||||
assert returns and elastic mode dies under load. Print a loud warning
|
||||
so we catch the regression before the first HEAVY request.
|
||||
"""
|
||||
try:
|
||||
import inspect
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
src = inspect.getsource(Scheduler)
|
||||
if "assert req_id in self.requests" in src:
|
||||
print("WARNING: vLLM scheduler still contains the unpatched "
|
||||
"`assert req_id in self.requests` line; expect engine "
|
||||
"death on KV-transfer abort race. Apply "
|
||||
"patches/0001-fix-kv-transfer-abort-race.patch.")
|
||||
else:
|
||||
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
|
||||
except Exception as exc:
|
||||
print(f"vLLM patch self-check skipped: {exc!r}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global is_pd_sep
|
||||
app.state.ready = asyncio.Event()
|
||||
|
||||
_verify_vllm_patch()
|
||||
|
||||
reconcile_task = asyncio.create_task(_reconcile_loop())
|
||||
|
||||
if global_args.combined:
|
||||
@@ -313,7 +344,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear'
|
||||
picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance
|
||||
best_inst, best_idx = picker(combined_instances, token_ids, session_id,
|
||||
input_length, session_affinity)
|
||||
input_length, session_affinity_combined)
|
||||
cache_hit = best_inst.estimate_cache_hit(token_ids)
|
||||
estimated_new = max(0, input_length - cache_hit)
|
||||
|
||||
@@ -395,7 +426,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
breakdown["cache_hit_tokens"] = cache_hit
|
||||
|
||||
if session_id:
|
||||
session_affinity[session_id] = d_idx
|
||||
session_affinity_combined[session_id] = d_idx
|
||||
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
@@ -509,7 +540,7 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||||
}
|
||||
|
||||
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
||||
input_length, session_affinity)
|
||||
input_length, session_affinity_prefill)
|
||||
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
||||
breakdown["p_inst"] = p_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
|
||||
Reference in New Issue
Block a user