Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang fe556b5d98 A2: proxy worker-state snapshot and request-id passthrough
Honor incoming X-Request-Id so replayer metrics and proxy breakdown
share a join key. Each route decision now captures session_id, the
full per-worker candidate-score snapshot (ongoing/pending/num_requests
/cached_blocks plus both linear and lmetric scores), the chosen score,
and unix timestamps for first-token and done events. A separate
_worker_state_log records one row per decision and is exposed via
GET /worker_state; GET /worker_state/latest returns a live snapshot
without recording it.

Required by Batch 3 (session hot-spot proof) and Batch 5 (failure
attribution); existing breakdown.json had no per-worker state at
decision time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 16:19:01 +08:00

846 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import json
import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
MAX_STREAM_RETRIES = 3
RETRY_DELAY_S = 0.5
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
heavy_threshold: int = 20000
overload_factor: float = 2.0
max_offload_inflight: int = 4
cache_gate_ratio: float = 0.0
decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20)
SETTINGS = Settings()
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh) # LRU touch on hit
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * SETTINGS.heavy_threshold
def snapshot_workers(
instances: list[InstanceState],
token_ids: list[int] | None = None,
input_length: int = 0,
) -> list[dict]:
"""Per-worker state at route-decision time.
All routing-relevant counters plus the score each policy would
have produced for `input_length` if it were dispatched now. Cheap
enough to call on every request; B3 hot-spot analysis depends on
this being captured per decision.
"""
snap: list[dict] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0
new_prefill = max(0, input_length - cache_hit)
snap.append({
"idx": i,
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
"cache_hit": cache_hit,
"new_prefill": new_prefill,
"score_linear": (inst.ongoing_tokens
+ _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit),
"score_lmetric": (inst.pending_prefill_tokens + new_prefill)
* inst.num_requests,
})
return snap
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity (the
session_id/affinity args are accepted for signature compatibility
with pick_instance/pick_instance_unified_hybrid but ignored).
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
_unified_fallback_rr_counter = 0
def pick_instance_unified_hybrid(
instances: list[InstanceState],
token_ids: list[int] | None,
session_id: str | None,
input_length: int,
affinity: dict[str, int],
) -> tuple[InstanceState, int, dict]:
"""Hybrid routing: high-cache affinity, else LMetric with tie-breaker.
Affinity gate (both must hold to stick):
- affinity instance cache_hit / input_length > 0.5
- affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor
Fallback ordering (when affinity not used):
primary: score = P_tokens * BS (LMetric)
secondary: new_uncached_tokens (prefer instance with most cache)
tertiary: num_requests (prefer least-loaded)
quaternary: round-robin (avoid degenerate inst-0 pinning
when BS=0 across the board)
Returns (chosen, idx, decision_dict). decision_dict carries the
review #7 breakdown fields so the caller can merge them verbatim.
"""
global _unified_fallback_rr_counter
n = len(instances)
avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0)
decision: dict = {
"decision": "lmetric_fallback",
"affinity_idx": None,
"chosen_idx": None,
"affinity_cache_hit": None,
"affinity_cache_ratio": None,
"affinity_num_requests": None,
"avg_num_requests": avg_reqs,
"fallback_score": None,
"tie_break_used": False,
}
if session_id and session_id in affinity:
a_idx = affinity[session_id]
if a_idx < n:
a_inst = instances[a_idx]
a_hit = a_inst.estimate_cache_hit(token_ids)
a_ratio = a_hit / max(input_length, 1)
decision["affinity_idx"] = a_idx
decision["affinity_cache_hit"] = a_hit
decision["affinity_cache_ratio"] = a_ratio
decision["affinity_num_requests"] = a_inst.num_requests
if (a_ratio > 0.5
and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor):
decision["decision"] = "affinity"
decision["chosen_idx"] = a_idx
return a_inst, a_idx, decision
keys: list[tuple[int, int, int, int]] = []
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
keys.append((score, new_prefill, bs, i))
best_triple = min(k[:3] for k in keys)
tied = [k for k in keys if k[:3] == best_triple]
if len(tied) > 1:
decision["tie_break_used"] = True
_unified_fallback_rr_counter += 1
winner = tied[_unified_fallback_rr_counter % len(tied)]
else:
winner = tied[0]
chosen_idx = winner[3]
decision["fallback_score"] = winner[0]
decision["chosen_idx"] = chosen_idx
return instances[chosen_idx], chosen_idx, decision
def _extract_output_token_ids_from_sse(
buffer: str,
chunk: bytes,
) -> tuple[str, list[int]]:
"""Extract vLLM streaming token_ids while preserving the raw stream."""
buffer += chunk.decode("utf-8", errors="ignore")
complete = buffer.endswith("\n") or buffer.endswith("\r")
lines = buffer.splitlines()
if complete:
buffer = ""
elif lines:
buffer = lines.pop()
else:
return buffer, []
output_ids: list[int] = []
for line in lines:
line = line.strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if not data or data == "[DONE]":
continue
try:
payload = json.loads(data)
except json.JSONDecodeError:
continue
choices = payload.get("choices", [])
for choice in choices:
token_ids = choice.get("token_ids")
if isinstance(token_ids, list):
output_ids.extend(
int(t) for t in token_ids if isinstance(t, int)
)
return buffer, output_ids
def _realized_tokens(
prompt_token_ids: list[int] | None,
output_token_ids: list[int],
) -> list[int] | None:
if prompt_token_ids is None:
return None
if not output_token_ids:
return prompt_token_ids
return prompt_token_ids + output_token_ids
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
# different instance lists, so a shared dict could mis-route after a mode switch.
session_affinity_combined: dict[str, int] = {}
session_affinity_prefill: dict[str, int] = {}
# Backwards-compat alias used by /stats etc.
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
_worker_state_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
async def _reconcile_loop():
"""Periodic safety net for shadow state.
StreamingResponse generators decrement load counters in their finally
block, but if a client disconnects before the body is consumed the
generator is never entered and the decrement is lost. Clamp negative
drift every minute so router scores stay sane. This does not replace
proper exact-state syncing with vLLM (see TODO.md item 6).
"""
while True:
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
return
for inst in combined_instances + prefill_instances + decode_instances:
if inst.ongoing_tokens < 0:
inst.ongoing_tokens = 0
if inst.ongoing_decode_tokens < 0:
inst.ongoing_decode_tokens = 0
if inst.pending_prefill_tokens < 0:
inst.pending_prefill_tokens = 0
if inst.num_requests < 0:
inst.num_requests = 0
if inst.active_p_offloads < 0:
inst.active_p_offloads = 0
def _verify_vllm_patch():
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
The patch turns an `assert req_id in self.requests` into a soft warn so
that engines do not crash on the KV-transfer abort race (see REPORT
§3.x). If somebody upgrades vLLM without re-applying the patch, the
assert returns and elastic mode dies under load. Print a loud warning
so we catch the regression before the first HEAVY request.
"""
try:
import inspect
from vllm.v1.core.sched.scheduler import Scheduler
src = inspect.getsource(Scheduler)
if "assert req_id in self.requests" in src:
print("WARNING: vLLM scheduler still contains the unpatched "
"`assert req_id in self.requests` line; expect engine "
"death on KV-transfer abort race. Apply "
"patches/0001-fix-kv-transfer-abort-race.patch.")
else:
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
except Exception as exc:
print(f"vLLM patch self-check skipped: {exc!r}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
_verify_vllm_patch()
reconcile_task = asyncio.create_task(_reconcile_loop())
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
reconcile_task.cancel()
try:
await reconcile_task
except asyncio.CancelledError:
pass
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
incoming_rid = request.headers.get("X-Request-Id")
request_id = incoming_rid or str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_local_request(api, req_data, headers, token_ids, input_length,
chosen: InstanceState, estimated_new: int,
breakdown: dict):
breakdown.setdefault("route_class", "LOCAL")
breakdown.setdefault("routed_to", chosen.url)
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
sse_buffer = ""
output_token_ids: list[int] = []
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
prefill_done = True
yield chunk
chosen.record_prefix(
_realized_tokens(token_ids, output_token_ids))
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Route a /v1/* request among combined (PD-colocated) instances.
--policy options:
linear: cache_hit-aware load score + sticky session affinity.
lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity.
unified: hybrid — stick to affinity instance when cache_ratio > 0.5
and it is not overloaded; otherwise fall back to LMetric
with a multi-key tie-breaker.
PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and
commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants
both regressed E2E tail). Re-enabling requires a new transfer mechanism.
"""
policy = getattr(global_args, 'policy', 'linear')
t_decision_unix = _time.time()
request_id = headers.get("X-Request-Id", "")
breakdown: dict = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": policy,
}
pre_decision_workers = snapshot_workers(
combined_instances, token_ids, input_length)
if policy == "lmetric":
chosen, best_idx = pick_instance_lmetric(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
elif policy == "unified":
chosen, best_idx, decision = pick_instance_unified_hybrid(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
breakdown.update(decision)
if session_id:
session_affinity_combined[session_id] = best_idx
else: # linear (default)
chosen, best_idx = pick_instance(
combined_instances, token_ids, session_id, input_length,
session_affinity_combined)
chosen_snap = pre_decision_workers[best_idx]
cache_hit = chosen_snap["cache_hit"]
estimated_new = chosen_snap["new_prefill"]
breakdown.update({
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"route_class": "LOCAL",
"routed_to": chosen.url,
"chosen_idx": best_idx,
"candidate_scores": pre_decision_workers,
"chosen_score_linear": chosen_snap["score_linear"],
"chosen_score_lmetric": chosen_snap["score_lmetric"],
})
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": policy,
"chosen_idx": best_idx,
"workers": pre_decision_workers,
})
return await _handle_local_request(
api, req_data, headers, token_ids, input_length,
chosen, estimated_new, breakdown)
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
t_decision_unix = _time.time()
breakdown = {
"request_id": request_id,
"session_id": session_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
"t_decision_unix": t_decision_unix,
"policy": "pd_sep",
}
pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length)
pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length)
p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_idx = min(range(len(decode_instances)),
key=lambda i: decode_instances[i].ongoing_tokens)
d_inst = decode_instances[d_idx]
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["candidate_scores_prefill"] = pre_decision_p
breakdown["candidate_scores_decode"] = pre_decision_d
breakdown["chosen_p_idx"] = p_idx
breakdown["chosen_d_idx"] = d_idx
_worker_state_log.append({
"t_decision_unix": t_decision_unix,
"request_id": request_id,
"session_id": session_id,
"policy": "pd_sep",
"chosen_p_idx": p_idx,
"chosen_d_idx": d_idx,
"workers_prefill": pre_decision_p,
"workers_decode": pre_decision_d,
})
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data["min_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
breakdown["t_prefill_sent_unix"] = _time.time()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["t_prefill_done_unix"] = _time.time()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
breakdown["t_decode_sent_unix"] = _time.time()
async def generate():
first_token = True
sse_buffer = ""
output_token_ids: list[int] = []
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
sse_buffer, new_output_ids = _extract_output_token_ids_from_sse(
sse_buffer, chunk)
output_token_ids.extend(new_output_ids)
if first_token:
breakdown["t_first_token"] = _time.monotonic()
breakdown["t_first_token_unix"] = _time.time()
first_token = False
yield chunk
d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids))
finally:
breakdown["t_done"] = _time.monotonic()
breakdown["t_done_unix"] = _time.time()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/worker_state")
async def get_worker_state():
"""Return per-decision worker-state snapshot log (one entry per route decision)."""
return _worker_state_log
@app.get("/worker_state/latest")
async def get_worker_state_latest():
"""Return current per-worker state snapshot without recording it."""
if combined_instances:
return {
"t_unix": _time.time(),
"mode": "combined",
"workers": snapshot_workers(combined_instances),
}
return {
"t_unix": _time.time(),
"mode": "pd_sep",
"workers_prefill": snapshot_workers(prefill_instances),
"workers_decode": snapshot_workers(decode_instances),
}
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear",
choices=["linear", "lmetric", "unified"],
help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), "
"or unified (hybrid affinity + LMetric fallback)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
# The four flags below are accepted for bench.sh backward compatibility but
# have no effect after the PD-sep offload path was retired (REPORT §3.9,
# commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and
# scripts/legacy/*.sh which still pass them through.
p.add_argument("--max-offload-inflight", type=int, default=4,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--offload-mode", type=str, default="cached_prefill",
choices=["direct_read", "cached_prefill"],
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--cache-gate-ratio", type=float, default=0.0,
help="[DEPRECATED] PUSH offload retired; no effect")
p.add_argument("--decode-iteration-s", type=float, default=0.05,
help="[DEPRECATED] PUSH offload retired; no effect")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
SETTINGS.heavy_threshold = global_args.heavy_threshold
SETTINGS.overload_factor = global_args.overload_factor
SETTINGS.max_offload_inflight = global_args.max_offload_inflight
SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio
SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05)
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))
uvicorn.run(app, host=global_args.host, port=global_args.port)