From 52a54e44afe81332131d0da1bae3af2265817a20 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Sat, 23 May 2026 21:12:56 +0800 Subject: [PATCH] proxy: split session_affinity per mode + vLLM patch self-check (M4, S2) - Replace the global session_affinity dict with two namespace-isolated ones (combined / prefill) so a session_id never indexes the wrong instance list across mode switches. Keep `session_affinity` as a read-only alias to the combined dict for any existing tooling. - Add a startup _verify_vllm_patch() that scans vllm.v1.core.sched.scheduler.Scheduler for the original `assert req_id in self.requests` line. If the patch was not re-applied after a vLLM upgrade we now print a loud warning at lifespan startup instead of dying mid-experiment on a KV-transfer abort race. --- scripts/cache_aware_proxy.py | 39 ++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/scripts/cache_aware_proxy.py b/scripts/cache_aware_proxy.py index 8fa62ea..7b4be53 100644 --- a/scripts/cache_aware_proxy.py +++ b/scripts/cache_aware_proxy.py @@ -169,7 +169,12 @@ global_args = None combined_instances: list[InstanceState] = [] prefill_instances: list[InstanceState] = [] decode_instances: list[InstanceState] = [] -session_affinity: dict[str, int] = {} +# Session affinity is namespace-isolated: combined-mode and pd-sep mode index +# different instance lists, so a shared dict could mis-route after a mode switch. +session_affinity_combined: dict[str, int] = {} +session_affinity_prefill: dict[str, int] = {} +# Backwards-compat alias used by /stats etc. +session_affinity = session_affinity_combined is_pd_sep = False _breakdown_log: list[dict] = [] @@ -224,11 +229,37 @@ async def _reconcile_loop(): inst.active_p_offloads = 0 +def _verify_vllm_patch(): + """Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch. + + The patch turns an `assert req_id in self.requests` into a soft warn so + that engines do not crash on the KV-transfer abort race (see REPORT + ยง3.x). If somebody upgrades vLLM without re-applying the patch, the + assert returns and elastic mode dies under load. Print a loud warning + so we catch the regression before the first HEAVY request. + """ + try: + import inspect + from vllm.v1.core.sched.scheduler import Scheduler + src = inspect.getsource(Scheduler) + if "assert req_id in self.requests" in src: + print("WARNING: vLLM scheduler still contains the unpatched " + "`assert req_id in self.requests` line; expect engine " + "death on KV-transfer abort race. Apply " + "patches/0001-fix-kv-transfer-abort-race.patch.") + else: + print("vLLM patch self-check: kv-transfer-abort assert is patched.") + except Exception as exc: + print(f"vLLM patch self-check skipped: {exc!r}") + + @asynccontextmanager async def lifespan(app: FastAPI): global is_pd_sep app.state.ready = asyncio.Event() + _verify_vllm_patch() + reconcile_task = asyncio.create_task(_reconcile_loop()) if global_args.combined: @@ -313,7 +344,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear' picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance best_inst, best_idx = picker(combined_instances, token_ids, session_id, - input_length, session_affinity) + input_length, session_affinity_combined) cache_hit = best_inst.estimate_cache_hit(token_ids) estimated_new = max(0, input_length - cache_hit) @@ -395,7 +426,7 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h breakdown["cache_hit_tokens"] = cache_hit if session_id: - session_affinity[session_id] = d_idx + session_affinity_combined[session_id] = d_idx return await _handle_direct_read_offload( api, req_data, headers, token_ids, input_length, @@ -509,7 +540,7 @@ async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, } p_inst, _ = pick_instance(prefill_instances, token_ids, session_id, - input_length, session_affinity) + input_length, session_affinity_prefill) d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens) breakdown["p_inst"] = p_inst.url breakdown["d_inst"] = d_inst.url