scripts/b3_isolated_policy.sh:
Recognize unified_v3 as a kv_both-requiring policy; respect explicit
KV_CONNECTOR=Nixl override (so unified_v2 / unified_v3 / unified_kv_both
can run against either Mooncake or Nixl back-end). When Nixl is
selected, skip the bootstrap-ports plumbing — Nixl uses its own UCX
side-channel and the proxy forwards kv_transfer_params from the src
response body instead of pre-baking engine_id/bootstrap_addr.
scripts/cache_aware_proxy.py:
- New unified_v3 policy (~250 lines): prefill stays on session-affinity
host (preserves intra-session prefix-cache reuse), decode is migrated
to a lower-load target when the affinity host is busy with concurrent
decodes. KV transfer flows prefill_host → decode_target, opposite of
v2. Knobs: v3_min_new_tokens, v3_min_prefill_decode_busy,
v3_target_load_ratio, v3_min_load_gap, v3_rotate_affinity,
v3_prefer_cache_target. cache_miss_audit found rotation hurts cross-
turn locality (9.5% hit with vs ~80% without) so default
v3_rotate_affinity=False.
- New connector_type setting ("mooncake" | "nixl") gating the PD-sep
handshake form: mooncake uses pre-baked kv_transfer_params,
nixl forwards them from the response body.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1596 lines
68 KiB
Python
1596 lines
68 KiB
Python
"""Unified cache-aware + token-level load-balanced global scheduler.
|
||
|
||
Supports two modes:
|
||
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
||
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
||
|
||
Routing policies (--policy):
|
||
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
|
||
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
|
||
P_tokens = pending_prefill_tokens + new_uncached_tokens
|
||
BS = num_requests (waiting + running)
|
||
Session affinity: multi-turn sessions stick to same instance (all policies).
|
||
"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import time as _time
|
||
import urllib.parse
|
||
import uuid
|
||
from collections import OrderedDict
|
||
from contextlib import asynccontextmanager
|
||
from dataclasses import dataclass
|
||
|
||
import httpx
|
||
|
||
MAX_STREAM_RETRIES = 3
|
||
RETRY_DELAY_S = 0.5
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
BLOCK_SIZE = 512
|
||
CACHE_HIT_ALPHA = 1.0
|
||
|
||
|
||
@dataclass
|
||
class Settings:
|
||
"""Runtime-tunable knobs. Populated from argparse in __main__.
|
||
|
||
All routing/offload code reads from the SETTINGS singleton so that
|
||
CLI overrides survive even when the module is imported as a library
|
||
(e.g. by tests/) and __main__ does not run.
|
||
"""
|
||
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
|
||
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
|
||
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
||
heavy_threshold: int = 20000
|
||
overload_factor: float = 2.0
|
||
max_offload_inflight: int = 4
|
||
cache_gate_ratio: float = 0.0
|
||
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
|
||
|
||
# --- Patch 6.9: cost-model calibration for unified_v2 ---
|
||
# Throughput when the engine runs in kv_both mode. Lower than the
|
||
# pure-decode 7000 tok/s because kv_both adds always-on overhead
|
||
# (REPORT §3.8 documents ~+16% TPOT vs plain).
|
||
prefill_throughput_kv_both: float = 4000.0
|
||
# Calibrated RDMA transfer cost: base + bandwidth term.
|
||
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
|
||
# Bandwidth term reflects realized effective throughput, not
|
||
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
|
||
# effective on the contended kv_both path. v2 uses this lookup
|
||
# rather than the constant rdma_overhead_s.
|
||
rdma_base_overhead_s: float = 0.3
|
||
rdma_effective_gb_per_s: float = 2.7
|
||
|
||
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
|
||
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
|
||
kv_bytes_per_token: int = 98304
|
||
|
||
# --- unified_v2 gating knobs (relaxed in v2.1 after the v1 0.2% trigger rate) ---
|
||
# B2 microbench shows TPOT idx 1.9x already at new_tokens=8k and TTFT
|
||
# idx ~12x; the previous 16k threshold was too conservative and
|
||
# rejected 88.7% of candidates (window_1_results/v2_breakdown).
|
||
pd_sep_min_new_tokens: int = 8000
|
||
pd_sep_min_decodes_protected: int = 1 # any in-flight work on chosen counts
|
||
pd_sep_min_src_cache_tokens: int = 4000 # half a block; was 8000
|
||
pd_sep_min_extra_cache_tokens: int = 2000 # half a block; was 4000
|
||
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
|
||
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
|
||
pd_sep_xfer_timeout_s: float = 60.0
|
||
|
||
# --- unified_v3 (offload-decode) gating knobs -----------------------
|
||
# v3 differs from v2 in *direction*: prefill stays on the session-
|
||
# affinity host (which holds the prefix cache); decode is migrated to
|
||
# a less-loaded target. KV transfer flows prefill_host → decode_target.
|
||
# The target doesn't need cache — we're shipping the post-prefill KV
|
||
# over anyway. After successful migration the session affinity table
|
||
# rotates to decode_target so the *next* turn lands where the KV now
|
||
# lives.
|
||
v3_min_new_tokens: int = 8000 # same as v2: don't migrate tiny prefills
|
||
v3_min_prefill_decode_busy: int = 1 # prefill_host must have ≥ this many concurrent decode tokens to justify migrating
|
||
v3_target_load_ratio: float = 0.7 # target.num_requests must be < prefill_host.num_requests × this
|
||
v3_min_load_gap: int = 1 # target.num_requests must also be ≤ prefill_host - this (absolute slack)
|
||
v3_rotate_affinity: bool = True # after migration, set session affinity to decode_target.
|
||
# Empirically False is better — see cache_miss_audit (next turn hits 9.5%
|
||
# with rotation vs ~80% without), because delay_free_blocks doesn't
|
||
# actually preserve cross-turn KV on decode_target.
|
||
v3_prefer_cache_target: bool = True # Mechanism B: among low-load candidates, prefer the one
|
||
# with the most prefix cache for this prompt — vLLM's connector
|
||
# auto-transfers only the missing portion (verified via
|
||
# smoke_partial_transfer: cache-rich dst is 77% faster than
|
||
# cold dst at 33k tokens, +512 ext).
|
||
|
||
# --- KV connector selection (governs PD-sep handshake) -------------
|
||
# "mooncake": pre-baked kv_transfer_params (bootstrap_addr+engine_id+transfer_id).
|
||
# Requires --bootstrap-ports and vLLMs launched with MooncakeConnector.
|
||
# "nixl" : response-forward handshake. src returns kv_transfer_params via
|
||
# response body, proxy forwards to dst. Nixl auto-selects transport
|
||
# via UCX (CUDA IPC / NVLink on intra-node, RDMA across nodes).
|
||
connector_type: str = "mooncake"
|
||
|
||
|
||
SETTINGS = Settings()
|
||
|
||
|
||
def estimate_transfer_cost(transfer_bytes: int) -> float:
|
||
"""Calibrated RDMA transfer cost as a function of bytes.
|
||
|
||
Replaces the legacy constant rdma_overhead_s. Calibration sources:
|
||
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
|
||
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
|
||
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
|
||
~2.7 GB/s effective on the contended kv_both path.
|
||
|
||
The p90 in that same run is 6.7 s (D-side block reservation +
|
||
scheduler step delays). v2's cost model uses the *median* — being
|
||
too pessimistic would suppress all PD-sep triggers. The risk of
|
||
underestimation is mitigated by the pd_sep_margin_s safety factor.
|
||
"""
|
||
base = SETTINGS.rdma_base_overhead_s
|
||
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
|
||
return base + bw_term
|
||
|
||
|
||
def estimate_same_worker_interference_s(
|
||
new_tokens: int,
|
||
num_decodes: int,
|
||
) -> float:
|
||
"""Estimated additional latency on `num_decodes` co-located decodes
|
||
when a `new_tokens`-token prefill runs on the same worker.
|
||
|
||
Derived from B2 microbench (analysis/characterization/window_1_results.md):
|
||
same-worker prefill of size N steals decode capacity for the
|
||
prefill's duration. The penalty factor is the fraction of decode
|
||
steps stolen during the prefill window.
|
||
|
||
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
|
||
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
|
||
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
|
||
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
|
||
|
||
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
|
||
because each affected decode loses ~penalty fraction of its capacity
|
||
during the prefill window.
|
||
"""
|
||
if num_decodes <= 0:
|
||
return 0.0
|
||
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
|
||
if new_tokens < 4000:
|
||
penalty = 0.2
|
||
elif new_tokens < 16000:
|
||
penalty = 0.5
|
||
elif new_tokens < 32000:
|
||
penalty = 0.8
|
||
else:
|
||
penalty = 0.95
|
||
return prefill_dur_s * penalty * num_decodes
|
||
|
||
|
||
class InstanceState:
|
||
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||
self.url = url
|
||
self.bootstrap_port = bootstrap_port
|
||
self.client = httpx.AsyncClient(
|
||
timeout=None, base_url=url,
|
||
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
||
)
|
||
self.ongoing_tokens = 0
|
||
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
|
||
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
|
||
self.num_requests = 0 # total in-flight requests (waiting + running)
|
||
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
|
||
self.engine_id: dict[int, str] = {}
|
||
self.dp_size = 1
|
||
# OrderedDict acts as an LRU keyed by block hash; value is unused.
|
||
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
|
||
|
||
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
||
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
||
return 0
|
||
hit = 0
|
||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||
if bh in self.cached_blocks:
|
||
self.cached_blocks.move_to_end(bh) # LRU touch on hit
|
||
hit += BLOCK_SIZE
|
||
else:
|
||
break
|
||
return hit
|
||
|
||
def record_prefix(self, token_ids: list[int] | None):
|
||
if not token_ids:
|
||
return
|
||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||
if bh in self.cached_blocks:
|
||
self.cached_blocks.move_to_end(bh)
|
||
else:
|
||
self.cached_blocks[bh] = None
|
||
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
|
||
self.cached_blocks.popitem(last=False)
|
||
|
||
|
||
def _p_offload_penalty(inst: InstanceState) -> int:
|
||
"""Penalty for PD-sep mode routing (legacy)."""
|
||
if inst.active_p_offloads <= 0:
|
||
return 0
|
||
return inst.active_p_offloads * SETTINGS.heavy_threshold
|
||
|
||
|
||
def snapshot_workers(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None = None,
|
||
input_length: int = 0,
|
||
) -> list[dict]:
|
||
"""Per-worker state at route-decision time.
|
||
|
||
All routing-relevant counters plus the score each policy would
|
||
have produced for `input_length` if it were dispatched now. Cheap
|
||
enough to call on every request; B3 hot-spot analysis depends on
|
||
this being captured per decision.
|
||
"""
|
||
snap: list[dict] = []
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
|
||
new_prefill = max(0, input_length - cache_hit)
|
||
snap.append({
|
||
"idx": i,
|
||
"url": inst.url,
|
||
"ongoing_tokens": inst.ongoing_tokens,
|
||
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
|
||
"pending_prefill_tokens": inst.pending_prefill_tokens,
|
||
"num_requests": inst.num_requests,
|
||
"active_p_offloads": inst.active_p_offloads,
|
||
"cached_blocks": len(inst.cached_blocks),
|
||
"cache_hit": cache_hit,
|
||
"new_prefill": new_prefill,
|
||
"score_linear": (inst.ongoing_tokens
|
||
+ _p_offload_penalty(inst)
|
||
- CACHE_HIT_ALPHA * cache_hit),
|
||
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
|
||
* inst.num_requests,
|
||
})
|
||
return snap
|
||
|
||
|
||
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||
session_id: str | None, input_length: int,
|
||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||
"""Session-sticky with load-aware override.
|
||
|
||
Turn 2+: use session affinity UNLESS pinned instance is overloaded
|
||
or busy with P-role offloads, in which case pick least-loaded.
|
||
Turn 1: pick instance with best score (load + cache combined).
|
||
Instances doing P-role offloads get a large penalty to steer
|
||
WARM/MEDIUM traffic away.
|
||
"""
|
||
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
|
||
|
||
if session_id and session_id in affinity:
|
||
idx = affinity[session_id]
|
||
if idx < len(instances):
|
||
inst = instances[idx]
|
||
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
|
||
and inst.active_p_offloads == 0):
|
||
return inst, idx
|
||
|
||
best_idx, best_score = 0, float("inf")
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
|
||
- CACHE_HIT_ALPHA * cache_hit)
|
||
if score < best_score:
|
||
best_score = score
|
||
best_idx = i
|
||
|
||
if session_id:
|
||
affinity[session_id] = best_idx
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
def pick_instance_load_only(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None,
|
||
session_id: str | None,
|
||
input_length: int,
|
||
affinity: dict[str, int],
|
||
) -> tuple[InstanceState, int]:
|
||
"""Pure load balancing: pick instance with fewest in-flight requests.
|
||
|
||
Ignores cache hits and session affinity. Used as a B3 control to
|
||
isolate the locality contribution of cache-aware policies.
|
||
"""
|
||
best_idx = min(range(len(instances)),
|
||
key=lambda i: instances[i].num_requests)
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
def pick_instance_sticky(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None,
|
||
session_id: str | None,
|
||
input_length: int,
|
||
affinity: dict[str, int],
|
||
) -> tuple[InstanceState, int]:
|
||
"""Hard session affinity: once assigned, never break.
|
||
|
||
First turn of a session picks the instance with the lowest
|
||
num_requests; subsequent turns always return to the same instance
|
||
regardless of load. Used as a B3 control to isolate the hot-spot
|
||
cost of perfect locality.
|
||
"""
|
||
if session_id and session_id in affinity:
|
||
idx = affinity[session_id]
|
||
if idx < len(instances):
|
||
return instances[idx], idx
|
||
best_idx = min(range(len(instances)),
|
||
key=lambda i: instances[i].num_requests)
|
||
if session_id:
|
||
affinity[session_id] = best_idx
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
|
||
session_id: str | None, input_length: int,
|
||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||
"""LMetric routing: score = P_tokens × BS (OSDI'26).
|
||
|
||
Pure per-request load-based routing, no session affinity (the
|
||
session_id/affinity args are accepted for signature compatibility
|
||
with pick_instance/pick_instance_unified_hybrid but ignored).
|
||
P = pending_prefill_tokens + (input_length - cache_hit)
|
||
BS = num_requests (current batch size)
|
||
"""
|
||
best_idx, best_score = 0, float("inf")
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||
new_prefill = max(0, input_length - cache_hit)
|
||
p_tokens = inst.pending_prefill_tokens + new_prefill
|
||
bs = inst.num_requests
|
||
score = p_tokens * bs
|
||
if score < best_score:
|
||
best_score = score
|
||
best_idx = i
|
||
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
_unified_fallback_rr_counter = 0
|
||
|
||
|
||
def pick_instance_unified_hybrid(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None,
|
||
session_id: str | None,
|
||
input_length: int,
|
||
affinity: dict[str, int],
|
||
) -> tuple[InstanceState, int, dict]:
|
||
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
|
||
|
||
Affinity gate (both must hold to stick):
|
||
- affinity instance cache_hit / input_length > 0.5
|
||
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
|
||
|
||
Fallback ordering (when affinity not used):
|
||
primary: score = P_tokens * BS (LMetric)
|
||
secondary: new_uncached_tokens (prefer instance with most cache)
|
||
tertiary: num_requests (prefer least-loaded)
|
||
quaternary: round-robin (avoid degenerate inst-0 pinning
|
||
when BS=0 across the board)
|
||
|
||
Returns (chosen, idx, decision_dict). decision_dict carries the
|
||
review #7 breakdown fields so the caller can merge them verbatim.
|
||
"""
|
||
global _unified_fallback_rr_counter
|
||
n = len(instances)
|
||
avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0)
|
||
|
||
decision: dict = {
|
||
"decision": "lmetric_fallback",
|
||
"affinity_idx": None,
|
||
"chosen_idx": None,
|
||
"affinity_cache_hit": None,
|
||
"affinity_cache_ratio": None,
|
||
"affinity_num_requests": None,
|
||
"avg_num_requests": avg_reqs,
|
||
"fallback_score": None,
|
||
"tie_break_used": False,
|
||
}
|
||
|
||
if session_id and session_id in affinity:
|
||
a_idx = affinity[session_id]
|
||
if a_idx < n:
|
||
a_inst = instances[a_idx]
|
||
a_hit = a_inst.estimate_cache_hit(token_ids)
|
||
a_ratio = a_hit / max(input_length, 1)
|
||
decision["affinity_idx"] = a_idx
|
||
decision["affinity_cache_hit"] = a_hit
|
||
decision["affinity_cache_ratio"] = a_ratio
|
||
decision["affinity_num_requests"] = a_inst.num_requests
|
||
if (a_ratio > 0.5
|
||
and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
|
||
decision["decision"] = "affinity"
|
||
decision["chosen_idx"] = a_idx
|
||
return a_inst, a_idx, decision
|
||
|
||
keys: list[tuple[int, int, int, int]] = []
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||
new_prefill = max(0, input_length - cache_hit)
|
||
p_tokens = inst.pending_prefill_tokens + new_prefill
|
||
bs = inst.num_requests
|
||
score = p_tokens * bs
|
||
keys.append((score, new_prefill, bs, i))
|
||
|
||
best_triple = min(k[:3] for k in keys)
|
||
tied = [k for k in keys if k[:3] == best_triple]
|
||
if len(tied) > 1:
|
||
decision["tie_break_used"] = True
|
||
_unified_fallback_rr_counter += 1
|
||
winner = tied[_unified_fallback_rr_counter % len(tied)]
|
||
else:
|
||
winner = tied[0]
|
||
chosen_idx = winner[3]
|
||
decision["fallback_score"] = winner[0]
|
||
decision["chosen_idx"] = chosen_idx
|
||
return instances[chosen_idx], chosen_idx, decision
|
||
|
||
|
||
def pick_instance_unified_v2(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None,
|
||
session_id: str | None,
|
||
input_length: int,
|
||
affinity: dict[str, int],
|
||
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
|
||
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
|
||
|
||
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
|
||
|
||
Stage 2 asks: is there another instance with materially more cache
|
||
for this request? If yes, would doing prefill on that instance and
|
||
transferring KV to `chosen` for decode be cheaper than just doing
|
||
everything on `chosen`?
|
||
|
||
The cost model compares two scenarios in seconds-of-decode-disruption:
|
||
|
||
local: same-worker prefill on chosen of (input - chosen.cache_hit)
|
||
tokens interferes with chosen.num_decodes co-located decodes.
|
||
|
||
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
|
||
(smaller, because src has more cache) interferes with
|
||
src.num_decodes co-located decodes, plus we pay RDMA
|
||
transfer of src.cache_hit blocks to chosen.
|
||
|
||
We migrate only when local cost > pd-sep cost + safety margin AND
|
||
a set of hard gates (size, cache, decodes) are met.
|
||
|
||
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
|
||
the handler should do local routing on `chosen`. When pd_sep is
|
||
(src_inst, src_idx) the handler should do prefill-on-src,
|
||
decode-on-chosen via Mooncake.
|
||
"""
|
||
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
|
||
instances, token_ids, session_id, input_length, affinity)
|
||
|
||
decision["v2_pd_sep"] = False
|
||
decision["v2_decision"] = "local"
|
||
decision["v2_reason"] = None
|
||
|
||
if not token_ids:
|
||
decision["v2_reason"] = "no_token_ids"
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
|
||
new_local = max(0, input_length - chosen_cache_hit)
|
||
|
||
# Hard gate 1: prefill must be large enough that interference
|
||
# outweighs the fixed RDMA setup cost.
|
||
if new_local < SETTINGS.pd_sep_min_new_tokens:
|
||
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
# Hard gate 2: chosen must have live decoding work to protect.
|
||
# v2.1 simplification: pure ongoing_decode_tokens check. The previous
|
||
# gate combined num_requests and decode_tokens with AND, but
|
||
# num_requests includes requests still in prefill — adding a prefill
|
||
# to a chosen that has only its own prefill running doesn't disrupt
|
||
# any decode, so skipping makes sense. The right semantic is "skip
|
||
# iff no decode is currently happening on chosen".
|
||
if chosen.ongoing_decode_tokens == 0:
|
||
decision["v2_reason"] = (
|
||
f"chosen_no_active_decode "
|
||
f"(num_req={chosen.num_requests} decode_tok={chosen.ongoing_decode_tokens})"
|
||
)
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
# Find best alternative cache source.
|
||
best_src_idx, best_src_hit = -1, 0
|
||
for i, inst in enumerate(instances):
|
||
if i == chosen_idx:
|
||
continue
|
||
h = inst.estimate_cache_hit(token_ids)
|
||
if h > best_src_hit:
|
||
best_src_idx, best_src_hit = i, h
|
||
|
||
# Hard gate 3: src must hold meaningful cache.
|
||
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
|
||
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
# Hard gate 4: src must hold materially more cache than chosen.
|
||
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
|
||
decision["v2_reason"] = (
|
||
f"src_not_meaningfully_more_cache "
|
||
f"(src={best_src_hit} chosen={chosen_cache_hit})"
|
||
)
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
src = instances[best_src_idx]
|
||
new_src = max(0, input_length - best_src_hit)
|
||
|
||
# Cost-benefit in seconds-of-decode-disruption.
|
||
cost_local = estimate_same_worker_interference_s(
|
||
new_local, chosen.num_requests)
|
||
cost_src_interf = estimate_same_worker_interference_s(
|
||
new_src, src.num_requests)
|
||
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
|
||
cost_xfer = estimate_transfer_cost(transfer_bytes)
|
||
cost_migrate = cost_src_interf + cost_xfer
|
||
|
||
decision["v2_chosen_cache_hit"] = chosen_cache_hit
|
||
decision["v2_src_idx"] = best_src_idx
|
||
decision["v2_src_cache_hit"] = best_src_hit
|
||
decision["v2_new_local"] = new_local
|
||
decision["v2_new_src"] = new_src
|
||
decision["v2_cost_local_s"] = cost_local
|
||
decision["v2_cost_src_interf_s"] = cost_src_interf
|
||
decision["v2_cost_xfer_s"] = cost_xfer
|
||
decision["v2_cost_migrate_s"] = cost_migrate
|
||
|
||
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
|
||
decision["v2_pd_sep"] = True
|
||
decision["v2_decision"] = "pd_sep"
|
||
decision["v2_reason"] = (
|
||
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
|
||
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||
)
|
||
return chosen, chosen_idx, decision, (src, best_src_idx)
|
||
|
||
decision["v2_reason"] = (
|
||
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
|
||
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
|
||
)
|
||
return chosen, chosen_idx, decision, None
|
||
|
||
|
||
def pick_instance_unified_v3(
|
||
instances: list[InstanceState],
|
||
token_ids: list[int] | None,
|
||
session_id: str | None,
|
||
input_length: int,
|
||
affinity: dict[str, int],
|
||
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
|
||
"""unified_v3 = unified hybrid + selective DECODE migration.
|
||
|
||
Direction-reversed from unified_v2:
|
||
- prefill stays on session-affinity host (`prefill_host`) so we keep
|
||
the 93%-intra-session prefix-cache reuse intact.
|
||
- decode is migrated to a lower-load `decode_target` when the
|
||
affinity host is busy with concurrent decodes.
|
||
- KV transfer flows prefill_host → decode_target (the opposite of
|
||
v2's src → chosen).
|
||
- target does NOT need pre-existing cache — we're shipping the
|
||
post-prefill KV over anyway.
|
||
- On successful migration the *caller* must rotate
|
||
`affinity[session_id] = decode_target_idx` so the next turn lands
|
||
where the KV now lives (decode_target retains the blocks after
|
||
completion, since mooncake defaults to delay_free_blocks=True).
|
||
|
||
Decision is purely load-based on the target side:
|
||
1. new_local ≥ v3_min_new_tokens (don't pay RDMA for tiny prefills)
|
||
2. prefill_host.ongoing_decode_tokens ≥ v3_min_prefill_decode_busy
|
||
(the host is actually busy decoding; migration buys decode-bw)
|
||
3. ∃ target with both num_requests < prefill_host.num_requests × ratio
|
||
and num_requests ≤ prefill_host.num_requests − v3_min_load_gap
|
||
|
||
Returns (prefill_host, prefill_idx, decision, migrate). When migrate
|
||
is None the request is fully local on prefill_host. When migrate is
|
||
(decode_target_inst, decode_target_idx), the handler should run
|
||
prefill on prefill_host and ship KV to decode_target for decode.
|
||
"""
|
||
prefill_host, prefill_idx, decision = pick_instance_unified_hybrid(
|
||
instances, token_ids, session_id, input_length, affinity)
|
||
|
||
decision["v3_migrate"] = False
|
||
decision["v3_decision"] = "local"
|
||
decision["v3_reason"] = None
|
||
|
||
if not token_ids:
|
||
decision["v3_reason"] = "no_token_ids"
|
||
return prefill_host, prefill_idx, decision, None
|
||
|
||
prefill_cache_hit = prefill_host.estimate_cache_hit(token_ids)
|
||
new_local = max(0, input_length - prefill_cache_hit)
|
||
decision["v3_prefill_cache_hit"] = prefill_cache_hit
|
||
decision["v3_new_local"] = new_local
|
||
|
||
# Gate 1: prefill must be large enough to amortise RDMA setup.
|
||
if new_local < SETTINGS.v3_min_new_tokens:
|
||
decision["v3_reason"] = (
|
||
f"new_local_below_threshold ({new_local} < {SETTINGS.v3_min_new_tokens})"
|
||
)
|
||
return prefill_host, prefill_idx, decision, None
|
||
|
||
# Gate 2: affinity host must be busy with concurrent decodes — that's
|
||
# what migrating decode-traffic-away buys us. If the host is idle
|
||
# there's no point.
|
||
if prefill_host.ongoing_decode_tokens < SETTINGS.v3_min_prefill_decode_busy:
|
||
decision["v3_reason"] = (
|
||
f"prefill_host_not_busy "
|
||
f"(ongoing_decode_tokens={prefill_host.ongoing_decode_tokens} < "
|
||
f"{SETTINGS.v3_min_prefill_decode_busy})"
|
||
)
|
||
return prefill_host, prefill_idx, decision, None
|
||
|
||
# Gate 3: pick the lowest-load target that is materially less loaded
|
||
# than the prefill_host. Cache content irrelevant — KV ships over.
|
||
threshold_loaded = max(1,
|
||
int(prefill_host.num_requests * SETTINGS.v3_target_load_ratio))
|
||
candidates = [
|
||
(i, inst) for i, inst in enumerate(instances)
|
||
if i != prefill_idx
|
||
and inst.num_requests < threshold_loaded
|
||
and inst.num_requests <= prefill_host.num_requests - SETTINGS.v3_min_load_gap
|
||
]
|
||
if not candidates:
|
||
decision["v3_reason"] = (
|
||
f"no_low_load_target "
|
||
f"(prefill_host.num_req={prefill_host.num_requests} "
|
||
f"threshold={threshold_loaded})"
|
||
)
|
||
return prefill_host, prefill_idx, decision, None
|
||
|
||
# Mechanism B (v3_prefer_cache_target=True): rank candidates first by
|
||
# cache_hit DESC (more cache = less KV to transfer), then by load. vLLM
|
||
# auto-skips transferring overlapping prefix when dst's local cache
|
||
# matches — verified in smoke_partial_transfer: 77% faster on a 33k
|
||
# prompt when dst has the prefix already.
|
||
if SETTINGS.v3_prefer_cache_target:
|
||
decode_target_idx, decode_target = min(
|
||
candidates,
|
||
key=lambda x: (-x[1].estimate_cache_hit(token_ids),
|
||
x[1].num_requests, x[1].ongoing_tokens))
|
||
else:
|
||
decode_target_idx, decode_target = min(
|
||
candidates, key=lambda x: (x[1].num_requests, x[1].ongoing_tokens))
|
||
|
||
target_cache_hit = decode_target.estimate_cache_hit(token_ids)
|
||
decision["v3_migrate"] = True
|
||
decision["v3_decision"] = "migrate_decode"
|
||
decision["v3_target_idx"] = decode_target_idx
|
||
decision["v3_target_num_req"] = decode_target.num_requests
|
||
decision["v3_target_cache_hit"] = target_cache_hit
|
||
decision["v3_prefill_num_req"] = prefill_host.num_requests
|
||
decision["v3_reason"] = (
|
||
f"prefill_host.num_req={prefill_host.num_requests} busy; "
|
||
f"target.num_req={decode_target.num_requests} cache_hit={target_cache_hit}, "
|
||
f"transferring KV after prefill"
|
||
)
|
||
return prefill_host, prefill_idx, decision, (decode_target, decode_target_idx)
|
||
|
||
|
||
def _extract_output_token_ids_from_sse(
|
||
buffer: str,
|
||
chunk: bytes,
|
||
) -> tuple[str, list[int]]:
|
||
"""Extract vLLM streaming token_ids while preserving the raw stream."""
|
||
buffer += chunk.decode("utf-8", errors="ignore")
|
||
complete = buffer.endswith("\n") or buffer.endswith("\r")
|
||
lines = buffer.splitlines()
|
||
if complete:
|
||
buffer = ""
|
||
elif lines:
|
||
buffer = lines.pop()
|
||
else:
|
||
return buffer, []
|
||
|
||
output_ids: list[int] = []
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line.startswith("data:"):
|
||
continue
|
||
data = line[5:].strip()
|
||
if not data or data == "[DONE]":
|
||
continue
|
||
try:
|
||
payload = json.loads(data)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
choices = payload.get("choices", [])
|
||
for choice in choices:
|
||
token_ids = choice.get("token_ids")
|
||
if isinstance(token_ids, list):
|
||
output_ids.extend(
|
||
int(t) for t in token_ids if isinstance(t, int)
|
||
)
|
||
return buffer, output_ids
|
||
|
||
|
||
def _realized_tokens(
|
||
prompt_token_ids: list[int] | None,
|
||
output_token_ids: list[int],
|
||
) -> list[int] | None:
|
||
if prompt_token_ids is None:
|
||
return None
|
||
if not output_token_ids:
|
||
return prompt_token_ids
|
||
return prompt_token_ids + output_token_ids
|
||
|
||
|
||
global_args = None
|
||
combined_instances: list[InstanceState] = []
|
||
prefill_instances: list[InstanceState] = []
|
||
decode_instances: list[InstanceState] = []
|
||
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
|
||
# different instance lists, so a shared dict could mis-route after a mode switch.
|
||
session_affinity_combined: dict[str, int] = {}
|
||
session_affinity_prefill: dict[str, int] = {}
|
||
# Backwards-compat alias used by /stats etc.
|
||
session_affinity = session_affinity_combined
|
||
is_pd_sep = False
|
||
_breakdown_log: list[dict] = []
|
||
_worker_state_log: list[dict] = []
|
||
|
||
|
||
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||
for inst in instances:
|
||
if inst.bootstrap_port is None:
|
||
continue
|
||
while True:
|
||
try:
|
||
await inst.client.get("/health")
|
||
except Exception:
|
||
await asyncio.sleep(1)
|
||
continue
|
||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
||
resp = await inst.client.get(url)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
for dp_rank, dp_entry in data.items():
|
||
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
||
inst.dp_size = len(data)
|
||
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
||
break
|
||
ready.set()
|
||
|
||
|
||
async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None:
|
||
"""Read vLLM's truth: (num_running, num_waiting). Returns None on failure."""
|
||
try:
|
||
resp = await asyncio.wait_for(inst.client.get("/metrics"), timeout=5.0)
|
||
if resp.status_code != 200:
|
||
return None
|
||
text = resp.text
|
||
except Exception:
|
||
return None
|
||
running = 0
|
||
waiting = 0
|
||
for line in text.splitlines():
|
||
if line.startswith("vllm:num_requests_running"):
|
||
try:
|
||
running = int(float(line.split()[-1]))
|
||
except (ValueError, IndexError):
|
||
pass
|
||
elif line.startswith("vllm:num_requests_waiting"):
|
||
try:
|
||
waiting = int(float(line.split()[-1]))
|
||
except (ValueError, IndexError):
|
||
pass
|
||
return running, waiting
|
||
|
||
|
||
async def _reconcile_loop():
|
||
"""Periodic shadow-state reconciliation against vLLM /metrics truth.
|
||
|
||
The proxy maintains shadow counters (num_requests, ongoing_tokens,
|
||
pending_prefill_tokens, ongoing_decode_tokens) by incrementing in
|
||
`_handle_local_request` and decrementing in the generator's finally
|
||
block. When the generator never enters (client disconnect between
|
||
StreamingResponse construction and Starlette starting iteration, or
|
||
Starlette failing before iteration), the decrement never fires and
|
||
the counter stays elevated forever. Over a long run the shadow
|
||
accumulates "phantom" load that biases routing decisions away from
|
||
the affected instance.
|
||
|
||
Two-pass fix:
|
||
|
||
1. Clamp negatives (defensive; rare in practice).
|
||
2. Sample vLLM's actual num_running + num_waiting via /metrics. If
|
||
the proxy's num_requests has been *higher* than vLLM's truth for
|
||
two consecutive cycles, reconcile downward to vLLM's count.
|
||
Two-cycle persistence avoids correcting transient mismatches
|
||
(e.g., proxy just incremented but vLLM hasn't scheduled the
|
||
request yet).
|
||
|
||
Cycle period: 30 s. Two-cycle persistence threshold: 60 s of stable
|
||
drift before correction.
|
||
"""
|
||
prev_phantom: dict[str, int] = {}
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(30)
|
||
except asyncio.CancelledError:
|
||
return
|
||
for inst in combined_instances + prefill_instances + decode_instances:
|
||
# Pass 1: clamp negatives (cheap, always do).
|
||
if inst.ongoing_tokens < 0:
|
||
inst.ongoing_tokens = 0
|
||
if inst.ongoing_decode_tokens < 0:
|
||
inst.ongoing_decode_tokens = 0
|
||
if inst.pending_prefill_tokens < 0:
|
||
inst.pending_prefill_tokens = 0
|
||
if inst.num_requests < 0:
|
||
inst.num_requests = 0
|
||
if inst.active_p_offloads < 0:
|
||
inst.active_p_offloads = 0
|
||
|
||
# Pass 2: detect phantom positives by polling vLLM truth.
|
||
metrics = await _fetch_vllm_inflight(inst)
|
||
if metrics is None:
|
||
continue
|
||
running, waiting = metrics
|
||
actual_inflight = running + waiting
|
||
phantom = inst.num_requests - actual_inflight
|
||
prev = prev_phantom.get(inst.url, 0)
|
||
if phantom > 0 and prev > 0:
|
||
# Drift held across two consecutive cycles (~60 s).
|
||
# Reconcile shadow to vLLM's truth.
|
||
old_num = inst.num_requests
|
||
inst.num_requests = actual_inflight
|
||
if actual_inflight == 0:
|
||
# No requests in flight; zero all per-request counters.
|
||
inst.ongoing_tokens = 0
|
||
inst.ongoing_decode_tokens = 0
|
||
inst.pending_prefill_tokens = 0
|
||
print(
|
||
f"[reconcile] {inst.url}: phantom drift "
|
||
f"num_requests {old_num} -> {actual_inflight} "
|
||
f"(vllm running={running} waiting={waiting})"
|
||
)
|
||
prev_phantom[inst.url] = phantom
|
||
|
||
|
||
def _verify_vllm_patch():
|
||
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
|
||
|
||
The patch turns an `assert req_id in self.requests` into a soft warn so
|
||
that engines do not crash on the KV-transfer abort race (see REPORT
|
||
§3.x). If somebody upgrades vLLM without re-applying the patch, the
|
||
assert returns and elastic mode dies under load. Print a loud warning
|
||
so we catch the regression before the first HEAVY request.
|
||
"""
|
||
try:
|
||
import inspect
|
||
from vllm.v1.core.sched.scheduler import Scheduler
|
||
src = inspect.getsource(Scheduler)
|
||
if "assert req_id in self.requests" in src:
|
||
print("WARNING: vLLM scheduler still contains the unpatched "
|
||
"`assert req_id in self.requests` line; expect engine "
|
||
"death on KV-transfer abort race. Apply "
|
||
"patches/0001-fix-kv-transfer-abort-race.patch.")
|
||
else:
|
||
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
|
||
except Exception as exc:
|
||
print(f"vLLM patch self-check skipped: {exc!r}")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global is_pd_sep
|
||
app.state.ready = asyncio.Event()
|
||
|
||
_verify_vllm_patch()
|
||
|
||
reconcile_task = asyncio.create_task(_reconcile_loop())
|
||
|
||
if global_args.combined:
|
||
is_pd_sep = False
|
||
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
|
||
for i, url in enumerate(global_args.combined):
|
||
bp = bp_list[i] if i < len(bp_list) else None
|
||
combined_instances.append(InstanceState(url, bp))
|
||
|
||
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
|
||
policy = getattr(global_args, 'policy', 'linear')
|
||
# Mooncake-based modes still need bootstrap discovery; NIXL uses
|
||
# its own UCX side-channel and doesn't go through our proxy
|
||
# bootstrap path. With --connector-type=nixl, v3 also skips bootstrap.
|
||
needs_bootstrap = (
|
||
global_args.offload
|
||
or (policy in ("unified_v2", "unified_v3", "unified_kv_both")
|
||
and getattr(global_args, 'connector_type', 'mooncake') == 'mooncake')
|
||
)
|
||
if needs_bootstrap and bp_list:
|
||
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
||
elif needs_bootstrap and not bp_list:
|
||
raise RuntimeError(
|
||
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
|
||
"got empty bootstrap list."
|
||
)
|
||
else:
|
||
app.state.ready.set()
|
||
|
||
policy = getattr(global_args, 'policy', 'linear')
|
||
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
|
||
else:
|
||
is_pd_sep = True
|
||
for url, bp in global_args.prefill:
|
||
prefill_instances.append(InstanceState(url, bp))
|
||
for url in global_args.decode:
|
||
decode_instances.append(InstanceState(url))
|
||
await init_prefill_bootstrap(prefill_instances, app.state.ready)
|
||
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
||
|
||
yield
|
||
reconcile_task.cancel()
|
||
try:
|
||
await reconcile_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
for inst in combined_instances + prefill_instances + decode_instances:
|
||
await inst.client.aclose()
|
||
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
|
||
@app.post("/v1/completions")
|
||
async def handle_completions(request: Request):
|
||
return await _handle(request, "/v1/completions")
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def handle_chat(request: Request):
|
||
return await _handle(request, "/v1/chat/completions")
|
||
|
||
|
||
async def _handle(request: Request, api: str):
|
||
if not app.state.ready.is_set():
|
||
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||
|
||
req_data = await request.json()
|
||
incoming_rid = request.headers.get("X-Request-Id")
|
||
request_id = incoming_rid or str(uuid.uuid4())
|
||
prompt = req_data.get("prompt")
|
||
token_ids = prompt if isinstance(prompt, list) else None
|
||
input_length = len(token_ids) if token_ids else 0
|
||
session_id = request.headers.get("X-Session-Id")
|
||
|
||
headers = {"X-Request-Id": request_id}
|
||
api_key = os.environ.get("OPENAI_API_KEY")
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
if is_pd_sep:
|
||
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
||
input_length, session_id, headers)
|
||
else:
|
||
return await _handle_combined(api, req_data, token_ids,
|
||
input_length, session_id, headers)
|
||
|
||
|
||
async def _handle_local_request(api, req_data, headers, token_ids, input_length,
|
||
chosen: InstanceState, estimated_new: int,
|
||
breakdown: dict):
|
||
breakdown.setdefault("route_class", "LOCAL")
|
||
breakdown.setdefault("routed_to", chosen.url)
|
||
chosen.ongoing_tokens += input_length
|
||
chosen.pending_prefill_tokens += estimated_new
|
||
chosen.num_requests += 1
|
||
|
||
async def generate():
|
||
prefill_done = False
|
||
sse_buffer = ""
|
||
output_token_ids: list[int] = []
|
||
try:
|
||
for attempt in range(MAX_STREAM_RETRIES):
|
||
try:
|
||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||
sse_buffer, chunk)
|
||
output_token_ids.extend(new_output_ids)
|
||
if not prefill_done:
|
||
chosen.pending_prefill_tokens -= estimated_new
|
||
chosen.ongoing_decode_tokens += input_length
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
breakdown["t_first_token_unix"] = _time.time()
|
||
prefill_done = True
|
||
yield chunk
|
||
chosen.record_prefix(
|
||
_realized_tokens(token_ids, output_token_ids))
|
||
break
|
||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||
raise
|
||
await asyncio.sleep(RETRY_DELAY_S)
|
||
finally:
|
||
if not prefill_done:
|
||
chosen.pending_prefill_tokens -= estimated_new
|
||
else:
|
||
chosen.ongoing_decode_tokens -= input_length
|
||
chosen.ongoing_tokens -= input_length
|
||
chosen.num_requests -= 1
|
||
breakdown["t_done"] = _time.monotonic()
|
||
breakdown["t_done_unix"] = _time.time()
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||
"""Route a /v1/* request among combined (PD-colocated) instances.
|
||
|
||
--policy options:
|
||
linear: cache_hit-aware load score + sticky session affinity.
|
||
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
|
||
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
|
||
and it is not overloaded; otherwise fall back to LMetric
|
||
with a multi-key tie-breaker.
|
||
|
||
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
|
||
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
|
||
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
|
||
"""
|
||
policy = getattr(global_args, 'policy', 'linear')
|
||
t_decision_unix = _time.time()
|
||
request_id = headers.get("X-Request-Id", "")
|
||
breakdown: dict = {
|
||
"request_id": request_id,
|
||
"session_id": session_id,
|
||
"input_length": input_length,
|
||
"t_proxy_recv": _time.monotonic(),
|
||
"t_decision_unix": t_decision_unix,
|
||
"policy": policy,
|
||
}
|
||
|
||
pre_decision_workers = snapshot_workers(
|
||
combined_instances, token_ids, input_length)
|
||
|
||
pd_sep_v2: tuple[InstanceState, int] | None = None
|
||
if policy == "lmetric":
|
||
chosen, best_idx = pick_instance_lmetric(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
elif policy == "load_only":
|
||
chosen, best_idx = pick_instance_load_only(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
elif policy == "sticky":
|
||
chosen, best_idx = pick_instance_sticky(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
elif policy in ("unified", "unified_kv_both", "unified_nixl_both"):
|
||
# unified_kv_both: same picker as `unified`, but the vLLMs are
|
||
# launched in kv_role=kv_both with MooncakeConnector. Use this
|
||
# as an isolation control for `unified_v2` so the v2-vs-v1 gap
|
||
# reflects only the PD-sep branch, not the kv_both always-on
|
||
# overhead.
|
||
# unified_nixl_both: identical to unified_kv_both but with
|
||
# NixlConnector at the vLLM layer. Used to attribute the
|
||
# kv_both overhead to either Mooncake-specific code or a
|
||
# generic v1-connector cost.
|
||
chosen, best_idx, decision = pick_instance_unified_hybrid(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
breakdown.update(decision)
|
||
if session_id:
|
||
session_affinity_combined[session_id] = best_idx
|
||
elif policy == "unified_v2":
|
||
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
breakdown.update(decision)
|
||
if session_id:
|
||
session_affinity_combined[session_id] = best_idx
|
||
elif policy == "unified_v3":
|
||
# v3: prefill on affinity (cache reuse), decode migrated to a
|
||
# low-load target. KV flows prefill_host → decode_target.
|
||
# Reuses _handle_combined_pd_sep_v2 with src=prefill_host,
|
||
# dst=decode_target (the handler is direction-agnostic).
|
||
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v3(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
breakdown.update(decision)
|
||
if session_id:
|
||
if pd_sep_v2 is not None and SETTINGS.v3_rotate_affinity:
|
||
# Migration + rotation: redirect next turn to decode_target,
|
||
# assuming KV will live there. (Empirically wrong — see
|
||
# cache_miss_audit. Keep behind a flag.)
|
||
_decode_target_inst, decode_target_idx = pd_sep_v2
|
||
session_affinity_combined[session_id] = decode_target_idx
|
||
else:
|
||
# No rotation: keep affinity on prefill_host (where the prefix
|
||
# cache lives). This is the empirically correct choice.
|
||
session_affinity_combined[session_id] = best_idx
|
||
else: # linear (default)
|
||
chosen, best_idx = pick_instance(
|
||
combined_instances, token_ids, session_id, input_length,
|
||
session_affinity_combined)
|
||
|
||
chosen_snap = pre_decision_workers[best_idx]
|
||
cache_hit = chosen_snap["cache_hit"]
|
||
estimated_new = chosen_snap["new_prefill"]
|
||
breakdown.update({
|
||
"cache_hit": cache_hit,
|
||
"estimated_new_tokens": estimated_new,
|
||
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
|
||
"routed_to": chosen.url,
|
||
"chosen_idx": best_idx,
|
||
"candidate_scores": pre_decision_workers,
|
||
"chosen_score_linear": chosen_snap["score_linear"],
|
||
"chosen_score_lmetric": chosen_snap["score_lmetric"],
|
||
})
|
||
|
||
_worker_state_log.append({
|
||
"t_decision_unix": t_decision_unix,
|
||
"request_id": request_id,
|
||
"session_id": session_id,
|
||
"policy": policy,
|
||
"chosen_idx": best_idx,
|
||
"v2_pd_sep": pd_sep_v2 is not None,
|
||
"workers": pre_decision_workers,
|
||
})
|
||
|
||
if pd_sep_v2 is not None:
|
||
# Handler contract: first arg = prefill source (does same-worker
|
||
# prefill with do_remote_decode=True, max_tokens=1), second arg =
|
||
# decode target (does do_remote_prefill=True, pulls KV via
|
||
# Mooncake, decodes).
|
||
#
|
||
# v2 contract: pd_sep_v2 = (src_inst, src_idx); chosen = decode
|
||
# → src does prefill (it has more cache), chosen decodes.
|
||
# v3 contract: chosen = prefill_host (affinity, has cache);
|
||
# pd_sep_v2 = (decode_target_inst, decode_target_idx)
|
||
# → chosen does prefill (cache reuse), decode_target decodes.
|
||
if policy == "unified_v3":
|
||
decode_target_inst, decode_target_idx = pd_sep_v2
|
||
prefill_inst = chosen
|
||
breakdown["v2_src_url"] = prefill_inst.url
|
||
breakdown["v2_src_idx"] = best_idx
|
||
breakdown["v3_decode_target_url"] = decode_target_inst.url
|
||
breakdown["v3_decode_target_idx"] = decode_target_idx
|
||
return await _handle_combined_pd_sep_v2(
|
||
api, req_data, headers, token_ids, input_length,
|
||
prefill_inst, decode_target_inst, breakdown,
|
||
request_id=request_id)
|
||
else:
|
||
src_inst, src_idx = pd_sep_v2
|
||
breakdown["v2_src_url"] = src_inst.url
|
||
breakdown["v2_src_idx"] = src_idx
|
||
return await _handle_combined_pd_sep_v2(
|
||
api, req_data, headers, token_ids, input_length,
|
||
src_inst, chosen, breakdown,
|
||
request_id=request_id)
|
||
|
||
return await _handle_local_request(
|
||
api, req_data, headers, token_ids, input_length,
|
||
chosen, estimated_new, breakdown)
|
||
|
||
|
||
async def _handle_combined_pd_sep_v2(
|
||
api, req_data, headers, token_ids, input_length,
|
||
src: InstanceState, dst: InstanceState, breakdown: dict,
|
||
*, request_id: str,
|
||
):
|
||
"""Per-request PD-sep among combined instances (unified_v2 path).
|
||
|
||
src does cached prefill (max_tokens=1) and ships KV to dst via
|
||
Mooncake; dst pulls KV and decodes. Both instances must run in
|
||
kv_role=kv_both with bootstrap server enabled.
|
||
|
||
Patch 6.6: the dst streaming call uses a per-chunk read timeout
|
||
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
|
||
the request instead of hanging for 600 s.
|
||
"""
|
||
connector = SETTINGS.connector_type
|
||
if connector == "mooncake" and src.bootstrap_port is None:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=(
|
||
"Mooncake PD-sep triggered but src instance "
|
||
f"{src.url} has no bootstrap_port; launch with "
|
||
"kv_role=kv_both and pass --bootstrap-ports"
|
||
),
|
||
)
|
||
|
||
# Reserve load on both endpoints.
|
||
src.ongoing_tokens += input_length
|
||
src.num_requests += 1
|
||
dst.ongoing_tokens += input_length
|
||
dst.num_requests += 1
|
||
src_load_held = True
|
||
dst_load_held = True
|
||
|
||
# Build prefill kv_transfer_params per connector.
|
||
prefill_data = req_data.copy()
|
||
if connector == "mooncake":
|
||
prefill_data["kv_transfer_params"] = {
|
||
"do_remote_decode": True,
|
||
"do_remote_prefill": False,
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
else: # nixl: src just signals it'll produce KV for remote decode
|
||
prefill_data["kv_transfer_params"] = {"do_remote_decode": True}
|
||
prefill_data["stream"] = False
|
||
prefill_data["max_tokens"] = 1
|
||
prefill_data["min_tokens"] = 1
|
||
prefill_data.pop("max_completion_tokens", None)
|
||
prefill_data.pop("stream_options", None)
|
||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||
|
||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||
breakdown["t_prefill_sent_unix"] = _time.time()
|
||
forwarded_params: dict | None = None
|
||
try:
|
||
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["t_prefill_done_unix"] = _time.time()
|
||
resp.raise_for_status()
|
||
if connector == "nixl":
|
||
# Nixl populates kv_transfer_params in the response body with
|
||
# remote_block_ids / remote_engine_id / remote_host / remote_port.
|
||
# We must read the body BEFORE aclose.
|
||
src_resp_json = resp.json()
|
||
forwarded_params = src_resp_json.get("kv_transfer_params")
|
||
if not forwarded_params or not forwarded_params.get("remote_block_ids"):
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"Nixl src returned no remote_block_ids: {forwarded_params}",
|
||
)
|
||
await resp.aclose()
|
||
src.record_prefix(token_ids)
|
||
except Exception as e:
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["t_prefill_done_unix"] = _time.time()
|
||
breakdown["prefill_error"] = True
|
||
breakdown["error_detail"] = repr(e)[:300]
|
||
_breakdown_log.append(breakdown)
|
||
# Release reservations on failure. Clear load_held flags so the
|
||
# finally block below does not double-decrement (CRITICAL audit #1).
|
||
if src_load_held:
|
||
src.ongoing_tokens -= input_length
|
||
src.num_requests -= 1
|
||
src_load_held = False
|
||
if dst_load_held:
|
||
dst.ongoing_tokens -= input_length
|
||
dst.num_requests -= 1
|
||
dst_load_held = False
|
||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||
finally:
|
||
if src_load_held:
|
||
src.ongoing_tokens -= input_length
|
||
src.num_requests -= 1
|
||
src_load_held = False
|
||
|
||
decode_data = req_data.copy()
|
||
if connector == "mooncake":
|
||
parsed = urllib.parse.urlparse(str(src.client.base_url))
|
||
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
|
||
decode_data["kv_transfer_params"] = {
|
||
"do_remote_decode": False,
|
||
"do_remote_prefill": True,
|
||
"remote_bootstrap_addr": bootstrap_addr,
|
||
"remote_engine_id": src.engine_id.get(0, ""),
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
else: # nixl: forward what src returned
|
||
decode_data["kv_transfer_params"] = forwarded_params
|
||
|
||
breakdown["t_decode_sent"] = _time.monotonic()
|
||
breakdown["t_decode_sent_unix"] = _time.time()
|
||
|
||
xfer_timeout = httpx.Timeout(
|
||
connect=10.0,
|
||
read=SETTINGS.pd_sep_xfer_timeout_s,
|
||
write=10.0,
|
||
pool=10.0,
|
||
)
|
||
|
||
async def generate():
|
||
nonlocal dst_load_held
|
||
first_token = True
|
||
sse_buffer = ""
|
||
output_token_ids: list[int] = []
|
||
try:
|
||
async with dst.client.stream(
|
||
"POST", api, json=decode_data, headers=headers,
|
||
timeout=xfer_timeout,
|
||
) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||
sse_buffer, chunk)
|
||
output_token_ids.extend(new_output_ids)
|
||
if first_token:
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
breakdown["t_first_token_unix"] = _time.time()
|
||
first_token = False
|
||
yield chunk
|
||
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||
finally:
|
||
breakdown["t_done"] = _time.monotonic()
|
||
breakdown["t_done_unix"] = _time.time()
|
||
if dst_load_held:
|
||
dst.ongoing_tokens -= input_length
|
||
dst.num_requests -= 1
|
||
dst_load_held = False
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||
session_id, headers):
|
||
"""PD-Sep mode with per-stage breakdown profiling."""
|
||
t_decision_unix = _time.time()
|
||
breakdown = {
|
||
"request_id": request_id,
|
||
"session_id": session_id,
|
||
"input_length": input_length,
|
||
"t_proxy_recv": _time.monotonic(),
|
||
"t_decision_unix": t_decision_unix,
|
||
"policy": "pd_sep",
|
||
}
|
||
|
||
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
|
||
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
|
||
|
||
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
|
||
input_length, session_affinity_prefill)
|
||
d_idx = min(range(len(decode_instances)),
|
||
key=lambda i: decode_instances[i].ongoing_tokens)
|
||
d_inst = decode_instances[d_idx]
|
||
breakdown["p_inst"] = p_inst.url
|
||
breakdown["d_inst"] = d_inst.url
|
||
breakdown["candidate_scores_prefill"] = pre_decision_p
|
||
breakdown["candidate_scores_decode"] = pre_decision_d
|
||
breakdown["chosen_p_idx"] = p_idx
|
||
breakdown["chosen_d_idx"] = d_idx
|
||
|
||
_worker_state_log.append({
|
||
"t_decision_unix": t_decision_unix,
|
||
"request_id": request_id,
|
||
"session_id": session_id,
|
||
"policy": "pd_sep",
|
||
"chosen_p_idx": p_idx,
|
||
"chosen_d_idx": d_idx,
|
||
"workers_prefill": pre_decision_p,
|
||
"workers_decode": pre_decision_d,
|
||
})
|
||
|
||
prefill_data = req_data.copy()
|
||
prefill_data["kv_transfer_params"] = {
|
||
"do_remote_decode": True, "do_remote_prefill": False,
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
prefill_data["stream"] = False
|
||
prefill_data["max_tokens"] = 1
|
||
prefill_data["min_tokens"] = 1
|
||
prefill_data.pop("max_completion_tokens", None)
|
||
prefill_data.pop("stream_options", None)
|
||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||
|
||
p_inst.ongoing_tokens += input_length
|
||
p_inst.num_requests += 1
|
||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||
breakdown["t_prefill_sent_unix"] = _time.time()
|
||
|
||
try:
|
||
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["t_prefill_done_unix"] = _time.time()
|
||
resp.raise_for_status()
|
||
await resp.aclose()
|
||
p_inst.record_prefix(token_ids)
|
||
except Exception as e:
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["t_prefill_done_unix"] = _time.time()
|
||
breakdown["prefill_error"] = True
|
||
_breakdown_log.append(breakdown)
|
||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||
finally:
|
||
p_inst.ongoing_tokens -= input_length
|
||
p_inst.num_requests -= 1
|
||
|
||
# Send decode
|
||
d_inst.ongoing_tokens += input_length
|
||
d_inst.num_requests += 1
|
||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
||
|
||
decode_data = req_data.copy()
|
||
decode_data["kv_transfer_params"] = {
|
||
"do_remote_decode": False, "do_remote_prefill": True,
|
||
"remote_bootstrap_addr": bootstrap_addr,
|
||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
|
||
breakdown["t_decode_sent"] = _time.monotonic()
|
||
breakdown["t_decode_sent_unix"] = _time.time()
|
||
|
||
async def generate():
|
||
first_token = True
|
||
sse_buffer = ""
|
||
output_token_ids: list[int] = []
|
||
try:
|
||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
|
||
sse_buffer, chunk)
|
||
output_token_ids.extend(new_output_ids)
|
||
if first_token:
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
breakdown["t_first_token_unix"] = _time.time()
|
||
first_token = False
|
||
yield chunk
|
||
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
|
||
finally:
|
||
breakdown["t_done"] = _time.monotonic()
|
||
breakdown["t_done_unix"] = _time.time()
|
||
d_inst.ongoing_tokens -= input_length
|
||
d_inst.num_requests -= 1
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
@app.get("/breakdown")
|
||
async def get_breakdown():
|
||
"""Return per-request breakdown data for analysis."""
|
||
return _breakdown_log
|
||
|
||
|
||
@app.get("/worker_state")
|
||
async def get_worker_state():
|
||
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
|
||
return _worker_state_log
|
||
|
||
|
||
@app.get("/worker_state/latest")
|
||
async def get_worker_state_latest():
|
||
"""Return current per-worker state snapshot without recording it."""
|
||
if combined_instances:
|
||
return {
|
||
"t_unix": _time.time(),
|
||
"mode": "combined",
|
||
"workers": snapshot_workers(combined_instances),
|
||
}
|
||
return {
|
||
"t_unix": _time.time(),
|
||
"mode": "pd_sep",
|
||
"workers_prefill": snapshot_workers(prefill_instances),
|
||
"workers_decode": snapshot_workers(decode_instances),
|
||
}
|
||
|
||
|
||
@app.get("/stats")
|
||
async def get_stats():
|
||
"""Return per-instance live state for debugging."""
|
||
instances = combined_instances or prefill_instances + decode_instances
|
||
return [{
|
||
"url": inst.url,
|
||
"role": "combined",
|
||
"ongoing_tokens": inst.ongoing_tokens,
|
||
"pending_prefill_tokens": inst.pending_prefill_tokens,
|
||
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
|
||
"num_requests": inst.num_requests,
|
||
"active_p_offloads": inst.active_p_offloads,
|
||
"cached_blocks": len(inst.cached_blocks),
|
||
} for inst in instances]
|
||
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
||
p.add_argument("--port", type=int, default=8000)
|
||
p.add_argument("--host", type=str, default="0.0.0.0")
|
||
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
||
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
||
help="PD-Sep prefill: URL [bootstrap_port]")
|
||
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
||
help="PD-Sep decode: URL")
|
||
p.add_argument("--heavy-threshold", type=int, default=20000,
|
||
help="New tokens threshold for HEAVY classification (adaptive offload)")
|
||
p.add_argument("--offload", action="store_true",
|
||
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
|
||
p.add_argument("--bootstrap-ports", type=str, default="",
|
||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||
p.add_argument("--policy", type=str, default="linear",
|
||
choices=["linear", "lmetric", "load_only", "sticky",
|
||
"unified", "unified_kv_both",
|
||
"unified_nixl_both", "unified_v2",
|
||
"unified_v3"],
|
||
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
|
||
"load_only (B3 control: pure min-num_requests), "
|
||
"sticky (B3 control: hard session affinity), "
|
||
"unified (hybrid affinity + LMetric fallback), "
|
||
"unified_kv_both (unified picker on kv_both Mooncake "
|
||
"vLLMs; isolation control for unified_v2), "
|
||
"unified_nixl_both (same as unified_kv_both but using "
|
||
"NixlConnector instead of MooncakeConnector; isolates "
|
||
"connector implementation from policy effect), "
|
||
"or unified_v2 (unified + selective per-request PD-sep "
|
||
"via Mooncake; requires --bootstrap-ports and "
|
||
"kv_role=kv_both vLLM launch)")
|
||
p.add_argument("--v3-rotate-affinity", type=int, default=1, choices=[0,1],
|
||
help="unified_v3 only: 1 = rotate session affinity to decode_target "
|
||
"after migration (original behavior, empirically loses prefix cache); "
|
||
"0 = keep affinity on prefill_host so next turn hits its cache.")
|
||
p.add_argument("--connector-type", type=str, default="mooncake",
|
||
choices=["mooncake", "nixl"],
|
||
help="PD-sep handshake protocol. 'mooncake' uses pre-baked engine_id"
|
||
" + bootstrap_addr (requires --bootstrap-ports). 'nixl' uses"
|
||
" response-forward (src returns kv_transfer_params, proxy"
|
||
" relays to dst; Nixl/UCX auto-picks NVLink intra-node).")
|
||
p.add_argument("--v3-prefer-cache-target", type=int, default=1, choices=[0,1],
|
||
help="Mechanism B: unified_v3 picks decode_target with the most"
|
||
" prefix cache among low-load candidates (default 1). Set 0"
|
||
" to fall back to pure-load tie-break (cache-blind).")
|
||
p.add_argument("--overload-factor", type=float, default=2.0,
|
||
help="Break session affinity when instance load > factor * avg")
|
||
# The four flags below are accepted for bench.sh backward compatibility but
|
||
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
|
||
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
|
||
# scripts/legacy/*.sh which still pass them through.
|
||
p.add_argument("--max-offload-inflight", type=int, default=4,
|
||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||
p.add_argument("--offload-mode", type=str, default="cached_prefill",
|
||
choices=["direct_read", "cached_prefill"],
|
||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
|
||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||
p.add_argument("--decode-iteration-s", type=float, default=0.05,
|
||
help="[DEPRECATED] PUSH offload retired; no effect")
|
||
args = p.parse_args()
|
||
|
||
args.prefill = []
|
||
if args.prefill_raw:
|
||
for entry in args.prefill_raw:
|
||
url = entry[0]
|
||
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
||
args.prefill.append((url, bp))
|
||
args.decode = [e[0] for e in (args.decode_raw or [])]
|
||
|
||
if not args.combined and not args.prefill:
|
||
p.error("Must specify either --combined or --prefill/--decode")
|
||
return args
|
||
|
||
|
||
if __name__ == "__main__":
|
||
global_args = parse_args()
|
||
SETTINGS.heavy_threshold = global_args.heavy_threshold
|
||
SETTINGS.overload_factor = global_args.overload_factor
|
||
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
|
||
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
|
||
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
|
||
SETTINGS.v3_rotate_affinity = bool(getattr(global_args, 'v3_rotate_affinity', 1))
|
||
SETTINGS.connector_type = getattr(global_args, 'connector_type', 'mooncake')
|
||
SETTINGS.v3_prefer_cache_target = bool(getattr(global_args, 'v3_prefer_cache_target', 1))
|
||
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s v3_rotate_affinity=%s "
|
||
"connector_type=%s v3_prefer_cache_target=%s" % (
|
||
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
|
||
getattr(global_args, 'offload', False),
|
||
SETTINGS.v3_rotate_affinity,
|
||
SETTINGS.connector_type,
|
||
SETTINGS.v3_prefer_cache_target))
|
||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|