diff --git a/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py b/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py new file mode 100644 index 0000000..27a607b --- /dev/null +++ b/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py @@ -0,0 +1,1770 @@ +"""Unified cache-aware + token-level load-balanced global scheduler. + +Supports two modes: + --combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer) + --prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer) + +Routing policies (--policy): + linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens + lmetric: score = P_tokens * BS (LMetric, OSDI'26) + P_tokens = pending_prefill_tokens + new_uncached_tokens + BS = num_requests (waiting + running) + Session affinity: multi-turn sessions stick to same instance (all policies). +""" + +import argparse +import asyncio +import json +import os +import time as _time +import urllib.parse +import uuid +from collections import OrderedDict, deque +from contextlib import asynccontextmanager +from dataclasses import dataclass + +import httpx + +MAX_STREAM_RETRIES = 3 +RETRY_DELAY_S = 0.5 +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse + +BLOCK_SIZE = 512 +CACHE_HIT_ALPHA = 1.0 + + +@dataclass +class Settings: + """Runtime-tunable knobs. Populated from argparse in __main__. + + All routing/offload code reads from the SETTINGS singleton so that + CLI overrides survive even when the module is imported as a library + (e.g. by tests/) and __main__ does not run. + """ + prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20) + rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost + cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks + heavy_threshold: int = 20000 + overload_factor: float = 2.0 + max_offload_inflight: int = 4 + cache_gate_ratio: float = 0.0 + decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) + + # --- Patch 6.9: cost-model calibration for unified_v2 --- + # Throughput when the engine runs in kv_both mode. Lower than the + # pure-decode 7000 tok/s because kv_both adds always-on overhead + # (REPORT §3.8 documents ~+16% TPOT vs plain). + prefill_throughput_kv_both: float = 4000.0 + # Calibrated RDMA transfer cost: base + bandwidth term. + # Floor from isolated test ≈ 0.3 s (handshake + scheduler step). + # Bandwidth term reflects realized effective throughput, not + # theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s + # effective on the contended kv_both path. v2 uses this lookup + # rather than the constant rdma_overhead_s. + rdma_base_overhead_s: float = 0.3 + rdma_effective_gb_per_s: float = 2.7 + + # Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2): + # 2 × 48 × 4 × 128 × 2 = 98304 bytes per token. + kv_bytes_per_token: int = 98304 + + # --- unified_v2 gating knobs (relaxed in v2.1 after the v1 0.2% trigger rate) --- + # B2 microbench shows TPOT idx 1.9x already at new_tokens=8k and TTFT + # idx ~12x; the previous 16k threshold was too conservative and + # rejected 88.7% of candidates (window_1_results/v2_breakdown). + pd_sep_min_new_tokens: int = 8000 + pd_sep_min_decodes_protected: int = 1 # any in-flight work on chosen counts + pd_sep_min_src_cache_tokens: int = 4000 # half a block; was 8000 + pd_sep_min_extra_cache_tokens: int = 2000 # half a block; was 4000 + pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating + # Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side). + pd_sep_xfer_timeout_s: float = 60.0 + + # --- unified_v3 (offload-decode) gating knobs ----------------------- + # v3 differs from v2 in *direction*: prefill stays on the session- + # affinity host (which holds the prefix cache); decode is migrated to + # a less-loaded target. KV transfer flows prefill_host → decode_target. + # The target doesn't need cache — we're shipping the post-prefill KV + # over anyway. After successful migration the session affinity table + # rotates to decode_target so the *next* turn lands where the KV now + # lives. + v3_min_new_tokens: int = 8000 # same as v2: don't migrate tiny prefills + v3_min_prefill_decode_busy: int = 1 # prefill_host must have ≥ this many concurrent decode tokens to justify migrating + v3_target_load_ratio: float = 0.7 # target.num_requests must be < prefill_host.num_requests × this + v3_min_load_gap: int = 1 # target.num_requests must also be ≤ prefill_host - this (absolute slack) + v3_rotate_affinity: bool = True # after migration, set session affinity to decode_target. + # Empirically False is better — see cache_miss_audit (next turn hits 9.5% + # with rotation vs ~80% without), because delay_free_blocks doesn't + # actually preserve cross-turn KV on decode_target. + v3_prefer_cache_target: bool = True # Mechanism B: among low-load candidates, prefer the one + # with the most prefix cache for this prompt — vLLM's connector + # auto-transfers only the missing portion (verified via + # smoke_partial_transfer: cache-rich dst is 77% faster than + # cold dst at 33k tokens, +512 ext). + # Anti-hotspot: picker scores effective_load = num_requests + (recent + # migrations received within window). Prevents clustering migrations on + # one instance in rapid succession (observed in Mech B run: inst_5 became + # a hotspot via post-rotation tail accumulation). + v3_recent_mig_window_s: float = 10.0 # sliding window + v3_recent_mig_weight: float = 1.0 # how many "virtual requests" each + # recent migration counts as + + # Direction B knob: LMetric fallback adds decode-token penalty to score. + # score = (pending_prefill + new + lmetric_decode_weight * ongoing_decode_tok) * num_req + # Empirical iter-time slope on H100 + Qwen3-30B-A3B: each decode token in + # batch costs ~0.01 prefill-token-equivalent in scheduler time, so 0.01 is + # a reasonable starting weight. Set 0 to disable (original behavior). + lmetric_decode_weight: float = 0.0 + + # --- KV connector selection (governs PD-sep handshake) ------------- + # "mooncake": pre-baked kv_transfer_params (bootstrap_addr+engine_id+transfer_id). + # Requires --bootstrap-ports and vLLMs launched with MooncakeConnector. + # "nixl" : response-forward handshake. src returns kv_transfer_params via + # response body, proxy forwards to dst. Nixl auto-selects transport + # via UCX (CUDA IPC / NVLink on intra-node, RDMA across nodes). + connector_type: str = "mooncake" + + +SETTINGS = Settings() + + +def estimate_transfer_cost(transfer_bytes: int) -> float: + """Calibrated RDMA transfer cost as a function of bytes. + + Replaces the legacy constant rdma_overhead_s. Calibration sources: + - Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py) + - Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows + decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving + ~2.7 GB/s effective on the contended kv_both path. + + The p90 in that same run is 6.7 s (D-side block reservation + + scheduler step delays). v2's cost model uses the *median* — being + too pessimistic would suppress all PD-sep triggers. The risk of + underestimation is mitigated by the pd_sep_margin_s safety factor. + """ + base = SETTINGS.rdma_base_overhead_s + bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3) + return base + bw_term + + +def estimate_same_worker_interference_s( + new_tokens: int, + num_decodes: int, +) -> float: + """Estimated additional latency on `num_decodes` co-located decodes + when a `new_tokens`-token prefill runs on the same worker. + + Derived from B2 microbench (analysis/characterization/window_1_results.md): + same-worker prefill of size N steals decode capacity for the + prefill's duration. The penalty factor is the fraction of decode + steps stolen during the prefill window. + + For new_tokens < 4k: ~0.2 (chunked prefill leaves room) + For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×) + For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×) + For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked) + + The cost in seconds is roughly: prefill_duration × penalty × n_decodes, + because each affected decode loses ~penalty fraction of its capacity + during the prefill window. + """ + if num_decodes <= 0: + return 0.0 + prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both + if new_tokens < 4000: + penalty = 0.2 + elif new_tokens < 16000: + penalty = 0.5 + elif new_tokens < 32000: + penalty = 0.8 + else: + penalty = 0.95 + return prefill_dur_s * penalty * num_decodes + + +class InstanceState: + def __init__(self, url: str, bootstrap_port: int | None = None): + self.url = url + self.bootstrap_port = bootstrap_port + self.client = httpx.AsyncClient( + timeout=None, base_url=url, + limits=httpx.Limits(max_connections=None, max_keepalive_connections=None), + ) + self.ongoing_tokens = 0 + self.ongoing_decode_tokens = 0 # subset: tokens in decode phase + self.pending_prefill_tokens = 0 # tokens for requests still in prefill + self.num_requests = 0 # total in-flight requests (waiting + running) + self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others + self.engine_id: dict[int, str] = {} + self.dp_size = 1 + # OrderedDict acts as an LRU keyed by block hash; value is unused. + self.cached_blocks: OrderedDict[int, None] = OrderedDict() + # v3 anti-hotspot: timestamps (monotonic) when this instance was picked + # as a v3 migration target. Used to compute effective_load = num_req + + # recent-migration count over a sliding window, preventing back-to-back + # decisions from clustering on the same dst. + self.recent_mig_targeted_at: deque[float] = deque(maxlen=64) + + def estimate_cache_hit(self, token_ids: list[int] | None) -> int: + if not token_ids or len(token_ids) < BLOCK_SIZE: + return 0 + hit = 0 + for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): + bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) + if bh in self.cached_blocks: + self.cached_blocks.move_to_end(bh) # LRU touch on hit + hit += BLOCK_SIZE + else: + break + return hit + + def record_prefix(self, token_ids: list[int] | None): + if not token_ids: + return + for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): + bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) + if bh in self.cached_blocks: + self.cached_blocks.move_to_end(bh) + else: + self.cached_blocks[bh] = None + if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks: + self.cached_blocks.popitem(last=False) + + +def _p_offload_penalty(inst: InstanceState) -> int: + """Penalty for PD-sep mode routing (legacy).""" + if inst.active_p_offloads <= 0: + return 0 + return inst.active_p_offloads * SETTINGS.heavy_threshold + + +def snapshot_workers( + instances: list[InstanceState], + token_ids: list[int] | None = None, + input_length: int = 0, +) -> list[dict]: + """Per-worker state at route-decision time. + + All routing-relevant counters plus the score each policy would + have produced for `input_length` if it were dispatched now. Cheap + enough to call on every request; B3 hot-spot analysis depends on + this being captured per decision. + """ + snap: list[dict] = [] + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0 + new_prefill = max(0, input_length - cache_hit) + snap.append({ + "idx": i, + "url": inst.url, + "ongoing_tokens": inst.ongoing_tokens, + "ongoing_decode_tokens": inst.ongoing_decode_tokens, + "pending_prefill_tokens": inst.pending_prefill_tokens, + "num_requests": inst.num_requests, + "active_p_offloads": inst.active_p_offloads, + "cached_blocks": len(inst.cached_blocks), + "cache_hit": cache_hit, + "new_prefill": new_prefill, + "score_linear": (inst.ongoing_tokens + + _p_offload_penalty(inst) + - CACHE_HIT_ALPHA * cache_hit), + "score_lmetric": (inst.pending_prefill_tokens + new_prefill) + * inst.num_requests, + }) + return snap + + +def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, + session_id: str | None, input_length: int, + affinity: dict[str, int]) -> tuple[InstanceState, int]: + """Session-sticky with load-aware override. + + Turn 2+: use session affinity UNLESS pinned instance is overloaded + or busy with P-role offloads, in which case pick least-loaded. + Turn 1: pick instance with best score (load + cache combined). + Instances doing P-role offloads get a large penalty to steer + WARM/MEDIUM traffic away. + """ + avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) + + if session_id and session_id in affinity: + idx = affinity[session_id] + if idx < len(instances): + inst = instances[idx] + if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor + and inst.active_p_offloads == 0): + return inst, idx + + best_idx, best_score = 0, float("inf") + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + score = (inst.ongoing_tokens + _p_offload_penalty(inst) + - CACHE_HIT_ALPHA * cache_hit) + if score < best_score: + best_score = score + best_idx = i + + if session_id: + affinity[session_id] = best_idx + return instances[best_idx], best_idx + + +def pick_instance_load_only( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int]: + """Pure load balancing: pick instance with fewest in-flight requests. + + Ignores cache hits and session affinity. Used as a B3 control to + isolate the locality contribution of cache-aware policies. + """ + best_idx = min(range(len(instances)), + key=lambda i: instances[i].num_requests) + return instances[best_idx], best_idx + + +def pick_instance_sticky( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int]: + """Hard session affinity: once assigned, never break. + + First turn of a session picks the instance with the lowest + num_requests; subsequent turns always return to the same instance + regardless of load. Used as a B3 control to isolate the hot-spot + cost of perfect locality. + """ + if session_id and session_id in affinity: + idx = affinity[session_id] + if idx < len(instances): + return instances[idx], idx + best_idx = min(range(len(instances)), + key=lambda i: instances[i].num_requests) + if session_id: + affinity[session_id] = best_idx + return instances[best_idx], best_idx + + +def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None, + session_id: str | None, input_length: int, + affinity: dict[str, int]) -> tuple[InstanceState, int]: + """LMetric routing: score = P_tokens × BS (OSDI'26). + + Pure per-request load-based routing, no session affinity (the + session_id/affinity args are accepted for signature compatibility + with pick_instance/pick_instance_unified_hybrid but ignored). + P = pending_prefill_tokens + (input_length - cache_hit) + BS = num_requests (current batch size) + """ + best_idx, best_score = 0, float("inf") + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + new_prefill = max(0, input_length - cache_hit) + p_tokens = inst.pending_prefill_tokens + new_prefill + bs = inst.num_requests + score = p_tokens * bs + if score < best_score: + best_score = score + best_idx = i + + return instances[best_idx], best_idx + + +_unified_fallback_rr_counter = 0 + + +def pick_instance_unified_hybrid( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int, dict]: + """Hybrid routing: high-cache affinity, else LMetric with tie-breaker. + + Affinity gate (both must hold to stick): + - affinity instance cache_hit / input_length > 0.5 + - affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor + + Fallback ordering (when affinity not used): + primary: score = P_tokens * BS (LMetric) + secondary: new_uncached_tokens (prefer instance with most cache) + tertiary: num_requests (prefer least-loaded) + quaternary: round-robin (avoid degenerate inst-0 pinning + when BS=0 across the board) + + Returns (chosen, idx, decision_dict). decision_dict carries the + review #7 breakdown fields so the caller can merge them verbatim. + """ + global _unified_fallback_rr_counter + n = len(instances) + avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0) + + decision: dict = { + "decision": "lmetric_fallback", + "affinity_idx": None, + "chosen_idx": None, + "affinity_cache_hit": None, + "affinity_cache_ratio": None, + "affinity_num_requests": None, + "avg_num_requests": avg_reqs, + "fallback_score": None, + "tie_break_used": False, + } + + if session_id and session_id in affinity: + a_idx = affinity[session_id] + if a_idx < n: + a_inst = instances[a_idx] + a_hit = a_inst.estimate_cache_hit(token_ids) + a_ratio = a_hit / max(input_length, 1) + decision["affinity_idx"] = a_idx + decision["affinity_cache_hit"] = a_hit + decision["affinity_cache_ratio"] = a_ratio + decision["affinity_num_requests"] = a_inst.num_requests + if (a_ratio > 0.5 + and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): + decision["decision"] = "affinity" + decision["chosen_idx"] = a_idx + return a_inst, a_idx, decision + + # Direction B: extend LMetric with decode-load awareness. + # Original score = (pending_prefill + new_uncached) * num_requests, which + # ignores ongoing decode work. A host with 200k decode tokens looks "ideal" + # (P_tokens=0) but its decode iters are slow due to large batch KV reads. + # + # First attempt (BUG): score = (p_tokens + decode_pen) * num_req — when + # num_req=0 the decode_pen is zeroed out, so idle-but-decoding hosts still + # look free and accumulate cold prefills (8007 hotspot in A+B v1 run). + # + # Fix: max(num_req, 1) so decode_pen contributes on idle hosts too. + keys: list[tuple[float, int, int, int]] = [] + for i, inst in enumerate(instances): + cache_hit = inst.estimate_cache_hit(token_ids) + new_prefill = max(0, input_length - cache_hit) + p_tokens = inst.pending_prefill_tokens + new_prefill + decode_pen = SETTINGS.lmetric_decode_weight * inst.ongoing_decode_tokens + bs = inst.num_requests + score = (p_tokens + decode_pen) * max(bs, 1) + keys.append((score, new_prefill, bs, i)) + + best_triple = min(k[:3] for k in keys) + tied = [k for k in keys if k[:3] == best_triple] + if len(tied) > 1: + decision["tie_break_used"] = True + _unified_fallback_rr_counter += 1 + winner = tied[_unified_fallback_rr_counter % len(tied)] + else: + winner = tied[0] + chosen_idx = winner[3] + decision["fallback_score"] = winner[0] + decision["chosen_idx"] = chosen_idx + return instances[chosen_idx], chosen_idx, decision + + +def pick_instance_unified_v2( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]: + """unified_v2 = unified hybrid + selective per-request PD-sep trigger. + + Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`. + + Stage 2 asks: is there another instance with materially more cache + for this request? If yes, would doing prefill on that instance and + transferring KV to `chosen` for decode be cheaper than just doing + everything on `chosen`? + + The cost model compares two scenarios in seconds-of-decode-disruption: + + local: same-worker prefill on chosen of (input - chosen.cache_hit) + tokens interferes with chosen.num_decodes co-located decodes. + + pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens + (smaller, because src has more cache) interferes with + src.num_decodes co-located decodes, plus we pay RDMA + transfer of src.cache_hit blocks to chosen. + + We migrate only when local cost > pd-sep cost + safety margin AND + a set of hard gates (size, cache, decodes) are met. + + Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None + the handler should do local routing on `chosen`. When pd_sep is + (src_inst, src_idx) the handler should do prefill-on-src, + decode-on-chosen via Mooncake. + """ + chosen, chosen_idx, decision = pick_instance_unified_hybrid( + instances, token_ids, session_id, input_length, affinity) + + decision["v2_pd_sep"] = False + decision["v2_decision"] = "local" + decision["v2_reason"] = None + + if not token_ids: + decision["v2_reason"] = "no_token_ids" + return chosen, chosen_idx, decision, None + + chosen_cache_hit = chosen.estimate_cache_hit(token_ids) + new_local = max(0, input_length - chosen_cache_hit) + + # Hard gate 1: prefill must be large enough that interference + # outweighs the fixed RDMA setup cost. + if new_local < SETTINGS.pd_sep_min_new_tokens: + decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})" + return chosen, chosen_idx, decision, None + + # Hard gate 2: chosen must have live decoding work to protect. + # v2.1 simplification: pure ongoing_decode_tokens check. The previous + # gate combined num_requests and decode_tokens with AND, but + # num_requests includes requests still in prefill — adding a prefill + # to a chosen that has only its own prefill running doesn't disrupt + # any decode, so skipping makes sense. The right semantic is "skip + # iff no decode is currently happening on chosen". + if chosen.ongoing_decode_tokens == 0: + decision["v2_reason"] = ( + f"chosen_no_active_decode " + f"(num_req={chosen.num_requests} decode_tok={chosen.ongoing_decode_tokens})" + ) + return chosen, chosen_idx, decision, None + + # Find best alternative cache source. + best_src_idx, best_src_hit = -1, 0 + for i, inst in enumerate(instances): + if i == chosen_idx: + continue + h = inst.estimate_cache_hit(token_ids) + if h > best_src_hit: + best_src_idx, best_src_hit = i, h + + # Hard gate 3: src must hold meaningful cache. + if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens: + decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})" + return chosen, chosen_idx, decision, None + + # Hard gate 4: src must hold materially more cache than chosen. + if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens: + decision["v2_reason"] = ( + f"src_not_meaningfully_more_cache " + f"(src={best_src_hit} chosen={chosen_cache_hit})" + ) + return chosen, chosen_idx, decision, None + + src = instances[best_src_idx] + new_src = max(0, input_length - best_src_hit) + + # Cost-benefit in seconds-of-decode-disruption. + cost_local = estimate_same_worker_interference_s( + new_local, chosen.num_requests) + cost_src_interf = estimate_same_worker_interference_s( + new_src, src.num_requests) + transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token + cost_xfer = estimate_transfer_cost(transfer_bytes) + cost_migrate = cost_src_interf + cost_xfer + + decision["v2_chosen_cache_hit"] = chosen_cache_hit + decision["v2_src_idx"] = best_src_idx + decision["v2_src_cache_hit"] = best_src_hit + decision["v2_new_local"] = new_local + decision["v2_new_src"] = new_src + decision["v2_cost_local_s"] = cost_local + decision["v2_cost_src_interf_s"] = cost_src_interf + decision["v2_cost_xfer_s"] = cost_xfer + decision["v2_cost_migrate_s"] = cost_migrate + + if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s: + decision["v2_pd_sep"] = True + decision["v2_decision"] = "pd_sep" + decision["v2_reason"] = ( + f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s " + f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s" + ) + return chosen, chosen_idx, decision, (src, best_src_idx) + + decision["v2_reason"] = ( + f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s " + f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s" + ) + return chosen, chosen_idx, decision, None + + +def pick_instance_unified_v3( + instances: list[InstanceState], + token_ids: list[int] | None, + session_id: str | None, + input_length: int, + affinity: dict[str, int], +) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]: + """unified_v3 = unified hybrid + selective DECODE migration. + + Direction-reversed from unified_v2: + - prefill stays on session-affinity host (`prefill_host`) so we keep + the 93%-intra-session prefix-cache reuse intact. + - decode is migrated to a lower-load `decode_target` when the + affinity host is busy with concurrent decodes. + - KV transfer flows prefill_host → decode_target (the opposite of + v2's src → chosen). + - target does NOT need pre-existing cache — we're shipping the + post-prefill KV over anyway. + - On successful migration the *caller* must rotate + `affinity[session_id] = decode_target_idx` so the next turn lands + where the KV now lives (decode_target retains the blocks after + completion, since mooncake defaults to delay_free_blocks=True). + + Decision is purely load-based on the target side: + 1. new_local ≥ v3_min_new_tokens (don't pay RDMA for tiny prefills) + 2. prefill_host.ongoing_decode_tokens ≥ v3_min_prefill_decode_busy + (the host is actually busy decoding; migration buys decode-bw) + 3. ∃ target with both num_requests < prefill_host.num_requests × ratio + and num_requests ≤ prefill_host.num_requests − v3_min_load_gap + + Returns (prefill_host, prefill_idx, decision, migrate). When migrate + is None the request is fully local on prefill_host. When migrate is + (decode_target_inst, decode_target_idx), the handler should run + prefill on prefill_host and ship KV to decode_target for decode. + """ + prefill_host, prefill_idx, decision = pick_instance_unified_hybrid( + instances, token_ids, session_id, input_length, affinity) + + decision["v3_migrate"] = False + decision["v3_decision"] = "local" + decision["v3_reason"] = None + + if not token_ids: + decision["v3_reason"] = "no_token_ids" + return prefill_host, prefill_idx, decision, None + + prefill_cache_hit = prefill_host.estimate_cache_hit(token_ids) + new_local = max(0, input_length - prefill_cache_hit) + decision["v3_prefill_cache_hit"] = prefill_cache_hit + decision["v3_new_local"] = new_local + + # Gate 1: prefill must be large enough to amortise RDMA setup. + if new_local < SETTINGS.v3_min_new_tokens: + decision["v3_reason"] = ( + f"new_local_below_threshold ({new_local} < {SETTINGS.v3_min_new_tokens})" + ) + return prefill_host, prefill_idx, decision, None + + # Gate 2: affinity host must be busy with concurrent decodes — that's + # what migrating decode-traffic-away buys us. If the host is idle + # there's no point. + if prefill_host.ongoing_decode_tokens < SETTINGS.v3_min_prefill_decode_busy: + decision["v3_reason"] = ( + f"prefill_host_not_busy " + f"(ongoing_decode_tokens={prefill_host.ongoing_decode_tokens} < " + f"{SETTINGS.v3_min_prefill_decode_busy})" + ) + return prefill_host, prefill_idx, decision, None + + # Gate 3: pick the lowest-effective-load target. effective_load adds a + # penalty for recent migrations the instance has received (anti-hotspot). + now_mono = _time.monotonic() + cutoff = now_mono - SETTINGS.v3_recent_mig_window_s + + def effective_load(inst): + # Drop expired entries lazily. + while inst.recent_mig_targeted_at and inst.recent_mig_targeted_at[0] < cutoff: + inst.recent_mig_targeted_at.popleft() + recent = len(inst.recent_mig_targeted_at) + return inst.num_requests + recent * SETTINGS.v3_recent_mig_weight + + threshold_loaded = max(1, + int(prefill_host.num_requests * SETTINGS.v3_target_load_ratio)) + candidates = [ + (i, inst) for i, inst in enumerate(instances) + if i != prefill_idx + and effective_load(inst) < threshold_loaded + and effective_load(inst) <= prefill_host.num_requests - SETTINGS.v3_min_load_gap + ] + if not candidates: + decision["v3_reason"] = ( + f"no_low_load_target " + f"(prefill_host.num_req={prefill_host.num_requests} " + f"threshold={threshold_loaded} " + f"eff_loads=[{','.join(f'{int(effective_load(i))}' for i in instances)}])" + ) + return prefill_host, prefill_idx, decision, None + + # Mechanism B (v3_prefer_cache_target=True): rank candidates first by + # cache_hit DESC (more cache = less KV to transfer), then by effective_load + # (which includes recent-migration penalty), then by ongoing_tokens. + if SETTINGS.v3_prefer_cache_target: + decode_target_idx, decode_target = min( + candidates, + key=lambda x: (-x[1].estimate_cache_hit(token_ids), + effective_load(x[1]), + x[1].ongoing_tokens)) + else: + decode_target_idx, decode_target = min( + candidates, key=lambda x: (effective_load(x[1]), x[1].ongoing_tokens)) + + target_cache_hit = decode_target.estimate_cache_hit(token_ids) + target_recent_received = len(decode_target.recent_mig_targeted_at) + # Record this decision for the anti-hotspot accounting. + decode_target.recent_mig_targeted_at.append(now_mono) + + decision["v3_migrate"] = True + decision["v3_decision"] = "migrate_decode" + decision["v3_src_idx"] = prefill_idx + decision["v3_target_idx"] = decode_target_idx + decision["v3_target_num_req"] = decode_target.num_requests + decision["v3_target_cache_hit"] = target_cache_hit + decision["v3_target_recent_received"] = target_recent_received + decision["v3_prefill_num_req"] = prefill_host.num_requests + # Snapshot of src state at the moment of decision (for postmortem). + decision["v3_src_state"] = { + "num_requests": prefill_host.num_requests, + "ongoing_tokens": prefill_host.ongoing_tokens, + "ongoing_decode_tokens": prefill_host.ongoing_decode_tokens, + "pending_prefill_tokens": prefill_host.pending_prefill_tokens, + } + decision["v3_target_state"] = { + "num_requests": decode_target.num_requests, + "ongoing_tokens": decode_target.ongoing_tokens, + "ongoing_decode_tokens": decode_target.ongoing_decode_tokens, + "pending_prefill_tokens": decode_target.pending_prefill_tokens, + "cache_hit_estimate": target_cache_hit, + "recent_mig_received_in_window": target_recent_received, + } + decision["v3_reason"] = ( + f"prefill_host.num_req={prefill_host.num_requests} busy; " + f"target.num_req={decode_target.num_requests} cache_hit={target_cache_hit} " + f"recent_received={target_recent_received}, " + f"transferring KV after prefill" + ) + return prefill_host, prefill_idx, decision, (decode_target, decode_target_idx) + + +def _extract_output_token_ids_from_sse( + buffer: str, + chunk: bytes, +) -> tuple[str, list[int]]: + """Extract vLLM streaming token_ids while preserving the raw stream.""" + buffer += chunk.decode("utf-8", errors="ignore") + complete = buffer.endswith("\n") or buffer.endswith("\r") + lines = buffer.splitlines() + if complete: + buffer = "" + elif lines: + buffer = lines.pop() + else: + return buffer, [] + + output_ids: list[int] = [] + for line in lines: + line = line.strip() + if not line.startswith("data:"): + continue + data = line[5:].strip() + if not data or data == "[DONE]": + continue + try: + payload = json.loads(data) + except json.JSONDecodeError: + continue + choices = payload.get("choices", []) + for choice in choices: + token_ids = choice.get("token_ids") + if isinstance(token_ids, list): + output_ids.extend( + int(t) for t in token_ids if isinstance(t, int) + ) + return buffer, output_ids + + +def _realized_tokens( + prompt_token_ids: list[int] | None, + output_token_ids: list[int], +) -> list[int] | None: + if prompt_token_ids is None: + return None + if not output_token_ids: + return prompt_token_ids + return prompt_token_ids + output_token_ids + + +global_args = None +combined_instances: list[InstanceState] = [] +prefill_instances: list[InstanceState] = [] +decode_instances: list[InstanceState] = [] +# Session affinity is namespace-isolated: combined-mode and pd-sep mode index +# different instance lists, so a shared dict could mis-route after a mode switch. +session_affinity_combined: dict[str, int] = {} +session_affinity_prefill: dict[str, int] = {} +# Backwards-compat alias used by /stats etc. +session_affinity = session_affinity_combined +is_pd_sep = False +_breakdown_log: list[dict] = [] +_worker_state_log: list[dict] = [] + + +async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): + for inst in instances: + if inst.bootstrap_port is None: + continue + while True: + try: + await inst.client.get("/health") + except Exception: + await asyncio.sleep(1) + continue + parsed = urllib.parse.urlparse(str(inst.client.base_url)) + url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query" + resp = await inst.client.get(url) + resp.raise_for_status() + data = resp.json() + for dp_rank, dp_entry in data.items(): + inst.engine_id[int(dp_rank)] = dp_entry["engine_id"] + inst.dp_size = len(data) + print(f"Inited {inst.url} engine_ids={inst.engine_id}") + break + ready.set() + + +async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None: + """Read vLLM's truth: (num_running, num_waiting). Returns None on failure.""" + try: + resp = await asyncio.wait_for(inst.client.get("/metrics"), timeout=5.0) + if resp.status_code != 200: + return None + text = resp.text + except Exception: + return None + running = 0 + waiting = 0 + for line in text.splitlines(): + if line.startswith("vllm:num_requests_running"): + try: + running = int(float(line.split()[-1])) + except (ValueError, IndexError): + pass + elif line.startswith("vllm:num_requests_waiting"): + try: + waiting = int(float(line.split()[-1])) + except (ValueError, IndexError): + pass + return running, waiting + + +async def _reconcile_loop(): + """Periodic shadow-state reconciliation against vLLM /metrics truth. + + The proxy maintains shadow counters (num_requests, ongoing_tokens, + pending_prefill_tokens, ongoing_decode_tokens) by incrementing in + `_handle_local_request` and decrementing in the generator's finally + block. When the generator never enters (client disconnect between + StreamingResponse construction and Starlette starting iteration, or + Starlette failing before iteration), the decrement never fires and + the counter stays elevated forever. Over a long run the shadow + accumulates "phantom" load that biases routing decisions away from + the affected instance. + + Two-pass fix: + + 1. Clamp negatives (defensive; rare in practice). + 2. Sample vLLM's actual num_running + num_waiting via /metrics. If + the proxy's num_requests has been *higher* than vLLM's truth for + two consecutive cycles, reconcile downward to vLLM's count. + Two-cycle persistence avoids correcting transient mismatches + (e.g., proxy just incremented but vLLM hasn't scheduled the + request yet). + + Cycle period: 30 s. Two-cycle persistence threshold: 60 s of stable + drift before correction. + """ + prev_phantom: dict[str, int] = {} + while True: + try: + await asyncio.sleep(30) + except asyncio.CancelledError: + return + for inst in combined_instances + prefill_instances + decode_instances: + # Pass 1: clamp negatives (cheap, always do). + if inst.ongoing_tokens < 0: + inst.ongoing_tokens = 0 + if inst.ongoing_decode_tokens < 0: + inst.ongoing_decode_tokens = 0 + if inst.pending_prefill_tokens < 0: + inst.pending_prefill_tokens = 0 + if inst.num_requests < 0: + inst.num_requests = 0 + if inst.active_p_offloads < 0: + inst.active_p_offloads = 0 + + # Pass 2: detect phantom positives by polling vLLM truth. + metrics = await _fetch_vllm_inflight(inst) + if metrics is None: + continue + running, waiting = metrics + actual_inflight = running + waiting + phantom = inst.num_requests - actual_inflight + prev = prev_phantom.get(inst.url, 0) + if phantom > 0 and prev > 0: + # Drift held across two consecutive cycles (~60 s). + # Reconcile shadow to vLLM's truth. + old_num = inst.num_requests + inst.num_requests = actual_inflight + if actual_inflight == 0: + # No requests in flight; zero all per-request counters. + inst.ongoing_tokens = 0 + inst.ongoing_decode_tokens = 0 + inst.pending_prefill_tokens = 0 + print( + f"[reconcile] {inst.url}: phantom drift " + f"num_requests {old_num} -> {actual_inflight} " + f"(vllm running={running} waiting={waiting})" + ) + prev_phantom[inst.url] = phantom + + +def _verify_vllm_patch(): + """Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch. + + The patch turns an `assert req_id in self.requests` into a soft warn so + that engines do not crash on the KV-transfer abort race (see REPORT + §3.x). If somebody upgrades vLLM without re-applying the patch, the + assert returns and elastic mode dies under load. Print a loud warning + so we catch the regression before the first HEAVY request. + """ + try: + import inspect + from vllm.v1.core.sched.scheduler import Scheduler + src = inspect.getsource(Scheduler) + if "assert req_id in self.requests" in src: + print("WARNING: vLLM scheduler still contains the unpatched " + "`assert req_id in self.requests` line; expect engine " + "death on KV-transfer abort race. Apply " + "patches/0001-fix-kv-transfer-abort-race.patch.") + else: + print("vLLM patch self-check: kv-transfer-abort assert is patched.") + except Exception as exc: + print(f"vLLM patch self-check skipped: {exc!r}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global is_pd_sep + app.state.ready = asyncio.Event() + + _verify_vllm_patch() + + reconcile_task = asyncio.create_task(_reconcile_loop()) + + if global_args.combined: + is_pd_sep = False + bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else [] + for i, url in enumerate(global_args.combined): + bp = bp_list[i] if i < len(bp_list) else None + combined_instances.append(InstanceState(url, bp)) + + # Bootstrap combined instances for offload (need engine_ids for KV transfer) + policy = getattr(global_args, 'policy', 'linear') + # Mooncake-based modes still need bootstrap discovery; NIXL uses + # its own UCX side-channel and doesn't go through our proxy + # bootstrap path. With --connector-type=nixl, v3 also skips bootstrap. + needs_bootstrap = ( + global_args.offload + or (policy in ("unified_v2", "unified_v3", "unified_kv_both") + and getattr(global_args, 'connector_type', 'mooncake') == 'mooncake') + ) + if needs_bootstrap and bp_list: + await init_prefill_bootstrap(combined_instances, app.state.ready) + elif needs_bootstrap and not bp_list: + raise RuntimeError( + f"--policy {policy} requires --bootstrap-ports for KV transfer; " + "got empty bootstrap list." + ) + else: + app.state.ready.set() + + policy = getattr(global_args, 'policy', 'linear') + print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}") + else: + is_pd_sep = True + for url, bp in global_args.prefill: + prefill_instances.append(InstanceState(url, bp)) + for url in global_args.decode: + decode_instances.append(InstanceState(url)) + await init_prefill_bootstrap(prefill_instances, app.state.ready) + print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D") + + yield + reconcile_task.cancel() + try: + await reconcile_task + except asyncio.CancelledError: + pass + for inst in combined_instances + prefill_instances + decode_instances: + await inst.client.aclose() + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + return await _handle(request, "/v1/completions") + + +@app.post("/v1/chat/completions") +async def handle_chat(request: Request): + return await _handle(request, "/v1/chat/completions") + + +async def _handle(request: Request, api: str): + if not app.state.ready.is_set(): + raise HTTPException(status_code=503, detail="Service Unavailable") + + req_data = await request.json() + incoming_rid = request.headers.get("X-Request-Id") + request_id = incoming_rid or str(uuid.uuid4()) + prompt = req_data.get("prompt") + token_ids = prompt if isinstance(prompt, list) else None + input_length = len(token_ids) if token_ids else 0 + session_id = request.headers.get("X-Session-Id") + + headers = {"X-Request-Id": request_id} + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + if is_pd_sep: + return await _handle_pd_sep(api, req_data, request_id, token_ids, + input_length, session_id, headers) + else: + return await _handle_combined(api, req_data, token_ids, + input_length, session_id, headers) + + +async def _handle_local_request(api, req_data, headers, token_ids, input_length, + chosen: InstanceState, estimated_new: int, + breakdown: dict, *, _pre_reserved: bool = False): + breakdown.setdefault("route_class", "LOCAL") + breakdown.setdefault("routed_to", chosen.url) + # Skip reservation when called from _handle_combined (it already reserved + # synchronously to close the picker→await race). When called directly + # from non-combined paths (PD-Sep, offload), reserve here for safety. + if not _pre_reserved: + chosen.ongoing_tokens += input_length + chosen.pending_prefill_tokens += estimated_new + chosen.num_requests += 1 + + async def generate(): + prefill_done = False + sse_buffer = "" + output_token_ids: list[int] = [] + try: + for attempt in range(MAX_STREAM_RETRIES): + try: + async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( + sse_buffer, chunk) + output_token_ids.extend(new_output_ids) + if not prefill_done: + chosen.pending_prefill_tokens -= estimated_new + chosen.ongoing_decode_tokens += input_length + breakdown["t_first_token"] = _time.monotonic() + breakdown["t_first_token_unix"] = _time.time() + prefill_done = True + yield chunk + chosen.record_prefix( + _realized_tokens(token_ids, output_token_ids)) + break + except (httpx.ConnectError, httpx.RemoteProtocolError): + if prefill_done or attempt >= MAX_STREAM_RETRIES - 1: + raise + await asyncio.sleep(RETRY_DELAY_S) + finally: + if not prefill_done: + chosen.pending_prefill_tokens -= estimated_new + else: + chosen.ongoing_decode_tokens -= input_length + chosen.ongoing_tokens -= input_length + chosen.num_requests -= 1 + breakdown["t_done"] = _time.monotonic() + breakdown["t_done_unix"] = _time.time() + _breakdown_log.append(breakdown) + + return StreamingResponse(generate(), media_type="text/event-stream") + + +async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): + """Route a /v1/* request among combined (PD-colocated) instances. + + --policy options: + linear: cache_hit-aware load score + sticky session affinity. + lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity. + unified: hybrid — stick to affinity instance when cache_ratio > 0.5 + and it is not overloaded; otherwise fall back to LMetric + with a multi-key tie-breaker. + + PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and + commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants + both regressed E2E tail). Re-enabling requires a new transfer mechanism. + """ + policy = getattr(global_args, 'policy', 'linear') + t_decision_unix = _time.time() + request_id = headers.get("X-Request-Id", "") + breakdown: dict = { + "request_id": request_id, + "session_id": session_id, + "input_length": input_length, + "t_proxy_recv": _time.monotonic(), + "t_decision_unix": t_decision_unix, + "policy": policy, + } + + pre_decision_workers = snapshot_workers( + combined_instances, token_ids, input_length) + + pd_sep_v2: tuple[InstanceState, int] | None = None + if policy == "lmetric": + chosen, best_idx = pick_instance_lmetric( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + elif policy == "load_only": + chosen, best_idx = pick_instance_load_only( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + elif policy == "sticky": + chosen, best_idx = pick_instance_sticky( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + elif policy in ("unified", "unified_kv_both", "unified_nixl_both"): + # unified_kv_both: same picker as `unified`, but the vLLMs are + # launched in kv_role=kv_both with MooncakeConnector. Use this + # as an isolation control for `unified_v2` so the v2-vs-v1 gap + # reflects only the PD-sep branch, not the kv_both always-on + # overhead. + # unified_nixl_both: identical to unified_kv_both but with + # NixlConnector at the vLLM layer. Used to attribute the + # kv_both overhead to either Mooncake-specific code or a + # generic v1-connector cost. + chosen, best_idx, decision = pick_instance_unified_hybrid( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + breakdown.update(decision) + if session_id: + session_affinity_combined[session_id] = best_idx + elif policy == "unified_v2": + chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + breakdown.update(decision) + if session_id: + session_affinity_combined[session_id] = best_idx + elif policy == "unified_v3": + # v3: prefill on affinity (cache reuse), decode migrated to a + # low-load target. KV flows prefill_host → decode_target. + # Reuses _handle_combined_pd_sep_v2 with src=prefill_host, + # dst=decode_target (the handler is direction-agnostic). + chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v3( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + breakdown.update(decision) + if session_id: + if pd_sep_v2 is not None and SETTINGS.v3_rotate_affinity: + # Migration + rotation: redirect next turn to decode_target, + # assuming KV will live there. (Empirically wrong — see + # cache_miss_audit. Keep behind a flag.) + _decode_target_inst, decode_target_idx = pd_sep_v2 + session_affinity_combined[session_id] = decode_target_idx + else: + # No rotation: keep affinity on prefill_host (where the prefix + # cache lives). This is the empirically correct choice. + session_affinity_combined[session_id] = best_idx + else: # linear (default) + chosen, best_idx = pick_instance( + combined_instances, token_ids, session_id, input_length, + session_affinity_combined) + + chosen_snap = pre_decision_workers[best_idx] + cache_hit = chosen_snap["cache_hit"] + estimated_new = chosen_snap["new_prefill"] + breakdown.update({ + "cache_hit": cache_hit, + "estimated_new_tokens": estimated_new, + "route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2", + "routed_to": chosen.url, + "chosen_idx": best_idx, + "candidate_scores": pre_decision_workers, + "chosen_score_linear": chosen_snap["score_linear"], + "chosen_score_lmetric": chosen_snap["score_lmetric"], + }) + + _worker_state_log.append({ + "t_decision_unix": t_decision_unix, + "request_id": request_id, + "session_id": session_id, + "policy": policy, + "chosen_idx": best_idx, + "v2_pd_sep": pd_sep_v2 is not None, + "workers": pre_decision_workers, + }) + + if pd_sep_v2 is not None: + # Handler contract: first arg = prefill source (does same-worker + # prefill with do_remote_decode=True, max_tokens=1), second arg = + # decode target (does do_remote_prefill=True, pulls KV via + # Mooncake, decodes). + # + # v2 contract: pd_sep_v2 = (src_inst, src_idx); chosen = decode + # → src does prefill (it has more cache), chosen decodes. + # v3 contract: chosen = prefill_host (affinity, has cache); + # pd_sep_v2 = (decode_target_inst, decode_target_idx) + # → chosen does prefill (cache reuse), decode_target decodes. + if policy == "unified_v3": + decode_target_inst, decode_target_idx = pd_sep_v2 + prefill_inst = chosen + breakdown["v2_src_url"] = prefill_inst.url + breakdown["v2_src_idx"] = best_idx + breakdown["v3_decode_target_url"] = decode_target_inst.url + breakdown["v3_decode_target_idx"] = decode_target_idx + return await _handle_combined_pd_sep_v2( + api, req_data, headers, token_ids, input_length, + prefill_inst, decode_target_inst, breakdown, + request_id=request_id) + else: + src_inst, src_idx = pd_sep_v2 + breakdown["v2_src_url"] = src_inst.url + breakdown["v2_src_idx"] = src_idx + return await _handle_combined_pd_sep_v2( + api, req_data, headers, token_ids, input_length, + src_inst, chosen, breakdown, + request_id=request_id) + + # Race fix: reserve load on `chosen` BEFORE the `await` so concurrent + # picker calls in the same asyncio event-loop tick see the updated + # counters. Without this, two requests arriving back-to-back can both + # pick the same "free" instance and both end up running there + # simultaneously (observed as 8007 hotspot in A+B run). + chosen.ongoing_tokens += input_length + chosen.pending_prefill_tokens += estimated_new + chosen.num_requests += 1 + breakdown.setdefault("route_class", "LOCAL") + breakdown.setdefault("routed_to", chosen.url) + return await _handle_local_request( + api, req_data, headers, token_ids, input_length, + chosen, estimated_new, breakdown, _pre_reserved=True) + + +async def _handle_combined_pd_sep_v2( + api, req_data, headers, token_ids, input_length, + src: InstanceState, dst: InstanceState, breakdown: dict, + *, request_id: str, +): + """Per-request PD-sep among combined instances (unified_v2 path). + + src does cached prefill (max_tokens=1) and ships KV to dst via + Mooncake; dst pulls KV and decodes. Both instances must run in + kv_role=kv_both with bootstrap server enabled. + + Patch 6.6: the dst streaming call uses a per-chunk read timeout + of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails + the request instead of hanging for 600 s. + """ + connector = SETTINGS.connector_type + if connector == "mooncake" and src.bootstrap_port is None: + raise HTTPException( + status_code=500, + detail=( + "Mooncake PD-sep triggered but src instance " + f"{src.url} has no bootstrap_port; launch with " + "kv_role=kv_both and pass --bootstrap-ports" + ), + ) + + # Reserve load on both endpoints. + src.ongoing_tokens += input_length + src.num_requests += 1 + dst.ongoing_tokens += input_length + dst.num_requests += 1 + src_load_held = True + dst_load_held = True + + # ---- LAYERWISE write-mode (opt-in EAR_WRITE_MODE=1, mooncake only) ------ + # Dispatch src prefill and dst decode CONCURRENTLY so the dst handshake + # reaches src during its prefill, letting the layer-wise connector push KV + # per-step (overlapped with prefill compute) instead of post-hoc. + if os.environ.get("EAR_WRITE_MODE", "0") == "1" and connector == "mooncake": + return await _handle_pdsep_v2_write_mode( + api, req_data, headers, token_ids, input_length, + src, dst, breakdown, request_id=request_id) + + # Build prefill kv_transfer_params per connector. + prefill_data = req_data.copy() + if connector == "mooncake": + prefill_data["kv_transfer_params"] = { + "do_remote_decode": True, + "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", + } + else: # nixl: src just signals it'll produce KV for remote decode + prefill_data["kv_transfer_params"] = {"do_remote_decode": True} + prefill_data["stream"] = False + prefill_data["max_tokens"] = 1 + prefill_data["min_tokens"] = 1 + prefill_data.pop("max_completion_tokens", None) + prefill_data.pop("stream_options", None) + p_headers = {**headers, "X-data-parallel-rank": "0"} + + breakdown["t_prefill_sent"] = _time.monotonic() + breakdown["t_prefill_sent_unix"] = _time.time() + forwarded_params: dict | None = None + try: + resp = await src.client.post(api, json=prefill_data, headers=p_headers) + breakdown["t_prefill_done"] = _time.monotonic() + breakdown["t_prefill_done_unix"] = _time.time() + resp.raise_for_status() + if connector == "nixl": + # Nixl populates kv_transfer_params in the response body with + # remote_block_ids / remote_engine_id / remote_host / remote_port. + # We must read the body BEFORE aclose. + src_resp_json = resp.json() + forwarded_params = src_resp_json.get("kv_transfer_params") + if not forwarded_params or not forwarded_params.get("remote_block_ids"): + raise HTTPException( + status_code=502, + detail=f"Nixl src returned no remote_block_ids: {forwarded_params}", + ) + await resp.aclose() + src.record_prefix(token_ids) + except Exception as e: + breakdown["t_prefill_done"] = _time.monotonic() + breakdown["t_prefill_done_unix"] = _time.time() + breakdown["prefill_error"] = True + breakdown["error_detail"] = repr(e)[:300] + _breakdown_log.append(breakdown) + # Release reservations on failure. Clear load_held flags so the + # finally block below does not double-decrement (CRITICAL audit #1). + if src_load_held: + src.ongoing_tokens -= input_length + src.num_requests -= 1 + src_load_held = False + if dst_load_held: + dst.ongoing_tokens -= input_length + dst.num_requests -= 1 + dst_load_held = False + raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") + finally: + if src_load_held: + src.ongoing_tokens -= input_length + src.num_requests -= 1 + src_load_held = False + + decode_data = req_data.copy() + if connector == "mooncake": + parsed = urllib.parse.urlparse(str(src.client.base_url)) + bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}" + decode_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_bootstrap_addr": bootstrap_addr, + "remote_engine_id": src.engine_id.get(0, ""), + "transfer_id": f"xfer-{request_id}", + } + else: # nixl: forward what src returned + decode_data["kv_transfer_params"] = forwarded_params + + breakdown["t_decode_sent"] = _time.monotonic() + breakdown["t_decode_sent_unix"] = _time.time() + + xfer_timeout = httpx.Timeout( + connect=10.0, + read=SETTINGS.pd_sep_xfer_timeout_s, + write=10.0, + pool=10.0, + ) + + async def generate(): + nonlocal dst_load_held + first_token = True + sse_buffer = "" + output_token_ids: list[int] = [] + try: + async with dst.client.stream( + "POST", api, json=decode_data, headers=headers, + timeout=xfer_timeout, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( + sse_buffer, chunk) + output_token_ids.extend(new_output_ids) + if first_token: + breakdown["t_first_token"] = _time.monotonic() + breakdown["t_first_token_unix"] = _time.time() + first_token = False + yield chunk + dst.record_prefix(_realized_tokens(token_ids, output_token_ids)) + finally: + breakdown["t_done"] = _time.monotonic() + breakdown["t_done_unix"] = _time.time() + if dst_load_held: + dst.ongoing_tokens -= input_length + dst.num_requests -= 1 + dst_load_held = False + _breakdown_log.append(breakdown) + + return StreamingResponse(generate(), media_type="text/event-stream") + + +async def _handle_pdsep_v2_write_mode( + api, req_data, headers, token_ids, input_length, + src: InstanceState, dst: InstanceState, breakdown: dict, + *, request_id: str, +): + """Write-mode v3 (mooncake + MOONCAKE_LAYERWISE): dispatch src prefill and + dst decode CONCURRENTLY so the dst handshake reaches src during prefill and + KV is pushed per-layer (overlapped) instead of post-hoc. Caller has already + reserved load on both src and dst; we release it in generate()'s finally. + """ + parsed = urllib.parse.urlparse(str(src.client.base_url)) + bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}" + tid = f"xfer-{request_id}" + + prefill_data = req_data.copy() + prefill_data["kv_transfer_params"] = { + "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": tid} + prefill_data["stream"] = False + prefill_data["max_tokens"] = 1 + prefill_data["min_tokens"] = 1 + prefill_data.pop("max_completion_tokens", None) + prefill_data.pop("stream_options", None) + p_headers = {**headers, "X-data-parallel-rank": "0"} + + decode_data = req_data.copy() + decode_data["kv_transfer_params"] = { + "do_remote_decode": False, "do_remote_prefill": True, + "remote_bootstrap_addr": bootstrap_addr, + "remote_engine_id": src.engine_id.get(0, ""), + "transfer_id": tid} + + xfer_timeout = httpx.Timeout( + connect=10.0, read=SETTINGS.pd_sep_xfer_timeout_s, write=10.0, pool=10.0) + + breakdown["write_mode"] = True + breakdown["t_prefill_sent"] = _time.monotonic() + breakdown["t_prefill_sent_unix"] = _time.time() + prefill_task = asyncio.create_task( + src.client.post(api, json=prefill_data, headers=p_headers)) + breakdown["t_decode_sent"] = _time.monotonic() + breakdown["t_decode_sent_unix"] = _time.time() + + async def generate(): + first_token = True + sse_buffer = "" + output_token_ids: list[int] = [] + try: + async with dst.client.stream( + "POST", api, json=decode_data, headers=headers, + timeout=xfer_timeout, + ) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( + sse_buffer, chunk) + output_token_ids.extend(new_output_ids) + if first_token: + breakdown["t_first_token"] = _time.monotonic() + breakdown["t_first_token_unix"] = _time.time() + first_token = False + yield chunk + dst.record_prefix(_realized_tokens(token_ids, output_token_ids)) + finally: + breakdown["t_done"] = _time.monotonic() + breakdown["t_done_unix"] = _time.time() + try: + presp = await prefill_task + breakdown["t_prefill_done"] = _time.monotonic() + breakdown["t_prefill_done_unix"] = _time.time() + presp.raise_for_status() + await presp.aclose() + src.record_prefix(token_ids) + except Exception as e: + breakdown["prefill_error"] = True + breakdown["error_detail"] = repr(e)[:300] + src.ongoing_tokens -= input_length + src.num_requests -= 1 + dst.ongoing_tokens -= input_length + dst.num_requests -= 1 + _breakdown_log.append(breakdown) + + return StreamingResponse(generate(), media_type="text/event-stream") + + +async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, + session_id, headers): + """PD-Sep mode with per-stage breakdown profiling.""" + t_decision_unix = _time.time() + breakdown = { + "request_id": request_id, + "session_id": session_id, + "input_length": input_length, + "t_proxy_recv": _time.monotonic(), + "t_decision_unix": t_decision_unix, + "policy": "pd_sep", + } + + pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length) + pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length) + + p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id, + input_length, session_affinity_prefill) + d_idx = min(range(len(decode_instances)), + key=lambda i: decode_instances[i].ongoing_tokens) + d_inst = decode_instances[d_idx] + breakdown["p_inst"] = p_inst.url + breakdown["d_inst"] = d_inst.url + breakdown["candidate_scores_prefill"] = pre_decision_p + breakdown["candidate_scores_decode"] = pre_decision_d + breakdown["chosen_p_idx"] = p_idx + breakdown["chosen_d_idx"] = d_idx + + _worker_state_log.append({ + "t_decision_unix": t_decision_unix, + "request_id": request_id, + "session_id": session_id, + "policy": "pd_sep", + "chosen_p_idx": p_idx, + "chosen_d_idx": d_idx, + "workers_prefill": pre_decision_p, + "workers_decode": pre_decision_d, + }) + + prefill_data = req_data.copy() + prefill_data["kv_transfer_params"] = { + "do_remote_decode": True, "do_remote_prefill": False, + "transfer_id": f"xfer-{request_id}", + } + prefill_data["stream"] = False + prefill_data["max_tokens"] = 1 + prefill_data["min_tokens"] = 1 + prefill_data.pop("max_completion_tokens", None) + prefill_data.pop("stream_options", None) + p_headers = {**headers, "X-data-parallel-rank": "0"} + + p_inst.ongoing_tokens += input_length + p_inst.num_requests += 1 + breakdown["t_prefill_sent"] = _time.monotonic() + breakdown["t_prefill_sent_unix"] = _time.time() + + try: + resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) + breakdown["t_prefill_done"] = _time.monotonic() + breakdown["t_prefill_done_unix"] = _time.time() + resp.raise_for_status() + await resp.aclose() + p_inst.record_prefix(token_ids) + except Exception as e: + breakdown["t_prefill_done"] = _time.monotonic() + breakdown["t_prefill_done_unix"] = _time.time() + breakdown["prefill_error"] = True + _breakdown_log.append(breakdown) + raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") + finally: + p_inst.ongoing_tokens -= input_length + p_inst.num_requests -= 1 + + # Send decode + d_inst.ongoing_tokens += input_length + d_inst.num_requests += 1 + parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) + bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}" + + decode_data = req_data.copy() + decode_data["kv_transfer_params"] = { + "do_remote_decode": False, "do_remote_prefill": True, + "remote_bootstrap_addr": bootstrap_addr, + "remote_engine_id": p_inst.engine_id.get(0, ""), + "transfer_id": f"xfer-{request_id}", + } + + breakdown["t_decode_sent"] = _time.monotonic() + breakdown["t_decode_sent_unix"] = _time.time() + + async def generate(): + first_token = True + sse_buffer = "" + output_token_ids: list[int] = [] + try: + async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(): + sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( + sse_buffer, chunk) + output_token_ids.extend(new_output_ids) + if first_token: + breakdown["t_first_token"] = _time.monotonic() + breakdown["t_first_token_unix"] = _time.time() + first_token = False + yield chunk + d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) + finally: + breakdown["t_done"] = _time.monotonic() + breakdown["t_done_unix"] = _time.time() + d_inst.ongoing_tokens -= input_length + d_inst.num_requests -= 1 + _breakdown_log.append(breakdown) + + return StreamingResponse(generate(), media_type="text/event-stream") + + +@app.get("/breakdown") +async def get_breakdown(): + """Return per-request breakdown data for analysis.""" + return _breakdown_log + + +@app.get("/worker_state") +async def get_worker_state(): + """Return per-decision worker-state snapshot log (one entry per route decision).""" + return _worker_state_log + + +@app.get("/worker_state/latest") +async def get_worker_state_latest(): + """Return current per-worker state snapshot without recording it.""" + if combined_instances: + return { + "t_unix": _time.time(), + "mode": "combined", + "workers": snapshot_workers(combined_instances), + } + return { + "t_unix": _time.time(), + "mode": "pd_sep", + "workers_prefill": snapshot_workers(prefill_instances), + "workers_decode": snapshot_workers(decode_instances), + } + + +@app.get("/stats") +async def get_stats(): + """Return per-instance live state for debugging.""" + instances = combined_instances or prefill_instances + decode_instances + return [{ + "url": inst.url, + "role": "combined", + "ongoing_tokens": inst.ongoing_tokens, + "pending_prefill_tokens": inst.pending_prefill_tokens, + "ongoing_decode_tokens": inst.ongoing_decode_tokens, + "num_requests": inst.num_requests, + "active_p_offloads": inst.active_p_offloads, + "cached_blocks": len(inst.cached_blocks), + } for inst in instances] + + +def parse_args(): + p = argparse.ArgumentParser(description="Unified cache-aware global scheduler") + p.add_argument("--port", type=int, default=8000) + p.add_argument("--host", type=str, default="0.0.0.0") + p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs") + p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw", + help="PD-Sep prefill: URL [bootstrap_port]") + p.add_argument("--decode", nargs=1, action="append", dest="decode_raw", + help="PD-Sep decode: URL") + p.add_argument("--heavy-threshold", type=int, default=20000, + help="New tokens threshold for HEAVY classification (adaptive offload)") + p.add_argument("--offload", action="store_true", + help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)") + p.add_argument("--bootstrap-ports", type=str, default="", + help="Comma-separated bootstrap ports for combined instances (for offload mode)") + p.add_argument("--policy", type=str, default="linear", + choices=["linear", "lmetric", "load_only", "sticky", + "unified", "unified_kv_both", + "unified_nixl_both", "unified_v2", + "unified_v3"], + help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), " + "load_only (B3 control: pure min-num_requests), " + "sticky (B3 control: hard session affinity), " + "unified (hybrid affinity + LMetric fallback), " + "unified_kv_both (unified picker on kv_both Mooncake " + "vLLMs; isolation control for unified_v2), " + "unified_nixl_both (same as unified_kv_both but using " + "NixlConnector instead of MooncakeConnector; isolates " + "connector implementation from policy effect), " + "or unified_v2 (unified + selective per-request PD-sep " + "via Mooncake; requires --bootstrap-ports and " + "kv_role=kv_both vLLM launch)") + p.add_argument("--v3-rotate-affinity", type=int, default=1, choices=[0,1], + help="unified_v3 only: 1 = rotate session affinity to decode_target " + "after migration (original behavior, empirically loses prefix cache); " + "0 = keep affinity on prefill_host so next turn hits its cache.") + p.add_argument("--connector-type", type=str, default="mooncake", + choices=["mooncake", "nixl"], + help="PD-sep handshake protocol. 'mooncake' uses pre-baked engine_id" + " + bootstrap_addr (requires --bootstrap-ports). 'nixl' uses" + " response-forward (src returns kv_transfer_params, proxy" + " relays to dst; Nixl/UCX auto-picks NVLink intra-node).") + p.add_argument("--v3-prefer-cache-target", type=int, default=1, choices=[0,1], + help="Mechanism B: unified_v3 picks decode_target with the most" + " prefix cache among low-load candidates (default 1). Set 0" + " to fall back to pure-load tie-break (cache-blind).") + p.add_argument("--lmetric-decode-weight", type=float, default=0.0, + help="Direction B: LMetric fallback adds this × ongoing_decode_tokens" + " to the queue-depth score, so hosts with heavy decode load get" + " penalised. 0 = original behavior; 0.01 is a reasonable start.") + p.add_argument("--overload-factor", type=float, default=2.0, + help="Break session affinity when instance load > factor * avg") + # The four flags below are accepted for bench.sh backward compatibility but + # have no effect after the PD-sep offload path was retired (REPORT §3.9, + # commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and + # scripts/legacy/*.sh which still pass them through. + p.add_argument("--max-offload-inflight", type=int, default=4, + help="[DEPRECATED] PUSH offload retired; no effect") + p.add_argument("--offload-mode", type=str, default="cached_prefill", + choices=["direct_read", "cached_prefill"], + help="[DEPRECATED] PUSH offload retired; no effect") + p.add_argument("--cache-gate-ratio", type=float, default=0.0, + help="[DEPRECATED] PUSH offload retired; no effect") + p.add_argument("--decode-iteration-s", type=float, default=0.05, + help="[DEPRECATED] PUSH offload retired; no effect") + args = p.parse_args() + + args.prefill = [] + if args.prefill_raw: + for entry in args.prefill_raw: + url = entry[0] + bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None + args.prefill.append((url, bp)) + args.decode = [e[0] for e in (args.decode_raw or [])] + + if not args.combined and not args.prefill: + p.error("Must specify either --combined or --prefill/--decode") + return args + + +if __name__ == "__main__": + global_args = parse_args() + SETTINGS.heavy_threshold = global_args.heavy_threshold + SETTINGS.overload_factor = global_args.overload_factor + SETTINGS.max_offload_inflight = global_args.max_offload_inflight + SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio + SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) + SETTINGS.v3_rotate_affinity = bool(getattr(global_args, 'v3_rotate_affinity', 1)) + SETTINGS.connector_type = getattr(global_args, 'connector_type', 'mooncake') + SETTINGS.v3_prefer_cache_target = bool(getattr(global_args, 'v3_prefer_cache_target', 1)) + SETTINGS.lmetric_decode_weight = float(getattr(global_args, 'lmetric_decode_weight', 0.0)) + print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s v3_rotate_affinity=%s " + "connector_type=%s v3_prefer_cache_target=%s lmetric_decode_weight=%.3f" % ( + SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, + getattr(global_args, 'offload', False), + SETTINGS.v3_rotate_affinity, + SETTINGS.connector_type, + SETTINGS.v3_prefer_cache_target, + SETTINGS.lmetric_decode_weight)) + uvicorn.run(app, host=global_args.host, port=global_args.port)