Files
agentic-kvc/microbench/connector_tax/layerwise/cache_aware_proxy.WRITEMODE.py
Gahow Wang 5b26c345f4 P2: all routing policies read real state via eff_ accessors + ablation harness
InstanceState.eff_{num_requests,pending_prefill,ongoing_decode,ongoing_tokens}
= max(shadow, real) when feed fresh (fixes 30s-stale under-count, keeps
in-flight RaceFix), plus real-only r_max_prefill_remaining / r_kv_used_frac.
Wired into load_only, lmetric, sticky, unified(_kv_both), unified_v3, and
snapshot logging. Feed off => identical to before. run_v3_trace.sh gains ES=1
toggle (always deploys enhanced proxy); run_ablation_es.sh runs each config
ES0-vs-ES1 to test whether real state changes policy performance/ranking.
All unit-tested without GPU.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 20:21:12 +08:00

1908 lines
83 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import json
import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict, deque
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
MAX_STREAM_RETRIES = 3
RETRY_DELAY_S = 0.5
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
heavy_threshold: int = 20000
overload_factor: float = 2.0
max_offload_inflight: int = 4
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
# --- Patch 6.9: cost-model calibration for unified_v2 ---
# Throughput when the engine runs in kv_both mode. Lower than the
# pure-decode 7000 tok/s because kv_both adds always-on overhead
# (REPORT §3.8 documents ~+16% TPOT vs plain).
prefill_throughput_kv_both: float = 4000.0
# Calibrated RDMA transfer cost: base + bandwidth term.
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
# Bandwidth term reflects realized effective throughput, not
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
# effective on the contended kv_both path. v2 uses this lookup
# rather than the constant rdma_overhead_s.
rdma_base_overhead_s: float = 0.3
rdma_effective_gb_per_s: float = 2.7
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
kv_bytes_per_token: int = 98304
# --- unified_v2 gating knobs (relaxed in v2.1 after the v1 0.2% trigger rate) ---
# B2 microbench shows TPOT idx 1.9x already at new_tokens=8k and TTFT
# idx ~12x; the previous 16k threshold was too conservative and
# rejected 88.7% of candidates (window_1_results/v2_breakdown).
pd_sep_min_new_tokens: int = 8000
pd_sep_min_decodes_protected: int = 1 # any in-flight work on chosen counts
pd_sep_min_src_cache_tokens: int = 4000 # half a block; was 8000
pd_sep_min_extra_cache_tokens: int = 2000 # half a block; was 4000
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
pd_sep_xfer_timeout_s: float = 60.0
# --- unified_v3 (offload-decode) gating knobs -----------------------
# v3 differs from v2 in *direction*: prefill stays on the session-
# affinity host (which holds the prefix cache); decode is migrated to
# a less-loaded target. KV transfer flows prefill_host → decode_target.
# The target doesn't need cache — we're shipping the post-prefill KV
# over anyway. After successful migration the session affinity table
# rotates to decode_target so the *next* turn lands where the KV now
# lives.
v3_min_new_tokens: int = 8000 # same as v2: don't migrate tiny prefills
v3_min_prefill_decode_busy: int = 1 # prefill_host must have ≥ this many concurrent decode tokens to justify migrating
v3_target_load_ratio: float = 0.7 # target.num_requests must be < prefill_host.num_requests × this
v3_min_load_gap: int = 1 # target.num_requests must also be ≤ prefill_host - this (absolute slack)
v3_rotate_affinity: bool = True # after migration, set session affinity to decode_target.
# Empirically False is better — see cache_miss_audit (next turn hits 9.5%
# with rotation vs ~80% without), because delay_free_blocks doesn't
# actually preserve cross-turn KV on decode_target.
v3_prefer_cache_target: bool = True # Mechanism B: among low-load candidates, prefer the one
# with the most prefix cache for this prompt — vLLM's connector
# auto-transfers only the missing portion (verified via
# smoke_partial_transfer: cache-rich dst is 77% faster than
# cold dst at 33k tokens, +512 ext).
# Anti-hotspot: picker scores effective_load = num_requests + (recent
# migrations received within window). Prevents clustering migrations on
# one instance in rapid succession (observed in Mech B run: inst_5 became
# a hotspot via post-rotation tail accumulation).
v3_recent_mig_window_s: float = 10.0 # sliding window
v3_recent_mig_weight: float = 1.0 # how many "virtual requests" each
# recent migration counts as
# P2: real engine-state feed (replaces 30s-stale shadow counters for
# migration target selection). Empty = disabled (use shadow only).
engine_state_uri: str = "" # file:///dev/shm/... or redis://...
engine_state_period_ms: int = 50 # router poll period
es_big_prefill_threshold: int = 16000 # target mid-prefill >= this => avoid (GIL stall)
es_kv_wall_frac: float = 0.90 # target KV usage >= this => avoid (capacity wall)
# Direction B knob: LMetric fallback adds decode-token penalty to score.
# score = (pending_prefill + new + lmetric_decode_weight * ongoing_decode_tok) * num_req
# Empirical iter-time slope on H100 + Qwen3-30B-A3B: each decode token in
# batch costs ~0.01 prefill-token-equivalent in scheduler time, so 0.01 is
# a reasonable starting weight. Set 0 to disable (original behavior).
lmetric_decode_weight: float = 0.0
# --- KV connector selection (governs PD-sep handshake) -------------
# "mooncake": pre-baked kv_transfer_params (bootstrap_addr+engine_id+transfer_id).
# Requires --bootstrap-ports and vLLMs launched with MooncakeConnector.
# "nixl" : response-forward handshake. src returns kv_transfer_params via
# response body, proxy forwards to dst. Nixl auto-selects transport
# via UCX (CUDA IPC / NVLink on intra-node, RDMA across nodes).
connector_type: str = "mooncake"
SETTINGS = Settings()
def estimate_transfer_cost(transfer_bytes: int) -> float:
"""Calibrated RDMA transfer cost as a function of bytes.
Replaces the legacy constant rdma_overhead_s. Calibration sources:
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
~2.7 GB/s effective on the contended kv_both path.
The p90 in that same run is 6.7 s (D-side block reservation +
scheduler step delays). v2's cost model uses the *median* — being
too pessimistic would suppress all PD-sep triggers. The risk of
underestimation is mitigated by the pd_sep_margin_s safety factor.
"""
base = SETTINGS.rdma_base_overhead_s
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
return base + bw_term
def estimate_same_worker_interference_s(
new_tokens: int,
num_decodes: int,
) -> float:
"""Estimated additional latency on `num_decodes` co-located decodes
when a `new_tokens`-token prefill runs on the same worker.
Derived from B2 microbench (analysis/characterization/window_1_results.md):
same-worker prefill of size N steals decode capacity for the
prefill's duration. The penalty factor is the fraction of decode
steps stolen during the prefill window.
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
because each affected decode loses ~penalty fraction of its capacity
during the prefill window.
"""
if num_decodes <= 0:
return 0.0
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
if new_tokens < 4000:
penalty = 0.2
elif new_tokens < 16000:
penalty = 0.5
elif new_tokens < 32000:
penalty = 0.8
else:
penalty = 0.95
return prefill_dur_s * penalty * num_decodes
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
# v3 anti-hotspot: timestamps (monotonic) when this instance was picked
# as a v3 migration target. Used to compute effective_load = num_req +
# recent-migration count over a sliding window, preventing back-to-back
# decisions from clustering on the same dst.
self.recent_mig_targeted_at: deque[float] = deque(maxlen=64)
# P2: latest real engine state (from the engine-state feed), or None
# when the feed is disabled/stale. Set by _engine_state_poll_loop.
self.real_state: dict | None = None
# ---- effective-load accessors (P2): prefer REAL engine state when the
# feed is fresh, else the proxy shadow counter. We take max(shadow, real)
# for load so we never under-count: REAL fixes the 30s-stale under-count,
# while the shadow's atomic pre-await reservation still covers the
# in-flight window (preserving the RaceFix against concurrent picks).
def eff_num_requests(self) -> float:
rs = self.real_state
if rs is not None:
return max(self.num_requests,
rs.get("num_running", 0) + rs.get("num_waiting", 0))
return self.num_requests
def eff_pending_prefill(self) -> float:
rs = self.real_state
if rs is not None:
return max(self.pending_prefill_tokens,
rs.get("pending_prefill_tokens", 0))
return self.pending_prefill_tokens
def eff_ongoing_decode(self) -> float:
rs = self.real_state
if rs is not None:
return max(self.ongoing_decode_tokens,
rs.get("ongoing_decode_tokens", 0))
return self.ongoing_decode_tokens
def eff_ongoing_tokens(self) -> float:
rs = self.real_state
if rs is not None:
return max(self.ongoing_tokens,
rs.get("pending_prefill_tokens", 0)
+ rs.get("ongoing_decode_tokens", 0))
return self.ongoing_tokens
def r_max_prefill_remaining(self) -> int:
rs = self.real_state
return int(rs.get("max_prefill_remaining", 0)) if rs is not None else 0
def r_kv_used_frac(self) -> float:
rs = self.real_state
if rs is not None:
return float(rs.get("gpu_kv_used_frac", 0.0) or 0.0)
return 0.0
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh) # LRU touch on hit
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * SETTINGS.heavy_threshold
def snapshot_workers(
instances: list[InstanceState],
token_ids: list[int] | None = None,
input_length: int = 0,
) -> list[dict]:
"""Per-worker state at route-decision time.
All routing-relevant counters plus the score each policy would
have produced for `input_length` if it were dispatched now. Cheap
enough to call on every request; B3 hot-spot analysis depends on
this being captured per decision.
"""
snap: list[dict] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
new_prefill = max(0, input_length - cache_hit)
snap.append({
"idx": i,
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
"cache_hit": cache_hit,
"new_prefill": new_prefill,
# scores reflect the ACTUAL decision basis (eff = real-or-shadow).
"score_linear": (inst.eff_ongoing_tokens()
+ _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit),
"score_lmetric": (inst.eff_pending_prefill() + new_prefill)
* inst.eff_num_requests(),
# P2: real-state fields when the feed is fresh (None otherwise).
"real_num_requests": (inst.eff_num_requests()
if inst.real_state is not None else None),
"real_max_prefill_remaining": (inst.r_max_prefill_remaining()
if inst.real_state is not None else None),
"real_kv_used_frac": (inst.r_kv_used_frac()
if inst.real_state is not None else None),
})
return snap
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
avg_load = max(sum(i.eff_ongoing_tokens() for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.eff_ongoing_tokens() <= avg_load * SETTINGS.overload_factor
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.eff_ongoing_tokens() + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_load_only(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int]:
"""Pure load balancing: pick instance with fewest in-flight requests.
Ignores cache hits and session affinity. Used as a B3 control to
isolate the locality contribution of cache-aware policies.
"""
best_idx = min(range(len(instances)),
key=lambda i: instances[i].eff_num_requests())
return instances[best_idx], best_idx
def pick_instance_sticky(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int]:
"""Hard session affinity: once assigned, never break.
First turn of a session picks the instance with the lowest
num_requests; subsequent turns always return to the same instance
regardless of load. Used as a B3 control to isolate the hot-spot
cost of perfect locality.
"""
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
return instances[idx], idx
best_idx = min(range(len(instances)),
key=lambda i: instances[i].eff_num_requests())
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity (the
session_id/affinity args are accepted for signature compatibility
with pick_instance/pick_instance_unified_hybrid but ignored).
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.eff_pending_prefill() + new_prefill
bs = inst.eff_num_requests()
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
_unified_fallback_rr_counter = 0
def pick_instance_unified_hybrid(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict]:
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
Affinity gate (both must hold to stick):
- affinity instance cache_hit / input_length > 0.5
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
Fallback ordering (when affinity not used):
primary: score = P_tokens * BS (LMetric)
secondary: new_uncached_tokens (prefer instance with most cache)
tertiary: num_requests (prefer least-loaded)
quaternary: round-robin (avoid degenerate inst-0 pinning
when BS=0 across the board)
Returns (chosen, idx, decision_dict). decision_dict carries the
review #7 breakdown fields so the caller can merge them verbatim.
"""
global _unified_fallback_rr_counter
n = len(instances)
avg_reqs = max(sum(i.eff_num_requests() for i in instances) / n, 1.0)
decision: dict = {
"decision": "lmetric_fallback",
"affinity_idx": None,
"chosen_idx": None,
"affinity_cache_hit": None,
"affinity_cache_ratio": None,
"affinity_num_requests": None,
"avg_num_requests": avg_reqs,
"fallback_score": None,
"tie_break_used": False,
}
if session_id and session_id in affinity:
a_idx = affinity[session_id]
if a_idx < n:
a_inst = instances[a_idx]
a_hit = a_inst.estimate_cache_hit(token_ids)
a_ratio = a_hit / max(input_length, 1)
decision["affinity_idx"] = a_idx
decision["affinity_cache_hit"] = a_hit
decision["affinity_cache_ratio"] = a_ratio
decision["affinity_num_requests"] = a_inst.eff_num_requests()
if (a_ratio > 0.5
and a_inst.eff_num_requests() <= avg_reqs * SETTINGS.overload_factor):
decision["decision"] = "affinity"
decision["chosen_idx"] = a_idx
return a_inst, a_idx, decision
# Direction B: extend LMetric with decode-load awareness.
# Original score = (pending_prefill + new_uncached) * num_requests, which
# ignores ongoing decode work. A host with 200k decode tokens looks "ideal"
# (P_tokens=0) but its decode iters are slow due to large batch KV reads.
#
# First attempt (BUG): score = (p_tokens + decode_pen) * num_req — when
# num_req=0 the decode_pen is zeroed out, so idle-but-decoding hosts still
# look free and accumulate cold prefills (8007 hotspot in A+B v1 run).
#
# Fix: max(num_req, 1) so decode_pen contributes on idle hosts too.
keys: list[tuple[float, int, int, int]] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
# P2: effective (real-or-shadow) load signals.
p_tokens = inst.eff_pending_prefill() + new_prefill
decode_pen = SETTINGS.lmetric_decode_weight * inst.eff_ongoing_decode()
bs = inst.eff_num_requests()
score = (p_tokens + decode_pen) * max(bs, 1)
keys.append((score, new_prefill, bs, i))
best_triple = min(k[:3] for k in keys)
tied = [k for k in keys if k[:3] == best_triple]
if len(tied) > 1:
decision["tie_break_used"] = True
_unified_fallback_rr_counter += 1
winner = tied[_unified_fallback_rr_counter % len(tied)]
else:
winner = tied[0]
chosen_idx = winner[3]
decision["fallback_score"] = winner[0]
decision["chosen_idx"] = chosen_idx
return instances[chosen_idx], chosen_idx, decision
def pick_instance_unified_v2(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
Stage 2 asks: is there another instance with materially more cache
for this request? If yes, would doing prefill on that instance and
transferring KV to `chosen` for decode be cheaper than just doing
everything on `chosen`?
The cost model compares two scenarios in seconds-of-decode-disruption:
local: same-worker prefill on chosen of (input - chosen.cache_hit)
tokens interferes with chosen.num_decodes co-located decodes.
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
(smaller, because src has more cache) interferes with
src.num_decodes co-located decodes, plus we pay RDMA
transfer of src.cache_hit blocks to chosen.
We migrate only when local cost > pd-sep cost + safety margin AND
a set of hard gates (size, cache, decodes) are met.
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
the handler should do local routing on `chosen`. When pd_sep is
(src_inst, src_idx) the handler should do prefill-on-src,
decode-on-chosen via Mooncake.
"""
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
instances, token_ids, session_id, input_length, affinity)
decision["v2_pd_sep"] = False
decision["v2_decision"] = "local"
decision["v2_reason"] = None
if not token_ids:
decision["v2_reason"] = "no_token_ids"
return chosen, chosen_idx, decision, None
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
new_local = max(0, input_length - chosen_cache_hit)
# Hard gate 1: prefill must be large enough that interference
# outweighs the fixed RDMA setup cost.
if new_local < SETTINGS.pd_sep_min_new_tokens:
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 2: chosen must have live decoding work to protect.
# v2.1 simplification: pure ongoing_decode_tokens check. The previous
# gate combined num_requests and decode_tokens with AND, but
# num_requests includes requests still in prefill — adding a prefill
# to a chosen that has only its own prefill running doesn't disrupt
# any decode, so skipping makes sense. The right semantic is "skip
# iff no decode is currently happening on chosen".
if chosen.ongoing_decode_tokens == 0:
decision["v2_reason"] = (
f"chosen_no_active_decode "
f"(num_req={chosen.num_requests} decode_tok={chosen.ongoing_decode_tokens})"
)
return chosen, chosen_idx, decision, None
# Find best alternative cache source.
best_src_idx, best_src_hit = -1, 0
for i, inst in enumerate(instances):
if i == chosen_idx:
continue
h = inst.estimate_cache_hit(token_ids)
if h > best_src_hit:
best_src_idx, best_src_hit = i, h
# Hard gate 3: src must hold meaningful cache.
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 4: src must hold materially more cache than chosen.
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
decision["v2_reason"] = (
f"src_not_meaningfully_more_cache "
f"(src={best_src_hit} chosen={chosen_cache_hit})"
)
return chosen, chosen_idx, decision, None
src = instances[best_src_idx]
new_src = max(0, input_length - best_src_hit)
# Cost-benefit in seconds-of-decode-disruption.
cost_local = estimate_same_worker_interference_s(
new_local, chosen.num_requests)
cost_src_interf = estimate_same_worker_interference_s(
new_src, src.num_requests)
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
cost_xfer = estimate_transfer_cost(transfer_bytes)
cost_migrate = cost_src_interf + cost_xfer
decision["v2_chosen_cache_hit"] = chosen_cache_hit
decision["v2_src_idx"] = best_src_idx
decision["v2_src_cache_hit"] = best_src_hit
decision["v2_new_local"] = new_local
decision["v2_new_src"] = new_src
decision["v2_cost_local_s"] = cost_local
decision["v2_cost_src_interf_s"] = cost_src_interf
decision["v2_cost_xfer_s"] = cost_xfer
decision["v2_cost_migrate_s"] = cost_migrate
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
decision["v2_pd_sep"] = True
decision["v2_decision"] = "pd_sep"
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, (src, best_src_idx)
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, None
def pick_instance_unified_v3(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
"""unified_v3 = unified hybrid + selective DECODE migration.
Direction-reversed from unified_v2:
- prefill stays on session-affinity host (`prefill_host`) so we keep
the 93%-intra-session prefix-cache reuse intact.
- decode is migrated to a lower-load `decode_target` when the
affinity host is busy with concurrent decodes.
- KV transfer flows prefill_host → decode_target (the opposite of
v2's src → chosen).
- target does NOT need pre-existing cache — we're shipping the
post-prefill KV over anyway.
- On successful migration the *caller* must rotate
`affinity[session_id] = decode_target_idx` so the next turn lands
where the KV now lives (decode_target retains the blocks after
completion, since mooncake defaults to delay_free_blocks=True).
Decision is purely load-based on the target side:
1. new_local ≥ v3_min_new_tokens (don't pay RDMA for tiny prefills)
2. prefill_host.ongoing_decode_tokens ≥ v3_min_prefill_decode_busy
(the host is actually busy decoding; migration buys decode-bw)
3. ∃ target with both num_requests < prefill_host.num_requests × ratio
and num_requests ≤ prefill_host.num_requests v3_min_load_gap
Returns (prefill_host, prefill_idx, decision, migrate). When migrate
is None the request is fully local on prefill_host. When migrate is
(decode_target_inst, decode_target_idx), the handler should run
prefill on prefill_host and ship KV to decode_target for decode.
"""
prefill_host, prefill_idx, decision = pick_instance_unified_hybrid(
instances, token_ids, session_id, input_length, affinity)
decision["v3_migrate"] = False
decision["v3_decision"] = "local"
decision["v3_reason"] = None
if not token_ids:
decision["v3_reason"] = "no_token_ids"
return prefill_host, prefill_idx, decision, None
prefill_cache_hit = prefill_host.estimate_cache_hit(token_ids)
new_local = max(0, input_length - prefill_cache_hit)
decision["v3_prefill_cache_hit"] = prefill_cache_hit
decision["v3_new_local"] = new_local
# Gate 1: prefill must be large enough to amortise RDMA setup.
if new_local < SETTINGS.v3_min_new_tokens:
decision["v3_reason"] = (
f"new_local_below_threshold ({new_local} < {SETTINGS.v3_min_new_tokens})"
)
return prefill_host, prefill_idx, decision, None
# Gate 2: affinity host must be busy with concurrent decodes — that's
# what migrating decode-traffic-away buys us. If the host is idle
# there's no point.
if prefill_host.eff_ongoing_decode() < SETTINGS.v3_min_prefill_decode_busy:
decision["v3_reason"] = (
f"prefill_host_not_busy "
f"(ongoing_decode_tokens={prefill_host.eff_ongoing_decode()} < "
f"{SETTINGS.v3_min_prefill_decode_busy})"
)
return prefill_host, prefill_idx, decision, None
# Gate 3: pick the lowest-effective-load target. effective_load adds a
# penalty for recent migrations the instance has received (anti-hotspot).
now_mono = _time.monotonic()
cutoff = now_mono - SETTINGS.v3_recent_mig_window_s
def _real_load(inst):
# P2: effective (real-or-shadow) request load; see eff_num_requests.
return inst.eff_num_requests()
def effective_load(inst):
# Drop expired entries lazily.
while inst.recent_mig_targeted_at and inst.recent_mig_targeted_at[0] < cutoff:
inst.recent_mig_targeted_at.popleft()
recent = len(inst.recent_mig_targeted_at)
return _real_load(inst) + recent * SETTINGS.v3_recent_mig_weight
ph_load = _real_load(prefill_host)
threshold_loaded = max(1,
int(ph_load * SETTINGS.v3_target_load_ratio))
candidates = [
(i, inst) for i, inst in enumerate(instances)
if i != prefill_idx
and effective_load(inst) < threshold_loaded
and effective_load(inst) <= ph_load - SETTINGS.v3_min_load_gap
]
if not candidates:
decision["v3_reason"] = (
f"no_low_load_target "
f"(prefill_host.num_req={prefill_host.num_requests} "
f"threshold={threshold_loaded} "
f"eff_loads=[{','.join(f'{int(effective_load(i))}' for i in instances)}])"
)
return prefill_host, prefill_idx, decision, None
# Mechanism B (v3_prefer_cache_target=True): rank candidates first by
# cache_hit DESC (more cache = less KV to transfer), then by effective_load
# (which includes recent-migration penalty), then by ongoing_tokens.
if SETTINGS.v3_prefer_cache_target:
def _tgt_key(x):
# P2: avoid a target that is mid-large-prefill (holds the GIL,
# stalls the mooncake receiver_loop = the ~45% control-plane
# residual layer-wise can't fix) or near the KV capacity wall,
# before ranking by cache-richness and real load.
inst = x[1]
ch = inst.estimate_cache_hit(token_ids)
rs = getattr(inst, "real_state", None)
stalls = near_wall = 0
if rs is not None:
if int(rs.get("max_prefill_remaining", 0)) >= SETTINGS.es_big_prefill_threshold:
stalls = 1
f = rs.get("gpu_kv_used_frac", 0.0) or 0.0
if float(f) >= SETTINGS.es_kv_wall_frac:
near_wall = 1
return (stalls, near_wall, -ch, effective_load(inst), inst.ongoing_tokens)
decode_target_idx, decode_target = min(candidates, key=_tgt_key)
else:
decode_target_idx, decode_target = min(
candidates, key=lambda x: (effective_load(x[1]), x[1].ongoing_tokens))
target_cache_hit = decode_target.estimate_cache_hit(token_ids)
target_recent_received = len(decode_target.recent_mig_targeted_at)
# Record this decision for the anti-hotspot accounting.
decode_target.recent_mig_targeted_at.append(now_mono)
decision["v3_migrate"] = True
decision["v3_decision"] = "migrate_decode"
decision["v3_src_idx"] = prefill_idx
decision["v3_target_idx"] = decode_target_idx
decision["v3_target_num_req"] = decode_target.num_requests
decision["v3_target_cache_hit"] = target_cache_hit
decision["v3_target_recent_received"] = target_recent_received
decision["v3_prefill_num_req"] = prefill_host.num_requests
# Snapshot of src state at the moment of decision (for postmortem).
decision["v3_src_state"] = {
"num_requests": prefill_host.num_requests,
"ongoing_tokens": prefill_host.ongoing_tokens,
"ongoing_decode_tokens": prefill_host.ongoing_decode_tokens,
"pending_prefill_tokens": prefill_host.pending_prefill_tokens,
}
decision["v3_target_state"] = {
"num_requests": decode_target.num_requests,
"ongoing_tokens": decode_target.ongoing_tokens,
"ongoing_decode_tokens": decode_target.ongoing_decode_tokens,
"pending_prefill_tokens": decode_target.pending_prefill_tokens,
"cache_hit_estimate": target_cache_hit,
"recent_mig_received_in_window": target_recent_received,
}
decision["v3_reason"] = (
f"prefill_host.num_req={prefill_host.num_requests} busy; "
f"target.num_req={decode_target.num_requests} cache_hit={target_cache_hit} "
f"recent_received={target_recent_received}, "
f"transferring KV after prefill"
)
return prefill_host, prefill_idx, decision, (decode_target, decode_target_idx)
def _extract_output_token_ids_from_sse(
buffer: str,
chunk: bytes,
) -> tuple[str, list[int]]:
"""Extract vLLM streaming token_ids while preserving the raw stream."""
buffer += chunk.decode("utf-8", errors="ignore")
complete = buffer.endswith("\n") or buffer.endswith("\r")
lines = buffer.splitlines()
if complete:
buffer = ""
elif lines:
buffer = lines.pop()
else:
return buffer, []
output_ids: list[int] = []
for line in lines:
line = line.strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if not data or data == "[DONE]":
continue
try:
payload = json.loads(data)
except json.JSONDecodeError:
continue
choices = payload.get("choices", [])
for choice in choices:
token_ids = choice.get("token_ids")
if isinstance(token_ids, list):
output_ids.extend(
int(t) for t in token_ids if isinstance(t, int)
)
return buffer, output_ids
def _realized_tokens(
prompt_token_ids: list[int] | None,
output_token_ids: list[int],
) -> list[int] | None:
if prompt_token_ids is None:
return None
if not output_token_ids:
return prompt_token_ids
return prompt_token_ids + output_token_ids
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
# different instance lists, so a shared dict could mis-route after a mode switch.
session_affinity_combined: dict[str, int] = {}
session_affinity_prefill: dict[str, int] = {}
# Backwards-compat alias used by /stats etc.
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
_worker_state_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None:
"""Read vLLM's truth: (num_running, num_waiting). Returns None on failure."""
try:
resp = await asyncio.wait_for(inst.client.get("/metrics"), timeout=5.0)
if resp.status_code != 200:
return None
text = resp.text
except Exception:
return None
running = 0
waiting = 0
for line in text.splitlines():
if line.startswith("vllm:num_requests_running"):
try:
running = int(float(line.split()[-1]))
except (ValueError, IndexError):
pass
elif line.startswith("vllm:num_requests_waiting"):
try:
waiting = int(float(line.split()[-1]))
except (ValueError, IndexError):
pass
return running, waiting
def _engine_state_read_all(uri: str, max_age_s: float = 2.0) -> dict:
"""P2 reader (inlined; mirrors engine_state.StateReader). Returns
{engine_id: state}, dropping records older than max_age_s."""
now = _time.time()
out: dict = {}
try:
if uri.startswith("file://"):
import glob
d = uri[len("file://"):]
for p in glob.glob(os.path.join(d, "*.json")):
try:
s = json.load(open(p))
except Exception:
continue
if now - s.get("ts", 0) <= max_age_s:
out[s.get("engine_id", os.path.basename(p)[:-5])] = s
elif uri.startswith("redis://"):
import redis
r = redis.Redis.from_url(uri)
for k in r.scan_iter("engine_state:*"):
v = r.get(k)
if not v:
continue
s = json.loads(v)
if now - s.get("ts", 0) <= max_age_s:
out[s.get("engine_id")] = s
except Exception:
pass
return out
async def _engine_state_poll_loop():
"""P2: poll the engine-state feed and attach real_state to each instance.
Instance i is keyed engine_{i} (matches AGENTIC_WORKER_ID in the launcher).
"""
uri = SETTINGS.engine_state_uri
if not uri:
return
period = max(0.01, SETTINGS.engine_state_period_ms / 1000.0)
insts = combined_instances or (prefill_instances + decode_instances)
print(f"[engine-state] polling {uri} every {period*1000:.0f}ms for {len(insts)} instances")
while True:
try:
await asyncio.sleep(period)
except asyncio.CancelledError:
return
states = await asyncio.to_thread(_engine_state_read_all, uri)
for i, inst in enumerate(insts):
inst.real_state = states.get(f"engine_{i}")
async def _reconcile_loop():
"""Periodic shadow-state reconciliation against vLLM /metrics truth.
The proxy maintains shadow counters (num_requests, ongoing_tokens,
pending_prefill_tokens, ongoing_decode_tokens) by incrementing in
`_handle_local_request` and decrementing in the generator's finally
block. When the generator never enters (client disconnect between
StreamingResponse construction and Starlette starting iteration, or
Starlette failing before iteration), the decrement never fires and
the counter stays elevated forever. Over a long run the shadow
accumulates "phantom" load that biases routing decisions away from
the affected instance.
Two-pass fix:
1. Clamp negatives (defensive; rare in practice).
2. Sample vLLM's actual num_running + num_waiting via /metrics. If
the proxy's num_requests has been *higher* than vLLM's truth for
two consecutive cycles, reconcile downward to vLLM's count.
Two-cycle persistence avoids correcting transient mismatches
(e.g., proxy just incremented but vLLM hasn't scheduled the
request yet).
Cycle period: 30 s. Two-cycle persistence threshold: 60 s of stable
drift before correction.
"""
prev_phantom: dict[str, int] = {}
while True:
try:
await asyncio.sleep(30)
except asyncio.CancelledError:
return
for inst in combined_instances + prefill_instances + decode_instances:
# Pass 1: clamp negatives (cheap, always do).
if inst.ongoing_tokens < 0:
inst.ongoing_tokens = 0
if inst.ongoing_decode_tokens < 0:
inst.ongoing_decode_tokens = 0
if inst.pending_prefill_tokens < 0:
inst.pending_prefill_tokens = 0
if inst.num_requests < 0:
inst.num_requests = 0
if inst.active_p_offloads < 0:
inst.active_p_offloads = 0
# Pass 2: detect phantom positives by polling vLLM truth.
metrics = await _fetch_vllm_inflight(inst)
if metrics is None:
continue
running, waiting = metrics
actual_inflight = running + waiting
phantom = inst.num_requests - actual_inflight
prev = prev_phantom.get(inst.url, 0)
if phantom > 0 and prev > 0:
# Drift held across two consecutive cycles (~60 s).
# Reconcile shadow to vLLM's truth.
old_num = inst.num_requests
inst.num_requests = actual_inflight
if actual_inflight == 0:
# No requests in flight; zero all per-request counters.
inst.ongoing_tokens = 0
inst.ongoing_decode_tokens = 0
inst.pending_prefill_tokens = 0
print(
f"[reconcile] {inst.url}: phantom drift "
f"num_requests {old_num} -> {actual_inflight} "
f"(vllm running={running} waiting={waiting})"
)
prev_phantom[inst.url] = phantom
def _verify_vllm_patch():
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
The patch turns an `assert req_id in self.requests` into a soft warn so
that engines do not crash on the KV-transfer abort race (see REPORT
§3.x). If somebody upgrades vLLM without re-applying the patch, the
assert returns and elastic mode dies under load. Print a loud warning
so we catch the regression before the first HEAVY request.
"""
try:
import inspect
from vllm.v1.core.sched.scheduler import Scheduler
src = inspect.getsource(Scheduler)
if "assert req_id in self.requests" in src:
print("WARNING: vLLM scheduler still contains the unpatched "
"`assert req_id in self.requests` line; expect engine "
"death on KV-transfer abort race. Apply "
"patches/0001-fix-kv-transfer-abort-race.patch.")
else:
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
except Exception as exc:
print(f"vLLM patch self-check skipped: {exc!r}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
_verify_vllm_patch()
reconcile_task = asyncio.create_task(_reconcile_loop())
engine_state_task = asyncio.create_task(_engine_state_poll_loop())
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
policy = getattr(global_args, 'policy', 'linear')
# Mooncake-based modes still need bootstrap discovery; NIXL uses
# its own UCX side-channel and doesn't go through our proxy
# bootstrap path. With --connector-type=nixl, v3 also skips bootstrap.
needs_bootstrap = (
global_args.offload
or (policy in ("unified_v2", "unified_v3", "unified_kv_both")
and getattr(global_args, 'connector_type', 'mooncake') == 'mooncake')
)
if needs_bootstrap and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
elif needs_bootstrap and not bp_list:
raise RuntimeError(
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
"got empty bootstrap list."
)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
reconcile_task.cancel()
try:
await reconcile_task
except asyncio.CancelledError:
pass
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
incoming_rid = request.headers.get("X-Request-Id")
request_id = incoming_rid or str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen: InstanceState, estimated_new: int,
breakdown: dict, *, _pre_reserved: bool = False):
breakdown.setdefault("route_class", "LOCAL")
breakdown.setdefault("routed_to", chosen.url)
# Skip reservation when called from _handle_combined (it already reserved
# synchronously to close the picker→await race). When called directly
# from non-combined paths (PD-Sep, offload), reserve here for safety.
if not _pre_reserved:
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
sse_buffer = ""
output_token_ids: list[int] = []
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
prefill_done = True
yield chunk
chosen.record_prefix(
_realized_tokens(token_ids, output_token_ids))
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Route a /v1/* request among combined (PD-colocated) instances.
--policy options:
linear: cache_hit-aware load score + sticky session affinity.
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
and it is not overloaded; otherwise fall back to LMetric
with a multi-key tie-breaker.
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
"""
policy = getattr(global_args, 'policy', 'linear')
t_decision_unix = _time.time()
request_id = headers.get("X-Request-Id", "")
breakdown: dict = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": policy,
}
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
pd_sep_v2: tuple[InstanceState, int] | None = None
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "load_only":
chosen, best_idx = pick_instance_load_only(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "sticky":
chosen, best_idx = pick_instance_sticky(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy in ("unified", "unified_kv_both", "unified_nixl_both"):
# unified_kv_both: same picker as `unified`, but the vLLMs are
# launched in kv_role=kv_both with MooncakeConnector. Use this
# as an isolation control for `unified_v2` so the v2-vs-v1 gap
# reflects only the PD-sep branch, not the kv_both always-on
# overhead.
# unified_nixl_both: identical to unified_kv_both but with
# NixlConnector at the vLLM layer. Used to attribute the
# kv_both overhead to either Mooncake-specific code or a
# generic v1-connector cost.
chosen, best_idx, decision = pick_instance_unified_hybrid(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
elif policy == "unified_v2":
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
elif policy == "unified_v3":
# v3: prefill on affinity (cache reuse), decode migrated to a
# low-load target. KV flows prefill_host → decode_target.
# Reuses _handle_combined_pd_sep_v2 with src=prefill_host,
# dst=decode_target (the handler is direction-agnostic).
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v3(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
if pd_sep_v2 is not None and SETTINGS.v3_rotate_affinity:
# Migration + rotation: redirect next turn to decode_target,
# assuming KV will live there. (Empirically wrong — see
# cache_miss_audit. Keep behind a flag.)
_decode_target_inst, decode_target_idx = pd_sep_v2
session_affinity_combined[session_id] = decode_target_idx
else:
# No rotation: keep affinity on prefill_host (where the prefix
# cache lives). This is the empirically correct choice.
session_affinity_combined[session_id] = best_idx
else: # linear (default)
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
chosen_snap = pre_decision_workers[best_idx]
cache_hit = chosen_snap["cache_hit"]
estimated_new = chosen_snap["new_prefill"]
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
"routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
"chosen_score_linear": chosen_snap["score_linear"],
"chosen_score_lmetric": chosen_snap["score_lmetric"],
})
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"v2_pd_sep": pd_sep_v2 is not None,
"workers": pre_decision_workers,
})
if pd_sep_v2 is not None:
# Handler contract: first arg = prefill source (does same-worker
# prefill with do_remote_decode=True, max_tokens=1), second arg =
# decode target (does do_remote_prefill=True, pulls KV via
# Mooncake, decodes).
#
# v2 contract: pd_sep_v2 = (src_inst, src_idx); chosen = decode
# → src does prefill (it has more cache), chosen decodes.
# v3 contract: chosen = prefill_host (affinity, has cache);
# pd_sep_v2 = (decode_target_inst, decode_target_idx)
# → chosen does prefill (cache reuse), decode_target decodes.
if policy == "unified_v3":
decode_target_inst, decode_target_idx = pd_sep_v2
prefill_inst = chosen
breakdown["v2_src_url"] = prefill_inst.url
breakdown["v2_src_idx"] = best_idx
breakdown["v3_decode_target_url"] = decode_target_inst.url
breakdown["v3_decode_target_idx"] = decode_target_idx
return await _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
prefill_inst, decode_target_inst, breakdown,
request_id=request_id)
else:
src_inst, src_idx = pd_sep_v2
breakdown["v2_src_url"] = src_inst.url
breakdown["v2_src_idx"] = src_idx
return await _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src_inst, chosen, breakdown,
request_id=request_id)
# Race fix: reserve load on `chosen` BEFORE the `await` so concurrent
# picker calls in the same asyncio event-loop tick see the updated
# counters. Without this, two requests arriving back-to-back can both
# pick the same "free" instance and both end up running there
# simultaneously (observed as 8007 hotspot in A+B run).
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
breakdown.setdefault("route_class", "LOCAL")
breakdown.setdefault("routed_to", chosen.url)
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown, _pre_reserved=True)
async def _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src: InstanceState, dst: InstanceState, breakdown: dict,
*, request_id: str,
):
"""Per-request PD-sep among combined instances (unified_v2 path).
src does cached prefill (max_tokens=1) and ships KV to dst via
Mooncake; dst pulls KV and decodes. Both instances must run in
kv_role=kv_both with bootstrap server enabled.
Patch 6.6: the dst streaming call uses a per-chunk read timeout
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
the request instead of hanging for 600 s.
"""
connector = SETTINGS.connector_type
if connector == "mooncake" and src.bootstrap_port is None:
raise HTTPException(
status_code=500,
detail=(
"Mooncake PD-sep triggered but src instance "
f"{src.url} has no bootstrap_port; launch with "
"kv_role=kv_both and pass --bootstrap-ports"
),
)
# Reserve load on both endpoints.
src.ongoing_tokens += input_length
src.num_requests += 1
dst.ongoing_tokens += input_length
dst.num_requests += 1
src_load_held = True
dst_load_held = True
# ---- LAYERWISE write-mode (opt-in EAR_WRITE_MODE=1, mooncake only) ------
# Dispatch src prefill and dst decode CONCURRENTLY so the dst handshake
# reaches src during its prefill, letting the layer-wise connector push KV
# per-step (overlapped with prefill compute) instead of post-hoc.
if os.environ.get("EAR_WRITE_MODE", "0") == "1" and connector == "mooncake":
return await _handle_pdsep_v2_write_mode(
api, req_data, headers, token_ids, input_length,
src, dst, breakdown, request_id=request_id)
# Build prefill kv_transfer_params per connector.
prefill_data = req_data.copy()
if connector == "mooncake":
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
else: # nixl: src just signals it'll produce KV for remote decode
prefill_data["kv_transfer_params"] = {"do_remote_decode": True}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
forwarded_params: dict | None = None
try:
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
if connector == "nixl":
# Nixl populates kv_transfer_params in the response body with
# remote_block_ids / remote_engine_id / remote_host / remote_port.
# We must read the body BEFORE aclose.
src_resp_json = resp.json()
forwarded_params = src_resp_json.get("kv_transfer_params")
if not forwarded_params or not forwarded_params.get("remote_block_ids"):
raise HTTPException(
status_code=502,
detail=f"Nixl src returned no remote_block_ids: {forwarded_params}",
)
await resp.aclose()
src.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
breakdown["error_detail"] = repr(e)[:300]
_breakdown_log.append(breakdown)
# Release reservations on failure. Clear load_held flags so the
# finally block below does not double-decrement (CRITICAL audit #1).
if src_load_held:
src.ongoing_tokens -= input_length
src.num_requests -= 1
src_load_held = False
if dst_load_held:
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
dst_load_held = False
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
if src_load_held:
src.ongoing_tokens -= input_length
src.num_requests -= 1
src_load_held = False
decode_data = req_data.copy()
if connector == "mooncake":
parsed = urllib.parse.urlparse(str(src.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": src.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
else: # nixl: forward what src returned
decode_data["kv_transfer_params"] = forwarded_params
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
xfer_timeout = httpx.Timeout(
connect=10.0,
read=SETTINGS.pd_sep_xfer_timeout_s,
write=10.0,
pool=10.0,
)
async def generate():
nonlocal dst_load_held
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with dst.client.stream(
"POST", api, json=decode_data, headers=headers,
timeout=xfer_timeout,
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
if dst_load_held:
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
dst_load_held = False
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pdsep_v2_write_mode(
api, req_data, headers, token_ids, input_length,
src: InstanceState, dst: InstanceState, breakdown: dict,
*, request_id: str,
):
"""Write-mode v3 (mooncake + MOONCAKE_LAYERWISE): dispatch src prefill and
dst decode CONCURRENTLY so the dst handshake reaches src during prefill and
KV is pushed per-layer (overlapped) instead of post-hoc. Caller has already
reserved load on both src and dst; we release it in generate()'s finally.
"""
parsed = urllib.parse.urlparse(str(src.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
tid = f"xfer-{request_id}"
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False, "transfer_id": tid}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": src.engine_id.get(0, ""),
"transfer_id": tid}
xfer_timeout = httpx.Timeout(
connect=10.0, read=SETTINGS.pd_sep_xfer_timeout_s, write=10.0, pool=10.0)
breakdown["write_mode"] = True
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
prefill_task = asyncio.create_task(
src.client.post(api, json=prefill_data, headers=p_headers))
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with dst.client.stream(
"POST", api, json=decode_data, headers=headers,
timeout=xfer_timeout,
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
try:
presp = await prefill_task
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
presp.raise_for_status()
await presp.aclose()
src.record_prefix(token_ids)
except Exception as e:
breakdown["prefill_error"] = True
breakdown["error_detail"] = repr(e)[:300]
src.ongoing_tokens -= input_length
src.num_requests -= 1
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
t_decision_unix = _time.time()
breakdown = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": "pd_sep",
}
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_idx = min(range(len(decode_instances)),
key=lambda i: decode_instances[i].ongoing_tokens)
d_inst = decode_instances[d_idx]
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["candidate_scores_prefill"] = pre_decision_p
breakdown["candidate_scores_decode"] = pre_decision_d
breakdown["chosen_p_idx"] = p_idx
breakdown["chosen_d_idx"] = d_idx
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": "pd_sep",
"chosen_p_idx": p_idx,
"chosen_d_idx": d_idx,
"workers_prefill": pre_decision_p,
"workers_decode": pre_decision_d,
})
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
p_inst.num_requests += 1
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
p_inst.num_requests -= 1
# Send decode
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/worker_state")
async def get_worker_state():
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
return _worker_state_log
@app.get("/worker_state/latest")
async def get_worker_state_latest():
"""Return current per-worker state snapshot without recording it."""
if combined_instances:
return {
"t_unix": _time.time(),
"mode": "combined",
"workers": snapshot_workers(combined_instances),
}
return {
"t_unix": _time.time(),
"mode": "pd_sep",
"workers_prefill": snapshot_workers(prefill_instances),
"workers_decode": snapshot_workers(decode_instances),
}
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear",
choices=["linear", "lmetric", "load_only", "sticky",
"unified", "unified_kv_both",
"unified_nixl_both", "unified_v2",
"unified_v3"],
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
"load_only (B3 control: pure min-num_requests), "
"sticky (B3 control: hard session affinity), "
"unified (hybrid affinity + LMetric fallback), "
"unified_kv_both (unified picker on kv_both Mooncake "
"vLLMs; isolation control for unified_v2), "
"unified_nixl_both (same as unified_kv_both but using "
"NixlConnector instead of MooncakeConnector; isolates "
"connector implementation from policy effect), "
"or unified_v2 (unified + selective per-request PD-sep "
"via Mooncake; requires --bootstrap-ports and "
"kv_role=kv_both vLLM launch)")
p.add_argument("--v3-rotate-affinity", type=int, default=1, choices=[0,1],
help="unified_v3 only: 1 = rotate session affinity to decode_target "
"after migration (original behavior, empirically loses prefix cache); "
"0 = keep affinity on prefill_host so next turn hits its cache.")
p.add_argument("--connector-type", type=str, default="mooncake",
choices=["mooncake", "nixl"],
help="PD-sep handshake protocol. 'mooncake' uses pre-baked engine_id"
" + bootstrap_addr (requires --bootstrap-ports). 'nixl' uses"
" response-forward (src returns kv_transfer_params, proxy"
" relays to dst; Nixl/UCX auto-picks NVLink intra-node).")
p.add_argument("--v3-prefer-cache-target", type=int, default=1, choices=[0,1],
help="Mechanism B: unified_v3 picks decode_target with the most"
" prefix cache among low-load candidates (default 1). Set 0"
" to fall back to pure-load tie-break (cache-blind).")
p.add_argument("--lmetric-decode-weight", type=float, default=0.0,
help="Direction B: LMetric fallback adds this × ongoing_decode_tokens"
" to the queue-depth score, so hosts with heavy decode load get"
" penalised. 0 = original behavior; 0.01 is a reasonable start.")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
p.add_argument("--engine-state-uri", type=str, default="",
help="P2: real engine-state feed for migration target "
"selection (file:///dev/shm/... or redis://...). "
"Empty=disabled (shadow counters only).")
# The four flags below are accepted for bench.sh backward compatibility but
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
# scripts/legacy/*.sh which still pass them through.
p.add_argument("--max-offload-inflight", type=int, default=4,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--offload-mode", type=str, default="cached_prefill",
choices=["direct_read", "cached_prefill"],
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="[DEPRECATED] PUSH offload retired; no effect")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
SETTINGS.heavy_threshold = global_args.heavy_threshold
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.engine_state_uri = getattr(global_args, 'engine_state_uri', '') or ''
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
SETTINGS.v3_rotate_affinity = bool(getattr(global_args, 'v3_rotate_affinity', 1))
SETTINGS.connector_type = getattr(global_args, 'connector_type', 'mooncake')
SETTINGS.v3_prefer_cache_target = bool(getattr(global_args, 'v3_prefer_cache_target', 1))
SETTINGS.lmetric_decode_weight = float(getattr(global_args, 'lmetric_decode_weight', 0.0))
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s v3_rotate_affinity=%s "
"connector_type=%s v3_prefer_cache_target=%s lmetric_decode_weight=%.3f" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False),
SETTINGS.v3_rotate_affinity,
SETTINGS.connector_type,
SETTINGS.v3_prefer_cache_target,
SETTINGS.lmetric_decode_weight))
uvicorn.run(app, host=global_args.host, port=global_args.port)