Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang 645b067dd4 Fix review bugs: PD-sep counter leaks, hardcoded paths, missing deps
Critical:
- cache_aware_proxy: _handle_pd_sep leaked p_inst.num_requests (never
  decremented) and never managed d_inst.num_requests; fix media_type
  from application/json to text/event-stream for SSE stream

High:
- b3_sweep/b3_isolated_policy/b3_analyze: replace hardcoded
  /home/admin/cpfs/wjh/ ROOT with script-relative $(dirname "$0")/..
- b3_analyze: replace hardcoded 8-port WORKER_MAP with dynamic
  generation from BASE_PORT and N_INSTANCES

Medium:
- analyze_breakdown: warn on stderr when records are skipped (was silent)
- deploy_vllm_patches: fail-fast on SSH/SCP errors instead of
  continuing with empty VENV_SITE
- pyproject.toml: declare fastapi and uvicorn as runtime dependencies
- launch_elastic_p2p: kill EngineCore and proxy in trap handler to
  prevent GPU memory leaks on exit
2026-05-26 15:54:55 +08:00

1361 lines
54 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import json
import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
MAX_STREAM_RETRIES = 3
RETRY_DELAY_S = 0.5
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # legacy floor; v2 uses estimate_transfer_cost
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
heavy_threshold: int = 20000
overload_factor: float = 2.0
max_offload_inflight: int = 4
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
# --- Patch 6.9: cost-model calibration for unified_v2 ---
# Throughput when the engine runs in kv_both mode. Lower than the
# pure-decode 7000 tok/s because kv_both adds always-on overhead
# (REPORT §3.8 documents ~+16% TPOT vs plain).
prefill_throughput_kv_both: float = 4000.0
# Calibrated RDMA transfer cost: base + bandwidth term.
# Floor from isolated test ≈ 0.3 s (handshake + scheduler step).
# Bandwidth term reflects realized effective throughput, not
# theoretical 25 GB/s — production p50 = 1.1 s for ~3 GB ≈ 2.7 GB/s
# effective on the contended kv_both path. v2 uses this lookup
# rather than the constant rdma_overhead_s.
rdma_base_overhead_s: float = 0.3
rdma_effective_gb_per_s: float = 2.7
# Qwen3-Coder-30B-A3B (bf16, 48 layers × 4 KV heads × 128 head_dim × 2):
# 2 × 48 × 4 × 128 × 2 = 98304 bytes per token.
kv_bytes_per_token: int = 98304
# --- unified_v2 gating knobs (relaxed in v2.1 after the v1 0.2% trigger rate) ---
# B2 microbench shows TPOT idx 1.9x already at new_tokens=8k and TTFT
# idx ~12x; the previous 16k threshold was too conservative and
# rejected 88.7% of candidates (window_1_results/v2_breakdown).
pd_sep_min_new_tokens: int = 8000
pd_sep_min_decodes_protected: int = 1 # any in-flight work on chosen counts
pd_sep_min_src_cache_tokens: int = 4000 # half a block; was 8000
pd_sep_min_extra_cache_tokens: int = 2000 # half a block; was 4000
pd_sep_margin_s: float = 0.2 # require cost gap > 0.2 s before migrating
# Patch 6.6: per-request KV-xfer wall-clock timeout (proxy side).
pd_sep_xfer_timeout_s: float = 60.0
SETTINGS = Settings()
def estimate_transfer_cost(transfer_bytes: int) -> float:
"""Calibrated RDMA transfer cost as a function of bytes.
Replaces the legacy constant rdma_overhead_s. Calibration sources:
- Floor: isolated-test ~0.3 s for a few-block PUSH (scripts/test_direct_read.py)
- Bandwidth term: outputs/contention_16s_elastic/breakdown.json shows
decode_sent->first_token p50 = 1.1 s for ~3 GB transfers, giving
~2.7 GB/s effective on the contended kv_both path.
The p90 in that same run is 6.7 s (D-side block reservation +
scheduler step delays). v2's cost model uses the *median* — being
too pessimistic would suppress all PD-sep triggers. The risk of
underestimation is mitigated by the pd_sep_margin_s safety factor.
"""
base = SETTINGS.rdma_base_overhead_s
bw_term = transfer_bytes / (SETTINGS.rdma_effective_gb_per_s * 1024 ** 3)
return base + bw_term
def estimate_same_worker_interference_s(
new_tokens: int,
num_decodes: int,
) -> float:
"""Estimated additional latency on `num_decodes` co-located decodes
when a `new_tokens`-token prefill runs on the same worker.
Derived from B2 microbench (analysis/characterization/window_1_results.md):
same-worker prefill of size N steals decode capacity for the
prefill's duration. The penalty factor is the fraction of decode
steps stolen during the prefill window.
For new_tokens < 4k: ~0.2 (chunked prefill leaves room)
For new_tokens 16k: ~0.5 (mid-regime, B2 TPOT idx 3.4×)
For new_tokens 32k: ~0.8 (B2 peak TPOT idx 7.9×)
For new_tokens > 32k: ~0.95 (B2 TTFT regime — decodes are nearly fully blocked)
The cost in seconds is roughly: prefill_duration × penalty × n_decodes,
because each affected decode loses ~penalty fraction of its capacity
during the prefill window.
"""
if num_decodes <= 0:
return 0.0
prefill_dur_s = new_tokens / SETTINGS.prefill_throughput_kv_both
if new_tokens < 4000:
penalty = 0.2
elif new_tokens < 16000:
penalty = 0.5
elif new_tokens < 32000:
penalty = 0.8
else:
penalty = 0.95
return prefill_dur_s * penalty * num_decodes
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh) # LRU touch on hit
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * SETTINGS.heavy_threshold
def snapshot_workers(
instances: list[InstanceState],
token_ids: list[int] | None = None,
input_length: int = 0,
) -> list[dict]:
"""Per-worker state at route-decision time.
All routing-relevant counters plus the score each policy would
have produced for `input_length` if it were dispatched now. Cheap
enough to call on every request; B3 hot-spot analysis depends on
this being captured per decision.
"""
snap: list[dict] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
new_prefill = max(0, input_length - cache_hit)
snap.append({
"idx": i,
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
"cache_hit": cache_hit,
"new_prefill": new_prefill,
"score_linear": (inst.ongoing_tokens
+ _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit),
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
* inst.num_requests,
})
return snap
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_load_only(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int]:
"""Pure load balancing: pick instance with fewest in-flight requests.
Ignores cache hits and session affinity. Used as a B3 control to
isolate the locality contribution of cache-aware policies.
"""
best_idx = min(range(len(instances)),
key=lambda i: instances[i].num_requests)
return instances[best_idx], best_idx
def pick_instance_sticky(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int]:
"""Hard session affinity: once assigned, never break.
First turn of a session picks the instance with the lowest
num_requests; subsequent turns always return to the same instance
regardless of load. Used as a B3 control to isolate the hot-spot
cost of perfect locality.
"""
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
return instances[idx], idx
best_idx = min(range(len(instances)),
key=lambda i: instances[i].num_requests)
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity (the
session_id/affinity args are accepted for signature compatibility
with pick_instance/pick_instance_unified_hybrid but ignored).
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
_unified_fallback_rr_counter = 0
def pick_instance_unified_hybrid(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict]:
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
Affinity gate (both must hold to stick):
- affinity instance cache_hit / input_length > 0.5
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
Fallback ordering (when affinity not used):
primary: score = P_tokens * BS (LMetric)
secondary: new_uncached_tokens (prefer instance with most cache)
tertiary: num_requests (prefer least-loaded)
quaternary: round-robin (avoid degenerate inst-0 pinning
when BS=0 across the board)
Returns (chosen, idx, decision_dict). decision_dict carries the
review #7 breakdown fields so the caller can merge them verbatim.
"""
global _unified_fallback_rr_counter
n = len(instances)
avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0)
decision: dict = {
"decision": "lmetric_fallback",
"affinity_idx": None,
"chosen_idx": None,
"affinity_cache_hit": None,
"affinity_cache_ratio": None,
"affinity_num_requests": None,
"avg_num_requests": avg_reqs,
"fallback_score": None,
"tie_break_used": False,
}
if session_id and session_id in affinity:
a_idx = affinity[session_id]
if a_idx < n:
a_inst = instances[a_idx]
a_hit = a_inst.estimate_cache_hit(token_ids)
a_ratio = a_hit / max(input_length, 1)
decision["affinity_idx"] = a_idx
decision["affinity_cache_hit"] = a_hit
decision["affinity_cache_ratio"] = a_ratio
decision["affinity_num_requests"] = a_inst.num_requests
if (a_ratio > 0.5
and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
decision["decision"] = "affinity"
decision["chosen_idx"] = a_idx
return a_inst, a_idx, decision
keys: list[tuple[int, int, int, int]] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
keys.append((score, new_prefill, bs, i))
best_triple = min(k[:3] for k in keys)
tied = [k for k in keys if k[:3] == best_triple]
if len(tied) > 1:
decision["tie_break_used"] = True
_unified_fallback_rr_counter += 1
winner = tied[_unified_fallback_rr_counter % len(tied)]
else:
winner = tied[0]
chosen_idx = winner[3]
decision["fallback_score"] = winner[0]
decision["chosen_idx"] = chosen_idx
return instances[chosen_idx], chosen_idx, decision
def pick_instance_unified_v2(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict, tuple[InstanceState, int] | None]:
"""unified_v2 = unified hybrid + selective per-request PD-sep trigger.
Stage 1 picks `chosen` exactly as `pick_instance_unified_hybrid`.
Stage 2 asks: is there another instance with materially more cache
for this request? If yes, would doing prefill on that instance and
transferring KV to `chosen` for decode be cheaper than just doing
everything on `chosen`?
The cost model compares two scenarios in seconds-of-decode-disruption:
local: same-worker prefill on chosen of (input - chosen.cache_hit)
tokens interferes with chosen.num_decodes co-located decodes.
pd-sep: same-worker prefill on src of (input - src.cache_hit) tokens
(smaller, because src has more cache) interferes with
src.num_decodes co-located decodes, plus we pay RDMA
transfer of src.cache_hit blocks to chosen.
We migrate only when local cost > pd-sep cost + safety margin AND
a set of hard gates (size, cache, decodes) are met.
Returns (chosen, chosen_idx, decision, pd_sep). When pd_sep is None
the handler should do local routing on `chosen`. When pd_sep is
(src_inst, src_idx) the handler should do prefill-on-src,
decode-on-chosen via Mooncake.
"""
chosen, chosen_idx, decision = pick_instance_unified_hybrid(
instances, token_ids, session_id, input_length, affinity)
decision["v2_pd_sep"] = False
decision["v2_decision"] = "local"
decision["v2_reason"] = None
if not token_ids:
decision["v2_reason"] = "no_token_ids"
return chosen, chosen_idx, decision, None
chosen_cache_hit = chosen.estimate_cache_hit(token_ids)
new_local = max(0, input_length - chosen_cache_hit)
# Hard gate 1: prefill must be large enough that interference
# outweighs the fixed RDMA setup cost.
if new_local < SETTINGS.pd_sep_min_new_tokens:
decision["v2_reason"] = f"new_local_below_threshold ({new_local} < {SETTINGS.pd_sep_min_new_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 2: chosen must have live decoding work to protect.
# v2.1 simplification: pure ongoing_decode_tokens check. The previous
# gate combined num_requests and decode_tokens with AND, but
# num_requests includes requests still in prefill — adding a prefill
# to a chosen that has only its own prefill running doesn't disrupt
# any decode, so skipping makes sense. The right semantic is "skip
# iff no decode is currently happening on chosen".
if chosen.ongoing_decode_tokens == 0:
decision["v2_reason"] = (
f"chosen_no_active_decode "
f"(num_req={chosen.num_requests} decode_tok={chosen.ongoing_decode_tokens})"
)
return chosen, chosen_idx, decision, None
# Find best alternative cache source.
best_src_idx, best_src_hit = -1, 0
for i, inst in enumerate(instances):
if i == chosen_idx:
continue
h = inst.estimate_cache_hit(token_ids)
if h > best_src_hit:
best_src_idx, best_src_hit = i, h
# Hard gate 3: src must hold meaningful cache.
if best_src_hit < SETTINGS.pd_sep_min_src_cache_tokens:
decision["v2_reason"] = f"src_cache_below_threshold ({best_src_hit} < {SETTINGS.pd_sep_min_src_cache_tokens})"
return chosen, chosen_idx, decision, None
# Hard gate 4: src must hold materially more cache than chosen.
if best_src_hit - chosen_cache_hit < SETTINGS.pd_sep_min_extra_cache_tokens:
decision["v2_reason"] = (
f"src_not_meaningfully_more_cache "
f"(src={best_src_hit} chosen={chosen_cache_hit})"
)
return chosen, chosen_idx, decision, None
src = instances[best_src_idx]
new_src = max(0, input_length - best_src_hit)
# Cost-benefit in seconds-of-decode-disruption.
cost_local = estimate_same_worker_interference_s(
new_local, chosen.num_requests)
cost_src_interf = estimate_same_worker_interference_s(
new_src, src.num_requests)
transfer_bytes = best_src_hit * SETTINGS.kv_bytes_per_token
cost_xfer = estimate_transfer_cost(transfer_bytes)
cost_migrate = cost_src_interf + cost_xfer
decision["v2_chosen_cache_hit"] = chosen_cache_hit
decision["v2_src_idx"] = best_src_idx
decision["v2_src_cache_hit"] = best_src_hit
decision["v2_new_local"] = new_local
decision["v2_new_src"] = new_src
decision["v2_cost_local_s"] = cost_local
decision["v2_cost_src_interf_s"] = cost_src_interf
decision["v2_cost_xfer_s"] = cost_xfer
decision["v2_cost_migrate_s"] = cost_migrate
if cost_local > cost_migrate + SETTINGS.pd_sep_margin_s:
decision["v2_pd_sep"] = True
decision["v2_decision"] = "pd_sep"
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s > migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, (src, best_src_idx)
decision["v2_reason"] = (
f"local_cost {cost_local:.2f}s <= migrate_cost {cost_migrate:.2f}s "
f"+ margin {SETTINGS.pd_sep_margin_s:.2f}s"
)
return chosen, chosen_idx, decision, None
def _extract_output_token_ids_from_sse(
buffer: str,
chunk: bytes,
) -> tuple[str, list[int]]:
"""Extract vLLM streaming token_ids while preserving the raw stream."""
buffer += chunk.decode("utf-8", errors="ignore")
complete = buffer.endswith("\n") or buffer.endswith("\r")
lines = buffer.splitlines()
if complete:
buffer = ""
elif lines:
buffer = lines.pop()
else:
return buffer, []
output_ids: list[int] = []
for line in lines:
line = line.strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if not data or data == "[DONE]":
continue
try:
payload = json.loads(data)
except json.JSONDecodeError:
continue
choices = payload.get("choices", [])
for choice in choices:
token_ids = choice.get("token_ids")
if isinstance(token_ids, list):
output_ids.extend(
int(t) for t in token_ids if isinstance(t, int)
)
return buffer, output_ids
def _realized_tokens(
prompt_token_ids: list[int] | None,
output_token_ids: list[int],
) -> list[int] | None:
if prompt_token_ids is None:
return None
if not output_token_ids:
return prompt_token_ids
return prompt_token_ids + output_token_ids
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
# different instance lists, so a shared dict could mis-route after a mode switch.
session_affinity_combined: dict[str, int] = {}
session_affinity_prefill: dict[str, int] = {}
# Backwards-compat alias used by /stats etc.
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
_worker_state_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
async def _fetch_vllm_inflight(inst: "InstanceState") -> tuple[int, int] | None:
"""Read vLLM's truth: (num_running, num_waiting). Returns None on failure."""
try:
resp = await asyncio.wait_for(inst.client.get("/metrics"), timeout=5.0)
if resp.status_code != 200:
return None
text = resp.text
except Exception:
return None
running = 0
waiting = 0
for line in text.splitlines():
if line.startswith("vllm:num_requests_running"):
try:
running = int(float(line.split()[-1]))
except (ValueError, IndexError):
pass
elif line.startswith("vllm:num_requests_waiting"):
try:
waiting = int(float(line.split()[-1]))
except (ValueError, IndexError):
pass
return running, waiting
async def _reconcile_loop():
"""Periodic shadow-state reconciliation against vLLM /metrics truth.
The proxy maintains shadow counters (num_requests, ongoing_tokens,
pending_prefill_tokens, ongoing_decode_tokens) by incrementing in
`_handle_local_request` and decrementing in the generator's finally
block. When the generator never enters (client disconnect between
StreamingResponse construction and Starlette starting iteration, or
Starlette failing before iteration), the decrement never fires and
the counter stays elevated forever. Over a long run the shadow
accumulates "phantom" load that biases routing decisions away from
the affected instance.
Two-pass fix:
1. Clamp negatives (defensive; rare in practice).
2. Sample vLLM's actual num_running + num_waiting via /metrics. If
the proxy's num_requests has been *higher* than vLLM's truth for
two consecutive cycles, reconcile downward to vLLM's count.
Two-cycle persistence avoids correcting transient mismatches
(e.g., proxy just incremented but vLLM hasn't scheduled the
request yet).
Cycle period: 30 s. Two-cycle persistence threshold: 60 s of stable
drift before correction.
"""
prev_phantom: dict[str, int] = {}
while True:
try:
await asyncio.sleep(30)
except asyncio.CancelledError:
return
for inst in combined_instances + prefill_instances + decode_instances:
# Pass 1: clamp negatives (cheap, always do).
if inst.ongoing_tokens < 0:
inst.ongoing_tokens = 0
if inst.ongoing_decode_tokens < 0:
inst.ongoing_decode_tokens = 0
if inst.pending_prefill_tokens < 0:
inst.pending_prefill_tokens = 0
if inst.num_requests < 0:
inst.num_requests = 0
if inst.active_p_offloads < 0:
inst.active_p_offloads = 0
# Pass 2: detect phantom positives by polling vLLM truth.
metrics = await _fetch_vllm_inflight(inst)
if metrics is None:
continue
running, waiting = metrics
actual_inflight = running + waiting
phantom = inst.num_requests - actual_inflight
prev = prev_phantom.get(inst.url, 0)
if phantom > 0 and prev > 0:
# Drift held across two consecutive cycles (~60 s).
# Reconcile shadow to vLLM's truth.
old_num = inst.num_requests
inst.num_requests = actual_inflight
if actual_inflight == 0:
# No requests in flight; zero all per-request counters.
inst.ongoing_tokens = 0
inst.ongoing_decode_tokens = 0
inst.pending_prefill_tokens = 0
print(
f"[reconcile] {inst.url}: phantom drift "
f"num_requests {old_num} -> {actual_inflight} "
f"(vllm running={running} waiting={waiting})"
)
prev_phantom[inst.url] = phantom
def _verify_vllm_patch():
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
The patch turns an `assert req_id in self.requests` into a soft warn so
that engines do not crash on the KV-transfer abort race (see REPORT
§3.x). If somebody upgrades vLLM without re-applying the patch, the
assert returns and elastic mode dies under load. Print a loud warning
so we catch the regression before the first HEAVY request.
"""
try:
import inspect
from vllm.v1.core.sched.scheduler import Scheduler
src = inspect.getsource(Scheduler)
if "assert req_id in self.requests" in src:
print("WARNING: vLLM scheduler still contains the unpatched "
"`assert req_id in self.requests` line; expect engine "
"death on KV-transfer abort race. Apply "
"patches/0001-fix-kv-transfer-abort-race.patch.")
else:
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
except Exception as exc:
print(f"vLLM patch self-check skipped: {exc!r}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
_verify_vllm_patch()
reconcile_task = asyncio.create_task(_reconcile_loop())
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
policy = getattr(global_args, 'policy', 'linear')
# Mooncake-based modes still need bootstrap discovery; NIXL uses
# its own UCX side-channel and doesn't go through our proxy
# bootstrap path (and unified_nixl_both never PD-seps anyway).
needs_bootstrap = (
global_args.offload
or policy in ("unified_v2", "unified_kv_both")
)
if needs_bootstrap and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
elif needs_bootstrap and not bp_list:
raise RuntimeError(
f"--policy {policy} requires --bootstrap-ports for KV transfer; "
"got empty bootstrap list."
)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
reconcile_task.cancel()
try:
await reconcile_task
except asyncio.CancelledError:
pass
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
incoming_rid = request.headers.get("X-Request-Id")
request_id = incoming_rid or str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen: InstanceState, estimated_new: int,
breakdown: dict):
breakdown.setdefault("route_class", "LOCAL")
breakdown.setdefault("routed_to", chosen.url)
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
sse_buffer = ""
output_token_ids: list[int] = []
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
prefill_done = True
yield chunk
chosen.record_prefix(
_realized_tokens(token_ids, output_token_ids))
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Route a /v1/* request among combined (PD-colocated) instances.
--policy options:
linear: cache_hit-aware load score + sticky session affinity.
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
and it is not overloaded; otherwise fall back to LMetric
with a multi-key tie-breaker.
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
"""
policy = getattr(global_args, 'policy', 'linear')
t_decision_unix = _time.time()
request_id = headers.get("X-Request-Id", "")
breakdown: dict = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": policy,
}
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
pd_sep_v2: tuple[InstanceState, int] | None = None
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "load_only":
chosen, best_idx = pick_instance_load_only(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "sticky":
chosen, best_idx = pick_instance_sticky(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy in ("unified", "unified_kv_both", "unified_nixl_both"):
# unified_kv_both: same picker as `unified`, but the vLLMs are
# launched in kv_role=kv_both with MooncakeConnector. Use this
# as an isolation control for `unified_v2` so the v2-vs-v1 gap
# reflects only the PD-sep branch, not the kv_both always-on
# overhead.
# unified_nixl_both: identical to unified_kv_both but with
# NixlConnector at the vLLM layer. Used to attribute the
# kv_both overhead to either Mooncake-specific code or a
# generic v1-connector cost.
chosen, best_idx, decision = pick_instance_unified_hybrid(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
elif policy == "unified_v2":
chosen, best_idx, decision, pd_sep_v2 = pick_instance_unified_v2(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
else: # linear (default)
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
chosen_snap = pre_decision_workers[best_idx]
cache_hit = chosen_snap["cache_hit"]
estimated_new = chosen_snap["new_prefill"]
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"route_class": "LOCAL" if pd_sep_v2 is None else "PD_SEP_V2",
"routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
"chosen_score_linear": chosen_snap["score_linear"],
"chosen_score_lmetric": chosen_snap["score_lmetric"],
})
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"v2_pd_sep": pd_sep_v2 is not None,
"workers": pre_decision_workers,
})
if pd_sep_v2 is not None:
src_inst, src_idx = pd_sep_v2
breakdown["v2_src_url"] = src_inst.url
breakdown["v2_src_idx"] = src_idx
return await _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src_inst, chosen, breakdown,
request_id=request_id)
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
async def _handle_combined_pd_sep_v2(
api, req_data, headers, token_ids, input_length,
src: InstanceState, dst: InstanceState, breakdown: dict,
*, request_id: str,
):
"""Per-request PD-sep among combined instances (unified_v2 path).
src does cached prefill (max_tokens=1) and ships KV to dst via
Mooncake; dst pulls KV and decodes. Both instances must run in
kv_role=kv_both with bootstrap server enabled.
Patch 6.6: the dst streaming call uses a per-chunk read timeout
of SETTINGS.pd_sep_xfer_timeout_s, so a stuck KV transfer fails
the request instead of hanging for 600 s.
"""
if src.bootstrap_port is None:
raise HTTPException(
status_code=500,
detail=(
"unified_v2 PD-sep triggered but src instance "
f"{src.url} has no bootstrap_port; launch with "
"kv_role=kv_both and pass --bootstrap-ports"
),
)
# Reserve load on both endpoints.
src.ongoing_tokens += input_length
src.num_requests += 1
dst.ongoing_tokens += input_length
dst.num_requests += 1
src_load_held = True
dst_load_held = True
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await src.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
src.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
breakdown["error_detail"] = repr(e)[:300]
_breakdown_log.append(breakdown)
# Release reservations on failure.
src.ongoing_tokens -= input_length
src.num_requests -= 1
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
if src_load_held:
src.ongoing_tokens -= input_length
src.num_requests -= 1
src_load_held = False
parsed = urllib.parse.urlparse(str(src.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{src.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": src.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
xfer_timeout = httpx.Timeout(
connect=10.0,
read=SETTINGS.pd_sep_xfer_timeout_s,
write=10.0,
pool=10.0,
)
async def generate():
nonlocal dst_load_held
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with dst.client.stream(
"POST", api, json=decode_data, headers=headers,
timeout=xfer_timeout,
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
dst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
if dst_load_held:
dst.ongoing_tokens -= input_length
dst.num_requests -= 1
dst_load_held = False
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
t_decision_unix = _time.time()
breakdown = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": "pd_sep",
}
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_idx = min(range(len(decode_instances)),
key=lambda i: decode_instances[i].ongoing_tokens)
d_inst = decode_instances[d_idx]
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["candidate_scores_prefill"] = pre_decision_p
breakdown["candidate_scores_decode"] = pre_decision_d
breakdown["chosen_p_idx"] = p_idx
breakdown["chosen_d_idx"] = d_idx
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": "pd_sep",
"chosen_p_idx": p_idx,
"chosen_d_idx": d_idx,
"workers_prefill": pre_decision_p,
"workers_decode": pre_decision_d,
})
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
p_inst.num_requests += 1
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
p_inst.num_requests -= 1
# Send decode
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/worker_state")
async def get_worker_state():
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
return _worker_state_log
@app.get("/worker_state/latest")
async def get_worker_state_latest():
"""Return current per-worker state snapshot without recording it."""
if combined_instances:
return {
"t_unix": _time.time(),
"mode": "combined",
"workers": snapshot_workers(combined_instances),
}
return {
"t_unix": _time.time(),
"mode": "pd_sep",
"workers_prefill": snapshot_workers(prefill_instances),
"workers_decode": snapshot_workers(decode_instances),
}
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear",
choices=["linear", "lmetric", "load_only", "sticky",
"unified", "unified_kv_both",
"unified_nixl_both", "unified_v2"],
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
"load_only (B3 control: pure min-num_requests), "
"sticky (B3 control: hard session affinity), "
"unified (hybrid affinity + LMetric fallback), "
"unified_kv_both (unified picker on kv_both Mooncake "
"vLLMs; isolation control for unified_v2), "
"unified_nixl_both (same as unified_kv_both but using "
"NixlConnector instead of MooncakeConnector; isolates "
"connector implementation from policy effect), "
"or unified_v2 (unified + selective per-request PD-sep "
"via Mooncake; requires --bootstrap-ports and "
"kv_role=kv_both vLLM launch)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
# The four flags below are accepted for bench.sh backward compatibility but
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
# scripts/legacy/*.sh which still pass them through.
p.add_argument("--max-offload-inflight", type=int, default=4,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--offload-mode", type=str, default="cached_prefill",
choices=["direct_read", "cached_prefill"],
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="[DEPRECATED] PUSH offload retired; no effect")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
SETTINGS.heavy_threshold = global_args.heavy_threshold
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))
uvicorn.run(app, host=global_args.host, port=global_args.port)