"""Unified cache-aware + token-level load-balanced global scheduler. Supports two modes: --combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer) --prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer) Routing policies (--policy): linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens lmetric: score = P_tokens * BS (LMetric, OSDI'26) P_tokens = pending_prefill_tokens + new_uncached_tokens BS = num_requests (waiting + running) Session affinity: multi-turn sessions stick to same instance (all policies). """ import argparse import asyncio import json import os import time as _time import urllib.parse import uuid from collections import OrderedDict from contextlib import asynccontextmanager from dataclasses import dataclass import httpx MAX_STREAM_RETRIES = 3 RETRY_DELAY_S = 0.5 import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse BLOCK_SIZE = 512 CACHE_HIT_ALPHA = 1.0 @dataclass class Settings: """Runtime-tunable knobs. Populated from argparse in __main__. All routing/offload code reads from the SETTINGS singleton so that CLI overrides survive even when the module is imported as a library (e.g. by tests/) and __main__ does not run. """ prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20) rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks heavy_threshold: int = 20000 overload_factor: float = 2.0 max_offload_inflight: int = 4 cache_gate_ratio: float = 0.0 decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) # --- Patch 6.9: cost-model calibration for unified_v2 --- # Throughput when the engine runs in kv_both mode. Lower than the # pure-decode 7000 tok/s because kv_both adds always-on overhead # (REPORT §3.8 documents ~+16% TPOT vs plain). prefill_throughput_kv_both: float = 4000.0 # Calibrated RDMA transfer cost: base + bandwidth term. # Floor from isolated test ≈ 0.3 s (handshake + scheduler step). # Bandwidth term reflects realized effective throughput, not # theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s # effective on the contended kv_both path. v2 uses this lookup # rather than the constant rdma_overhead_s. rdma_base_overhead_s: float = 0.3 rdma_effective_gb_per_s: float = 2.7 # Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2): # 2 × 48 × 4 × 128 × 2 = 98304 bytes per token. kv_bytes_per_token: int = 98304 # --- unified_v2 gating knobs (relaxed in v2.1 after the v1 0.2% trigger rate) --- # B2 microbench shows TPOT idx 1.9x already at new_tokens=8k and TTFT # idx ~12x; the previous 16k threshold was too conservative and # rejected 88.7% of candidates (window_1_results/v2_breakdown). pd_sep_min_new_tokens: int = 8000 pd_sep_min_decodes_protected: int = 1 # any in-flight work on chosen counts pd_sep_min_src_cache_tokens: int = 4000 # half a block; was 8000 pd_sep_min_extra_cache_tokens: int = 2000 # half a block; was 4000 pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating # Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side). pd_sep_xfer_timeout_s: float = 60.0 SETTINGS = Settings() def estimate_transfer_cost(transfer_bytes: int) -> float: """Calibrated RDMA transfer cost as a function of bytes. Replaces the legacy constant rdma_overhead_s. Calibration sources: - Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py) - Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving ~2.7 GB/s effective on the contended kv_both path. The p90 in that same run is 6.7 s (D-side block reservation + scheduler step delays). v2's cost model uses the *median* — being too pessimistic would suppress all PD-sep triggers. The risk of underestimation is mitigated by the pd_sep_margin_s safety factor. """ base = SETTINGS.rdma_base_overhead_s bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3) return base + bw_term def estimate_same_worker_interference_s( new_tokens: int, num_decodes: int, ) -> float: """Estimated additional latency on `num_decodes` co-located decodes when a `new_tokens`-token prefill runs on the same worker. Derived from B2 microbench (analysis/characterization/window_1_results.md): same-worker prefill of size N steals decode capacity for the prefill's duration. The penalty factor is the fraction of decode steps stolen during the prefill window. For new_tokens < 4k: ~0.2 (chunked prefill leaves room) For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×) For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×) For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked) The cost in seconds is roughly: prefill_duration × penalty × n_decodes, because each affected decode loses ~penalty fraction of its capacity during the prefill window. """ if num_decodes <= 0: return 0.0 prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both if new_tokens < 4000: penalty = 0.2 elif new_tokens < 16000: penalty = 0.5 elif new_tokens < 32000: penalty = 0.8 else: penalty = 0.95 return prefill_dur_s * penalty * num_decodes class InstanceState: def __init__(self, url: str, bootstrap_port: int | None = None): self.url = url self.bootstrap_port = bootstrap_port self.client = httpx.AsyncClient( timeout=None, base_url=url, limits=httpx.Limits(max_connections=None, max_keepalive_connections=None), ) self.ongoing_tokens = 0 self.ongoing_decode_tokens = 0 # subset: tokens in decode phase self.pending_prefill_tokens = 0 # tokens for requests still in prefill self.num_requests = 0 # total in-flight requests (waiting + running) self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others self.engine_id: dict[int, str] = {} self.dp_size = 1 # OrderedDict acts as an LRU keyed by block hash; value is unused. self.cached_blocks: OrderedDict[int, None] = OrderedDict() def estimate_cache_hit(self, token_ids: list[int] | None) -> int: if not token_ids or len(token_ids) < BLOCK_SIZE: return 0 hit = 0 for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) if bh in self.cached_blocks: self.cached_blocks.move_to_end(bh) # LRU touch on hit hit += BLOCK_SIZE else: break return hit def record_prefix(self, token_ids: list[int] | None): if not token_ids: return for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) if bh in self.cached_blocks: self.cached_blocks.move_to_end(bh) else: self.cached_blocks[bh] = None if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks: self.cached_blocks.popitem(last=False) def _p_offload_penalty(inst: InstanceState) -> int: """Penalty for PD-sep mode routing (legacy).""" if inst.active_p_offloads <= 0: return 0 return inst.active_p_offloads * SETTINGS.heavy_threshold def snapshot_workers( instances: list[InstanceState], token_ids: list[int] | None = None, input_length: int = 0, ) -> list[dict]: """Per-worker state at route-decision time. All routing-relevant counters plus the score each policy would have produced for `input_length` if it were dispatched now. Cheap enough to call on every request; B3 hot-spot analysis depends on this being captured per decision. """ snap: list[dict] = [] for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0 new_prefill = max(0, input_length - cache_hit) snap.append({ "idx": i, "url": inst.url, "ongoing_tokens": inst.ongoing_tokens, "ongoing_decode_tokens": inst.ongoing_decode_tokens, "pending_prefill_tokens": inst.pending_prefill_tokens, "num_requests": inst.num_requests, "active_p_offloads": inst.active_p_offloads, "cached_blocks": len(inst.cached_blocks), "cache_hit": cache_hit, "new_prefill": new_prefill, "score_linear": (inst.ongoing_tokens + _p_offload_penalty(inst) - CACHE_HIT_ALPHA * cache_hit), "score_lmetric": (inst.pending_prefill_tokens + new_prefill) * inst.num_requests, }) return snap def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """Session-sticky with load-aware override. Turn 2+: use session affinity UNLESS pinned instance is overloaded or busy with P-role offloads, in which case pick least-loaded. Turn 1: pick instance with best score (load + cache combined). Instances doing P-role offloads get a large penalty to steer WARM/MEDIUM traffic away. """ avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): inst = instances[idx] if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor and inst.active_p_offloads == 0): return inst, idx best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) score = (inst.ongoing_tokens + _p_offload_penalty(inst) - CACHE_HIT_ALPHA * cache_hit) if score < best_score: best_score = score best_idx = i if session_id: affinity[session_id] = best_idx return instances[best_idx], best_idx def pick_instance_load_only( instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int], ) -> tuple[InstanceState, int]: """Pure load balancing: pick instance with fewest in-flight requests. Ignores cache hits and session affinity. Used as a B3 control to isolate the locality contribution of cache-aware policies. """ best_idx = min(range(len(instances)), key=lambda i: instances[i].num_requests) return instances[best_idx], best_idx def pick_instance_sticky( instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int], ) -> tuple[InstanceState, int]: """Hard session affinity: once assigned, never break. First turn of a session picks the instance with the lowest num_requests; subsequent turns always return to the same instance regardless of load. Used as a B3 control to isolate the hot-spot cost of perfect locality. """ if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): return instances[idx], idx best_idx = min(range(len(instances)), key=lambda i: instances[i].num_requests) if session_id: affinity[session_id] = best_idx return instances[best_idx], best_idx def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """LMetric routing: score = P_tokens × BS (OSDI'26). Pure per-request load-based routing, no session affinity (the session_id/affinity args are accepted for signature compatibility with pick_instance/pick_instance_unified_hybrid but ignored). P = pending_prefill_tokens + (input_length - cache_hit) BS = num_requests (current batch size) """ best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) new_prefill = max(0, input_length - cache_hit) p_tokens = inst.pending_prefill_tokens + new_prefill bs = inst.num_requests score = p_tokens * bs if score < best_score: best_score = score best_idx = i return instances[best_idx], best_idx _unified_fallback_rr_counter = 0 def pick_instance_unified_hybrid( instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int], ) -> tuple[InstanceState, int, dict]: """Hybrid routing: high-cache affinity, else LMetric with tie-breaker. Affinity gate (both must hold to stick): - affinity instance cache_hit / input_length > 0.5 - affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor Fallback ordering (when affinity not used): primary: score = P_tokens * BS (LMetric) secondary: new_uncached_tokens (prefer instance with most cache) tertiary: num_requests (prefer least-loaded) quaternary: round-robin (avoid degenerate inst-0 pinning when BS=0 across the board) Returns (chosen, idx, decision_dict). decision_dict carries the review #7 breakdown fields so the caller can merge them verbatim. """ global _unified_fallback_rr_counter n = len(instances) avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0) decision: dict = { "decision": "lmetric_fallback", "affinity_idx": None, "chosen_idx": None, "affinity_cache_hit": None, "affinity_cache_ratio": None, "affinity_num_requests": None, "avg_num_requests": avg_reqs, "fallback_score": None, "tie_break_used": False, } if session_id and session_id in affinity: a_idx = affinity[session_id] if a_idx < n: a_inst = instances[a_idx] a_hit = a_inst.estimate_cache_hit(token_ids) a_ratio = a_hit / max(input_length, 1) decision["affinity_idx"] = a_idx decision["affinity_cache_hit"] = a_hit decision["affinity_cache_ratio"] = a_ratio decision["affinity_num_requests"] = a_inst.num_requests if (a_ratio > 0.5 and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): decision["decision"] = "affinity" decision["chosen_idx"] = a_idx return a_inst, a_idx, decision keys: list[tuple[int, int, int, int]] = [] for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) new_prefill = max(0, input_length - cache_hit) p_tokens = inst.pending_prefill_tokens + new_prefill bs = inst.num_requests score = p_tokens * bs keys.append((score, new_prefill, bs, i)) best_triple = min(k[:3] for k in keys) tied = [k for k in keys if k[:3] == best_triple] if len(tied) > 1: decision["tie_break_used"] = True _unified_fallback_rr_counter += 1 winner = tied[_unified_fallback_rr_counter % len(tied)] else: winner = tied[0] chosen_idx = winner[3] decision["fallback_score"] = winner[0] decision["chosen_idx"] = chosen_idx return instances[chosen_idx], chosen_idx, decision def pick_instance_unified_v2( instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int], ) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]: """unified_v2 = unified hybrid + selective per-request PD-sep trigger. Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`. Stage 2 asks: is there another instance with materially more cache for this request? If yes, would doing prefill on that instance and transferring KV to `chosen` for decode be cheaper than just doing everything on `chosen`? The cost model compares two scenarios in seconds-of-decode-disruption: local: same-worker prefill on chosen of (input - chosen.cache_hit) tokens interferes with chosen.num_decodes co-located decodes. pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens (smaller, because src has more cache) interferes with src.num_decodes co-located decodes, plus we pay RDMA transfer of src.cache_hit blocks to chosen. We migrate only when local cost > pd-sep cost + safety margin AND a set of hard gates (size, cache, decodes) are met. Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None the handler should do local routing on `chosen`. When pd_sep is (src_inst, src_idx) the handler should do prefill-on-src, decode-on-chosen via Mooncake. """ chosen, chosen_idx, decision = pick_instance_unified_hybrid( instances, token_ids, session_id, input_length, affinity) decision["v2_pd_sep"] = False decision["v2_decision"] = "local" decision["v2_reason"] = None if not token_ids: decision["v2_reason"] = "no_token_ids" return chosen, chosen_idx, decision, None chosen_cache_hit = chosen.estimate_cache_hit(token_ids) new_local = max(0, input_length - chosen_cache_hit) # Hard gate 1: prefill must be large enough that interference # outweighs the fixed RDMA setup cost. if new_local < SETTINGS.pd_sep_min_new_tokens: decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})" return chosen, chosen_idx, decision, None # Hard gate 2: chosen must have live decoding work to protect. # v2.1 simplification: pure ongoing_decode_tokens check. The previous # gate combined num_requests and decode_tokens with AND, but # num_requests includes requests still in prefill — adding a prefill # to a chosen that has only its own prefill running doesn't disrupt # any decode, so skipping makes sense. The right semantic is "skip # iff no decode is currently happening on chosen". if chosen.ongoing_decode_tokens == 0: decision["v2_reason"] = ( f"chosen_no_active_decode " f"(num_req={chosen.num_requests} decode_tok={chosen.ongoing_decode_tokens})" ) return chosen, chosen_idx, decision, None # Find best alternative cache source. best_src_idx, best_src_hit = -1, 0 for i, inst in enumerate(instances): if i == chosen_idx: continue h = inst.estimate_cache_hit(token_ids) if h > best_src_hit: best_src_idx, best_src_hit = i, h # Hard gate 3: src must hold meaningful cache. if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens: decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})" return chosen, chosen_idx, decision, None # Hard gate 4: src must hold materially more cache than chosen. if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens: decision["v2_reason"] = ( f"src_not_meaningfully_more_cache " f"(src={best_src_hit} chosen={chosen_cache_hit})" ) return chosen, chosen_idx, decision, None src = instances[best_src_idx] new_src = max(0, input_length - best_src_hit) # Cost-benefit in seconds-of-decode-disruption. cost_local = estimate_same_worker_interference_s( new_local, chosen.num_requests) cost_src_interf = estimate_same_worker_interference_s( new_src, src.num_requests) transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token cost_xfer = estimate_transfer_cost(transfer_bytes) cost_migrate = cost_src_interf + cost_xfer decision["v2_chosen_cache_hit"] = chosen_cache_hit decision["v2_src_idx"] = best_src_idx decision["v2_src_cache_hit"] = best_src_hit decision["v2_new_local"] = new_local decision["v2_new_src"] = new_src decision["v2_cost_local_s"] = cost_local decision["v2_cost_src_interf_s"] = cost_src_interf decision["v2_cost_xfer_s"] = cost_xfer decision["v2_cost_migrate_s"] = cost_migrate if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s: decision["v2_pd_sep"] = True decision["v2_decision"] = "pd_sep" decision["v2_reason"] = ( f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s " f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s" ) return chosen, chosen_idx, decision, (src, best_src_idx) decision["v2_reason"] = ( f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s " f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s" ) return chosen, chosen_idx, decision, None def _extract_output_token_ids_from_sse( buffer: str, chunk: bytes, ) -> tuple[str, list[int]]: """Extract vLLM streaming token_ids while preserving the raw stream.""" buffer += chunk.decode("utf-8", errors="ignore") complete = buffer.endswith("\n") or buffer.endswith("\r") lines = buffer.splitlines() if complete: buffer = "" elif lines: buffer = lines.pop() else: return buffer, [] output_ids: list[int] = [] for line in lines: line = line.strip() if not line.startswith("data:"): continue data = line[5:].strip() if not data or data == "[DONE]": continue try: payload = json.loads(data) except json.JSONDecodeError: continue choices = payload.get("choices", []) for choice in choices: token_ids = choice.get("token_ids") if isinstance(token_ids, list): output_ids.extend( int(t) for t in token_ids if isinstance(t, int) ) return buffer, output_ids def _realized_tokens( prompt_token_ids: list[int] | None, output_token_ids: list[int], ) -> list[int] | None: if prompt_token_ids is None: return None if not output_token_ids: return prompt_token_ids return prompt_token_ids + output_token_ids global_args = None combined_instances: list[InstanceState] = [] prefill_instances: list[InstanceState] = [] decode_instances: list[InstanceState] = [] # Session affinity is namespace-isolated: combined-mode and pd-sep mode index # different instance lists, so a shared dict could mis-route after a mode switch. session_affinity_combined: dict[str, int] = {} session_affinity_prefill: dict[str, int] = {} # Backwards-compat alias used by /stats etc. session_affinity = session_affinity_combined is_pd_sep = False _breakdown_log: list[dict] = [] _worker_state_log: list[dict] = [] async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): for inst in instances: if inst.bootstrap_port is None: continue while True: try: await inst.client.get("/health") except Exception: await asyncio.sleep(1) continue parsed = urllib.parse.urlparse(str(inst.client.base_url)) url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query" resp = await inst.client.get(url) resp.raise_for_status() data = resp.json() for dp_rank, dp_entry in data.items(): inst.engine_id[int(dp_rank)] = dp_entry["engine_id"] inst.dp_size = len(data) print(f"Inited {inst.url} engine_ids={inst.engine_id}") break ready.set() async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None: """Read vLLM's truth: (num_running, num_waiting). Returns None on failure.""" try: resp = await asyncio.wait_for(inst.client.get("/metrics"), timeout=5.0) if resp.status_code != 200: return None text = resp.text except Exception: return None running = 0 waiting = 0 for line in text.splitlines(): if line.startswith("vllm:num_requests_running"): try: running = int(float(line.split()[-1])) except (ValueError, IndexError): pass elif line.startswith("vllm:num_requests_waiting"): try: waiting = int(float(line.split()[-1])) except (ValueError, IndexError): pass return running, waiting async def _reconcile_loop(): """Periodic shadow-state reconciliation against vLLM /metrics truth. The proxy maintains shadow counters (num_requests, ongoing_tokens, pending_prefill_tokens, ongoing_decode_tokens) by incrementing in `_handle_local_request` and decrementing in the generator's finally block. When the generator never enters (client disconnect between StreamingResponse construction and Starlette starting iteration, or Starlette failing before iteration), the decrement never fires and the counter stays elevated forever. Over a long run the shadow accumulates "phantom" load that biases routing decisions away from the affected instance. Two-pass fix: 1. Clamp negatives (defensive; rare in practice). 2. Sample vLLM's actual num_running + num_waiting via /metrics. If the proxy's num_requests has been *higher* than vLLM's truth for two consecutive cycles, reconcile downward to vLLM's count. Two-cycle persistence avoids correcting transient mismatches (e.g., proxy just incremented but vLLM hasn't scheduled the request yet). Cycle period: 30 s. Two-cycle persistence threshold: 60 s of stable drift before correction. """ prev_phantom: dict[str, int] = {} while True: try: await asyncio.sleep(30) except asyncio.CancelledError: return for inst in combined_instances + prefill_instances + decode_instances: # Pass 1: clamp negatives (cheap, always do). if inst.ongoing_tokens < 0: inst.ongoing_tokens = 0 if inst.ongoing_decode_tokens < 0: inst.ongoing_decode_tokens = 0 if inst.pending_prefill_tokens < 0: inst.pending_prefill_tokens = 0 if inst.num_requests < 0: inst.num_requests = 0 if inst.active_p_offloads < 0: inst.active_p_offloads = 0 # Pass 2: detect phantom positives by polling vLLM truth. metrics = await _fetch_vllm_inflight(inst) if metrics is None: continue running, waiting = metrics actual_inflight = running + waiting phantom = inst.num_requests - actual_inflight prev = prev_phantom.get(inst.url, 0) if phantom > 0 and prev > 0: # Drift held across two consecutive cycles (~60 s). # Reconcile shadow to vLLM's truth. old_num = inst.num_requests inst.num_requests = actual_inflight if actual_inflight == 0: # No requests in flight; zero all per-request counters. inst.ongoing_tokens = 0 inst.ongoing_decode_tokens = 0 inst.pending_prefill_tokens = 0 print( f"[reconcile] {inst.url}: phantom drift " f"num_requests {old_num} -> {actual_inflight} " f"(vllm running={running} waiting={waiting})" ) prev_phantom[inst.url] = phantom def _verify_vllm_patch(): """Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch. The patch turns an `assert req_id in self.requests` into a soft warn so that engines do not crash on the KV-transfer abort race (see REPORT §3.x). If somebody upgrades vLLM without re-applying the patch, the assert returns and elastic mode dies under load. Print a loud warning so we catch the regression before the first HEAVY request. """ try: import inspect from vllm.v1.core.sched.scheduler import Scheduler src = inspect.getsource(Scheduler) if "assert req_id in self.requests" in src: print("WARNING: vLLM scheduler still contains the unpatched " "`assert req_id in self.requests` line; expect engine " "death on KV-transfer abort race. Apply " "patches/0001-fix-kv-transfer-abort-race.patch.") else: print("vLLM patch self-check: kv-transfer-abort assert is patched.") except Exception as exc: print(f"vLLM patch self-check skipped: {exc!r}") @asynccontextmanager async def lifespan(app: FastAPI): global is_pd_sep app.state.ready = asyncio.Event() _verify_vllm_patch() reconcile_task = asyncio.create_task(_reconcile_loop()) if global_args.combined: is_pd_sep = False bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else [] for i, url in enumerate(global_args.combined): bp = bp_list[i] if i < len(bp_list) else None combined_instances.append(InstanceState(url, bp)) # Bootstrap combined instances for offload (need engine_ids for KV transfer) policy = getattr(global_args, 'policy', 'linear') # Mooncake-based modes still need bootstrap discovery; NIXL uses # its own UCX side-channel and doesn't go through our proxy # bootstrap path (and unified_nixl_both never PD-seps anyway). needs_bootstrap = ( global_args.offload or policy in ("unified_v2", "unified_kv_both") ) if needs_bootstrap and bp_list: await init_prefill_bootstrap(combined_instances, app.state.ready) elif needs_bootstrap and not bp_list: raise RuntimeError( f"--policy {policy} requires --bootstrap-ports for KV transfer; " "got empty bootstrap list." ) else: app.state.ready.set() policy = getattr(global_args, 'policy', 'linear') print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}") else: is_pd_sep = True for url, bp in global_args.prefill: prefill_instances.append(InstanceState(url, bp)) for url in global_args.decode: decode_instances.append(InstanceState(url)) await init_prefill_bootstrap(prefill_instances, app.state.ready) print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D") yield reconcile_task.cancel() try: await reconcile_task except asyncio.CancelledError: pass for inst in combined_instances + prefill_instances + decode_instances: await inst.client.aclose() app = FastAPI(lifespan=lifespan) @app.post("/v1/completions") async def handle_completions(request: Request): return await _handle(request, "/v1/completions") @app.post("/v1/chat/completions") async def handle_chat(request: Request): return await _handle(request, "/v1/chat/completions") async def _handle(request: Request, api: str): if not app.state.ready.is_set(): raise HTTPException(status_code=503, detail="Service Unavailable") req_data = await request.json() incoming_rid = request.headers.get("X-Request-Id") request_id = incoming_rid or str(uuid.uuid4()) prompt = req_data.get("prompt") token_ids = prompt if isinstance(prompt, list) else None input_length = len(token_ids) if token_ids else 0 session_id = request.headers.get("X-Session-Id") headers = {"X-Request-Id": request_id} api_key = os.environ.get("OPENAI_API_KEY") if api_key: headers["Authorization"] = f"Bearer {api_key}" if is_pd_sep: return await _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers) else: return await _handle_combined(api, req_data, token_ids, input_length, session_id, headers) async def _handle_local_request(api, req_data, headers, token_ids, input_length, chosen: InstanceState, estimated_new: int, breakdown: dict): breakdown.setdefault("route_class", "LOCAL") breakdown.setdefault("routed_to", chosen.url) chosen.ongoing_tokens += input_length chosen.pending_prefill_tokens += estimated_new chosen.num_requests += 1 async def generate(): prefill_done = False sse_buffer = "" output_token_ids: list[int] = [] try: for attempt in range(MAX_STREAM_RETRIES): try: async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( sse_buffer, chunk) output_token_ids.extend(new_output_ids) if not prefill_done: chosen.pending_prefill_tokens -= estimated_new chosen.ongoing_decode_tokens += input_length breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token_unix"] = _time.time() prefill_done = True yield chunk chosen.record_prefix( _realized_tokens(token_ids, output_token_ids)) break except (httpx.ConnectError, httpx.RemoteProtocolError): if prefill_done or attempt >= MAX_STREAM_RETRIES - 1: raise await asyncio.sleep(RETRY_DELAY_S) finally: if not prefill_done: chosen.pending_prefill_tokens -= estimated_new else: chosen.ongoing_decode_tokens -= input_length chosen.ongoing_tokens -= input_length chosen.num_requests -= 1 breakdown["t_done"] = _time.monotonic() breakdown["t_done_unix"] = _time.time() _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="text/event-stream") async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): """Route a /v1/* request among combined (PD-colocated) instances. --policy options: linear: cache_hit-aware load score + sticky session affinity. lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity. unified: hybrid — stick to affinity instance when cache_ratio > 0.5 and it is not overloaded; otherwise fall back to LMetric with a multi-key tie-breaker. PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants both regressed E2E tail). Re-enabling requires a new transfer mechanism. """ policy = getattr(global_args, 'policy', 'linear') t_decision_unix = _time.time() request_id = headers.get("X-Request-Id", "") breakdown: dict = { "request_id": request_id, "session_id": session_id, "input_length": input_length, "t_proxy_recv": _time.monotonic(), "t_decision_unix": t_decision_unix, "policy": policy, } pre_decision_workers = snapshot_workers( combined_instances, token_ids, input_length) pd_sep_v2: tuple[InstanceState, int] | None = None if policy == "lmetric": chosen, best_idx = pick_instance_lmetric( combined_instances, token_ids, session_id, input_length, session_affinity_combined) elif policy == "load_only": chosen, best_idx = pick_instance_load_only( combined_instances, token_ids, session_id, input_length, session_affinity_combined) elif policy == "sticky": chosen, best_idx = pick_instance_sticky( combined_instances, token_ids, session_id, input_length, session_affinity_combined) elif policy in ("unified", "unified_kv_both", "unified_nixl_both"): # unified_kv_both: same picker as `unified`, but the vLLMs are # launched in kv_role=kv_both with MooncakeConnector. Use this # as an isolation control for `unified_v2` so the v2-vs-v1 gap # reflects only the PD-sep branch, not the kv_both always-on # overhead. # unified_nixl_both: identical to unified_kv_both but with # NixlConnector at the vLLM layer. Used to attribute the # kv_both overhead to either Mooncake-specific code or a # generic v1-connector cost. chosen, best_idx, decision = pick_instance_unified_hybrid( combined_instances, token_ids, session_id, input_length, session_affinity_combined) breakdown.update(decision) if session_id: session_affinity_combined[session_id] = best_idx elif policy == "unified_v2": chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2( combined_instances, token_ids, session_id, input_length, session_affinity_combined) breakdown.update(decision) if session_id: session_affinity_combined[session_id] = best_idx else: # linear (default) chosen, best_idx = pick_instance( combined_instances, token_ids, session_id, input_length, session_affinity_combined) chosen_snap = pre_decision_workers[best_idx] cache_hit = chosen_snap["cache_hit"] estimated_new = chosen_snap["new_prefill"] breakdown.update({ "cache_hit": cache_hit, "estimated_new_tokens": estimated_new, "route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2", "routed_to": chosen.url, "chosen_idx": best_idx, "candidate_scores": pre_decision_workers, "chosen_score_linear": chosen_snap["score_linear"], "chosen_score_lmetric": chosen_snap["score_lmetric"], }) _worker_state_log.append({ "t_decision_unix": t_decision_unix, "request_id": request_id, "session_id": session_id, "policy": policy, "chosen_idx": best_idx, "v2_pd_sep": pd_sep_v2 is not None, "workers": pre_decision_workers, }) if pd_sep_v2 is not None: src_inst, src_idx = pd_sep_v2 breakdown["v2_src_url"] = src_inst.url breakdown["v2_src_idx"] = src_idx return await _handle_combined_pd_sep_v2( api, req_data, headers, token_ids, input_length, src_inst, chosen, breakdown, request_id=request_id) return await _handle_local_request( api, req_data, headers, token_ids, input_length, chosen, estimated_new, breakdown) async def _handle_combined_pd_sep_v2( api, req_data, headers, token_ids, input_length, src: InstanceState, dst: InstanceState, breakdown: dict, *, request_id: str, ): """Per-request PD-sep among combined instances (unified_v2 path). src does cached prefill (max_tokens=1) and ships KV to dst via Mooncake; dst pulls KV and decodes. Both instances must run in kv_role=kv_both with bootstrap server enabled. Patch 6.6: the dst streaming call uses a per-chunk read timeout of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails the request instead of hanging for 600 s. """ if src.bootstrap_port is None: raise HTTPException( status_code=500, detail=( "unified_v2 PD-sep triggered but src instance " f"{src.url} has no bootstrap_port; launch with " "kv_role=kv_both and pass --bootstrap-ports" ), ) # Reserve load on both endpoints. src.ongoing_tokens += input_length src.num_requests += 1 dst.ongoing_tokens += input_length dst.num_requests += 1 src_load_held = True dst_load_held = True prefill_data = req_data.copy() prefill_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": f"xfer-{request_id}", } prefill_data["stream"] = False prefill_data["max_tokens"] = 1 prefill_data["min_tokens"] = 1 prefill_data.pop("max_completion_tokens", None) prefill_data.pop("stream_options", None) p_headers = {**headers, "X-data-parallel-rank": "0"} breakdown["t_prefill_sent"] = _time.monotonic() breakdown["t_prefill_sent_unix"] = _time.time() try: resp = await src.client.post(api, json=prefill_data, headers=p_headers) breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() resp.raise_for_status() await resp.aclose() src.record_prefix(token_ids) except Exception as e: breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() breakdown["prefill_error"] = True breakdown["error_detail"] = repr(e)[:300] _breakdown_log.append(breakdown) # Release reservations on failure. src.ongoing_tokens -= input_length src.num_requests -= 1 dst.ongoing_tokens -= input_length dst.num_requests -= 1 raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") finally: if src_load_held: src.ongoing_tokens -= input_length src.num_requests -= 1 src_load_held = False parsed = urllib.parse.urlparse(str(src.client.base_url)) bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}" decode_data = req_data.copy() decode_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, "remote_bootstrap_addr": bootstrap_addr, "remote_engine_id": src.engine_id.get(0, ""), "transfer_id": f"xfer-{request_id}", } breakdown["t_decode_sent"] = _time.monotonic() breakdown["t_decode_sent_unix"] = _time.time() xfer_timeout = httpx.Timeout( connect=10.0, read=SETTINGS.pd_sep_xfer_timeout_s, write=10.0, pool=10.0, ) async def generate(): nonlocal dst_load_held first_token = True sse_buffer = "" output_token_ids: list[int] = [] try: async with dst.client.stream( "POST", api, json=decode_data, headers=headers, timeout=xfer_timeout, ) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( sse_buffer, chunk) output_token_ids.extend(new_output_ids) if first_token: breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token_unix"] = _time.time() first_token = False yield chunk dst.record_prefix(_realized_tokens(token_ids, output_token_ids)) finally: breakdown["t_done"] = _time.monotonic() breakdown["t_done_unix"] = _time.time() if dst_load_held: dst.ongoing_tokens -= input_length dst.num_requests -= 1 dst_load_held = False _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="text/event-stream") async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers): """PD-Sep mode with per-stage breakdown profiling.""" t_decision_unix = _time.time() breakdown = { "request_id": request_id, "session_id": session_id, "input_length": input_length, "t_proxy_recv": _time.monotonic(), "t_decision_unix": t_decision_unix, "policy": "pd_sep", } pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length) pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length) p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id, input_length, session_affinity_prefill) d_idx = min(range(len(decode_instances)), key=lambda i: decode_instances[i].ongoing_tokens) d_inst = decode_instances[d_idx] breakdown["p_inst"] = p_inst.url breakdown["d_inst"] = d_inst.url breakdown["candidate_scores_prefill"] = pre_decision_p breakdown["candidate_scores_decode"] = pre_decision_d breakdown["chosen_p_idx"] = p_idx breakdown["chosen_d_idx"] = d_idx _worker_state_log.append({ "t_decision_unix": t_decision_unix, "request_id": request_id, "session_id": session_id, "policy": "pd_sep", "chosen_p_idx": p_idx, "chosen_d_idx": d_idx, "workers_prefill": pre_decision_p, "workers_decode": pre_decision_d, }) prefill_data = req_data.copy() prefill_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": f"xfer-{request_id}", } prefill_data["stream"] = False prefill_data["max_tokens"] = 1 prefill_data["min_tokens"] = 1 prefill_data.pop("max_completion_tokens", None) prefill_data.pop("stream_options", None) p_headers = {**headers, "X-data-parallel-rank": "0"} p_inst.ongoing_tokens += input_length p_inst.num_requests += 1 breakdown["t_prefill_sent"] = _time.monotonic() breakdown["t_prefill_sent_unix"] = _time.time() try: resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() resp.raise_for_status() await resp.aclose() p_inst.record_prefix(token_ids) except Exception as e: breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() breakdown["prefill_error"] = True _breakdown_log.append(breakdown) raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") finally: p_inst.ongoing_tokens -= input_length p_inst.num_requests -= 1 # Send decode d_inst.ongoing_tokens += input_length d_inst.num_requests += 1 parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}" decode_data = req_data.copy() decode_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, "remote_bootstrap_addr": bootstrap_addr, "remote_engine_id": p_inst.engine_id.get(0, ""), "transfer_id": f"xfer-{request_id}", } breakdown["t_decode_sent"] = _time.monotonic() breakdown["t_decode_sent_unix"] = _time.time() async def generate(): first_token = True sse_buffer = "" output_token_ids: list[int] = [] try: async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( sse_buffer, chunk) output_token_ids.extend(new_output_ids) if first_token: breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token_unix"] = _time.time() first_token = False yield chunk d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) finally: breakdown["t_done"] = _time.monotonic() breakdown["t_done_unix"] = _time.time() d_inst.ongoing_tokens -= input_length d_inst.num_requests -= 1 _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="text/event-stream") @app.get("/breakdown") async def get_breakdown(): """Return per-request breakdown data for analysis.""" return _breakdown_log @app.get("/worker_state") async def get_worker_state(): """Return per-decision worker-state snapshot log (one entry per route decision).""" return _worker_state_log @app.get("/worker_state/latest") async def get_worker_state_latest(): """Return current per-worker state snapshot without recording it.""" if combined_instances: return { "t_unix": _time.time(), "mode": "combined", "workers": snapshot_workers(combined_instances), } return { "t_unix": _time.time(), "mode": "pd_sep", "workers_prefill": snapshot_workers(prefill_instances), "workers_decode": snapshot_workers(decode_instances), } @app.get("/stats") async def get_stats(): """Return per-instance live state for debugging.""" instances = combined_instances or prefill_instances + decode_instances return [{ "url": inst.url, "role": "combined", "ongoing_tokens": inst.ongoing_tokens, "pending_prefill_tokens": inst.pending_prefill_tokens, "ongoing_decode_tokens": inst.ongoing_decode_tokens, "num_requests": inst.num_requests, "active_p_offloads": inst.active_p_offloads, "cached_blocks": len(inst.cached_blocks), } for inst in instances] def parse_args(): p = argparse.ArgumentParser(description="Unified cache-aware global scheduler") p.add_argument("--port", type=int, default=8000) p.add_argument("--host", type=str, default="0.0.0.0") p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs") p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw", help="PD-Sep prefill: URL [bootstrap_port]") p.add_argument("--decode", nargs=1, action="append", dest="decode_raw", help="PD-Sep decode: URL") p.add_argument("--heavy-threshold", type=int, default=20000, help="New tokens threshold for HEAVY classification (adaptive offload)") p.add_argument("--offload", action="store_true", help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)") p.add_argument("--bootstrap-ports", type=str, default="", help="Comma-separated bootstrap ports for combined instances (for offload mode)") p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric", "load_only", "sticky", "unified", "unified_kv_both", "unified_nixl_both", "unified_v2"], help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), " "load_only (B3 control: pure min-num_requests), " "sticky (B3 control: hard session affinity), " "unified (hybrid affinity + LMetric fallback), " "unified_kv_both (unified picker on kv_both Mooncake " "vLLMs; isolation control for unified_v2), " "unified_nixl_both (same as unified_kv_both but using " "NixlConnector instead of MooncakeConnector; isolates " "connector implementation from policy effect), " "or unified_v2 (unified + selective per-request PD-sep " "via Mooncake; requires --bootstrap-ports and " "kv_role=kv_both vLLM launch)") p.add_argument("--overload-factor", type=float, default=2.0, help="Break session affinity when instance load > factor * avg") # The four flags below are accepted for bench.sh backward compatibility but # have no effect after the PD-sep offload path was retired (REPORT §3.9, # commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and # scripts/legacy/*.sh which still pass them through. p.add_argument("--max-offload-inflight", type=int, default=4, help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--offload-mode", type=str, default="cached_prefill", choices=["direct_read", "cached_prefill"], help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--cache-gate-ratio", type=float, default=0.0, help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--decode-iteration-s", type=float, default=0.05, help="[DEPRECATED] PUSH offload retired; no effect") args = p.parse_args() args.prefill = [] if args.prefill_raw: for entry in args.prefill_raw: url = entry[0] bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None args.prefill.append((url, bp)) args.decode = [e[0] for e in (args.decode_raw or [])] if not args.combined and not args.prefill: p.error("Must specify either --combined or --prefill/--decode") return args if __name__ == "__main__": global_args = parse_args() SETTINGS.heavy_threshold = global_args.heavy_threshold SETTINGS.overload_factor = global_args.overload_factor SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % ( SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, getattr(global_args, 'offload', False))) uvicorn.run(app, host=global_args.host, port=global_args.port)