A: Add /estimate_hit endpoint to bootstrap server for real-time cache probing. Proxy queries this before committing to PUSH, eliminating 24% zero-match PUSH requests (shadow cache divergence). C: Add _handle_cached_prefill_offload: C (cache source) does fast cached prefill → KV to Mooncake → D pulls and decodes. Replaces broken direct_read PUSH where D waited for RDMA transfer while occupying KV blocks without doing compute. Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
802 lines
31 KiB
Python
802 lines
31 KiB
Python
"""Unified cache-aware + token-level load-balanced global scheduler.
|
||
|
||
Supports two modes:
|
||
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
||
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
||
|
||
Routing policies (--policy):
|
||
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
|
||
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
|
||
P_tokens = pending_prefill_tokens + new_uncached_tokens
|
||
BS = num_requests (waiting + running)
|
||
Session affinity: multi-turn sessions stick to same instance (all policies).
|
||
"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import os
|
||
import time as _time
|
||
import urllib.parse
|
||
import uuid
|
||
from collections import OrderedDict
|
||
from contextlib import asynccontextmanager
|
||
from dataclasses import dataclass
|
||
|
||
import httpx
|
||
|
||
MAX_STREAM_RETRIES = 3
|
||
RETRY_DELAY_S = 0.5
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
BLOCK_SIZE = 512
|
||
CACHE_HIT_ALPHA = 1.0
|
||
|
||
|
||
@dataclass
|
||
class Settings:
|
||
"""Runtime-tunable knobs. Populated from argparse in __main__.
|
||
|
||
All routing/offload code reads from the SETTINGS singleton so that
|
||
CLI overrides survive even when the module is imported as a library
|
||
(e.g. by tests/) and __main__ does not run.
|
||
"""
|
||
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
|
||
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
|
||
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
|
||
|
||
|
||
SETTINGS = Settings()
|
||
|
||
|
||
class InstanceState:
|
||
def __init__(self, url: str, bootstrap_port: int | None = None):
|
||
self.url = url
|
||
self.bootstrap_port = bootstrap_port
|
||
self.client = httpx.AsyncClient(
|
||
timeout=None, base_url=url,
|
||
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
||
)
|
||
self.ongoing_tokens = 0
|
||
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
|
||
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
|
||
self.num_requests = 0 # total in-flight requests (waiting + running)
|
||
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
|
||
self.engine_id: dict[int, str] = {}
|
||
self.dp_size = 1
|
||
# OrderedDict acts as an LRU keyed by block hash; value is unused.
|
||
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
|
||
|
||
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
||
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
||
return 0
|
||
hit = 0
|
||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||
if bh in self.cached_blocks:
|
||
self.cached_blocks.move_to_end(bh) # LRU touch on hit
|
||
hit += BLOCK_SIZE
|
||
else:
|
||
break
|
||
return hit
|
||
|
||
def record_prefix(self, token_ids: list[int] | None):
|
||
if not token_ids:
|
||
return
|
||
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
||
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
||
if bh in self.cached_blocks:
|
||
self.cached_blocks.move_to_end(bh)
|
||
else:
|
||
self.cached_blocks[bh] = None
|
||
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
|
||
self.cached_blocks.popitem(last=False)
|
||
|
||
|
||
def _p_offload_penalty(inst: InstanceState) -> int:
|
||
"""Penalty for PD-sep mode routing (legacy)."""
|
||
if inst.active_p_offloads <= 0:
|
||
return 0
|
||
return inst.active_p_offloads * 20000
|
||
|
||
|
||
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
||
session_id: str | None, input_length: int,
|
||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||
"""Session-sticky with load-aware override.
|
||
|
||
Turn 2+: use session affinity UNLESS pinned instance is overloaded
|
||
or busy with P-role offloads, in which case pick least-loaded.
|
||
Turn 1: pick instance with best score (load + cache combined).
|
||
Instances doing P-role offloads get a large penalty to steer
|
||
WARM/MEDIUM traffic away.
|
||
"""
|
||
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
|
||
|
||
if session_id and session_id in affinity:
|
||
idx = affinity[session_id]
|
||
if idx < len(instances):
|
||
inst = instances[idx]
|
||
if (inst.ongoing_tokens <= avg_load * 2.0
|
||
and inst.active_p_offloads == 0):
|
||
return inst, idx
|
||
|
||
best_idx, best_score = 0, float("inf")
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
|
||
- CACHE_HIT_ALPHA * cache_hit)
|
||
if score < best_score:
|
||
best_score = score
|
||
best_idx = i
|
||
|
||
if session_id:
|
||
affinity[session_id] = best_idx
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
|
||
session_id: str | None, input_length: int,
|
||
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
||
"""LMetric routing: score = P_tokens × BS (OSDI'26).
|
||
|
||
Pure per-request load-based routing, no session affinity.
|
||
P = pending_prefill_tokens + (input_length - cache_hit)
|
||
BS = num_requests (current batch size)
|
||
"""
|
||
best_idx, best_score = 0, float("inf")
|
||
for i, inst in enumerate(instances):
|
||
cache_hit = inst.estimate_cache_hit(token_ids)
|
||
new_prefill = max(0, input_length - cache_hit)
|
||
p_tokens = inst.pending_prefill_tokens + new_prefill
|
||
bs = inst.num_requests
|
||
score = p_tokens * bs
|
||
if score < best_score:
|
||
best_score = score
|
||
best_idx = i
|
||
|
||
return instances[best_idx], best_idx
|
||
|
||
|
||
_bootstrap_client: httpx.AsyncClient | None = None
|
||
|
||
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
|
||
|
||
|
||
async def _get_bootstrap_client() -> httpx.AsyncClient:
|
||
global _bootstrap_client
|
||
if _bootstrap_client is None:
|
||
_bootstrap_client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
|
||
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
|
||
)
|
||
return _bootstrap_client
|
||
|
||
|
||
async def _query_bootstrap_hit(
|
||
inst: InstanceState, token_ids: list[int],
|
||
) -> int | None:
|
||
"""Query bootstrap's /estimate_hit for real cache hit count.
|
||
|
||
Returns hit_tokens on success, None on failure (caller should fallback).
|
||
"""
|
||
if inst.bootstrap_port is None:
|
||
return None
|
||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
|
||
try:
|
||
client = await _get_bootstrap_client()
|
||
resp = await client.post(url, json={
|
||
"token_ids": token_ids,
|
||
"block_size": BLOCK_SIZE,
|
||
})
|
||
resp.raise_for_status()
|
||
return resp.json()["hit_tokens"]
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
global_args = None
|
||
combined_instances: list[InstanceState] = []
|
||
prefill_instances: list[InstanceState] = []
|
||
decode_instances: list[InstanceState] = []
|
||
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
|
||
# different instance lists, so a shared dict could mis-route after a mode switch.
|
||
session_affinity_combined: dict[str, int] = {}
|
||
session_affinity_prefill: dict[str, int] = {}
|
||
# Backwards-compat alias used by /stats etc.
|
||
session_affinity = session_affinity_combined
|
||
is_pd_sep = False
|
||
_breakdown_log: list[dict] = []
|
||
|
||
|
||
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
||
for inst in instances:
|
||
if inst.bootstrap_port is None:
|
||
continue
|
||
while True:
|
||
try:
|
||
await inst.client.get("/health")
|
||
except Exception:
|
||
await asyncio.sleep(1)
|
||
continue
|
||
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
||
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
||
resp = await inst.client.get(url)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
for dp_rank, dp_entry in data.items():
|
||
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
||
inst.dp_size = len(data)
|
||
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
||
break
|
||
ready.set()
|
||
|
||
|
||
async def _reconcile_loop():
|
||
"""Periodic safety net for shadow state.
|
||
|
||
StreamingResponse generators decrement load counters in their finally
|
||
block, but if a client disconnects before the body is consumed the
|
||
generator is never entered and the decrement is lost. Clamp negative
|
||
drift every minute so router scores stay sane. This does not replace
|
||
proper exact-state syncing with vLLM (see TODO.md item 6).
|
||
"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
except asyncio.CancelledError:
|
||
return
|
||
for inst in combined_instances + prefill_instances + decode_instances:
|
||
if inst.ongoing_tokens < 0:
|
||
inst.ongoing_tokens = 0
|
||
if inst.ongoing_decode_tokens < 0:
|
||
inst.ongoing_decode_tokens = 0
|
||
if inst.pending_prefill_tokens < 0:
|
||
inst.pending_prefill_tokens = 0
|
||
if inst.num_requests < 0:
|
||
inst.num_requests = 0
|
||
if inst.active_p_offloads < 0:
|
||
inst.active_p_offloads = 0
|
||
|
||
|
||
def _verify_vllm_patch():
|
||
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
|
||
|
||
The patch turns an `assert req_id in self.requests` into a soft warn so
|
||
that engines do not crash on the KV-transfer abort race (see REPORT
|
||
§3.x). If somebody upgrades vLLM without re-applying the patch, the
|
||
assert returns and elastic mode dies under load. Print a loud warning
|
||
so we catch the regression before the first HEAVY request.
|
||
"""
|
||
try:
|
||
import inspect
|
||
from vllm.v1.core.sched.scheduler import Scheduler
|
||
src = inspect.getsource(Scheduler)
|
||
if "assert req_id in self.requests" in src:
|
||
print("WARNING: vLLM scheduler still contains the unpatched "
|
||
"`assert req_id in self.requests` line; expect engine "
|
||
"death on KV-transfer abort race. Apply "
|
||
"patches/0001-fix-kv-transfer-abort-race.patch.")
|
||
else:
|
||
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
|
||
except Exception as exc:
|
||
print(f"vLLM patch self-check skipped: {exc!r}")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global is_pd_sep
|
||
app.state.ready = asyncio.Event()
|
||
|
||
_verify_vllm_patch()
|
||
|
||
reconcile_task = asyncio.create_task(_reconcile_loop())
|
||
|
||
if global_args.combined:
|
||
is_pd_sep = False
|
||
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
|
||
for i, url in enumerate(global_args.combined):
|
||
bp = bp_list[i] if i < len(bp_list) else None
|
||
combined_instances.append(InstanceState(url, bp))
|
||
|
||
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
|
||
if global_args.offload and bp_list:
|
||
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
||
else:
|
||
app.state.ready.set()
|
||
|
||
policy = getattr(global_args, 'policy', 'linear')
|
||
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
|
||
else:
|
||
is_pd_sep = True
|
||
for url, bp in global_args.prefill:
|
||
prefill_instances.append(InstanceState(url, bp))
|
||
for url in global_args.decode:
|
||
decode_instances.append(InstanceState(url))
|
||
await init_prefill_bootstrap(prefill_instances, app.state.ready)
|
||
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
||
|
||
yield
|
||
reconcile_task.cancel()
|
||
try:
|
||
await reconcile_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
if _bootstrap_client is not None:
|
||
await _bootstrap_client.aclose()
|
||
for inst in combined_instances + prefill_instances + decode_instances:
|
||
await inst.client.aclose()
|
||
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
|
||
@app.post("/v1/completions")
|
||
async def handle_completions(request: Request):
|
||
return await _handle(request, "/v1/completions")
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def handle_chat(request: Request):
|
||
return await _handle(request, "/v1/chat/completions")
|
||
|
||
|
||
async def _handle(request: Request, api: str):
|
||
if not app.state.ready.is_set():
|
||
raise HTTPException(status_code=503, detail="Service Unavailable")
|
||
|
||
req_data = await request.json()
|
||
request_id = str(uuid.uuid4())
|
||
prompt = req_data.get("prompt")
|
||
token_ids = prompt if isinstance(prompt, list) else None
|
||
input_length = len(token_ids) if token_ids else 0
|
||
session_id = request.headers.get("X-Session-Id")
|
||
|
||
headers = {"X-Request-Id": request_id}
|
||
api_key = os.environ.get("OPENAI_API_KEY")
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
if is_pd_sep:
|
||
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
||
input_length, session_id, headers)
|
||
else:
|
||
return await _handle_combined(api, req_data, token_ids,
|
||
input_length, session_id, headers)
|
||
|
||
|
||
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||
"""Unified routing: pick the instance with lowest expected latency.
|
||
|
||
For each instance, estimate:
|
||
latency = queue_time + prefill_time + transfer_cost
|
||
where prefill_time depends on whether the instance has cache (local),
|
||
can receive cache via PUSH (remote), or must do cold prefill.
|
||
"""
|
||
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
|
||
throughput = SETTINGS.prefill_throughput
|
||
|
||
# Compute cache hits for all instances
|
||
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
|
||
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
|
||
best_cache_hit = cache_hits[best_cache_idx]
|
||
|
||
def _instance_cost(i: int) -> tuple[float, bool]:
|
||
"""Expected latency if this request goes to instance i."""
|
||
inst = combined_instances[i]
|
||
queue = inst.pending_prefill_tokens / throughput
|
||
local_hit = cache_hits[i]
|
||
local_new = max(0, input_length - local_hit)
|
||
local_cost = queue + local_new / throughput
|
||
|
||
if offload_enabled and best_cache_hit > 0 and i != best_cache_idx and local_hit < best_cache_hit:
|
||
push_new = max(0, input_length - best_cache_hit)
|
||
push_cost = queue + push_new / throughput + SETTINGS.rdma_overhead_s
|
||
if push_cost < local_cost:
|
||
return push_cost, True
|
||
return local_cost, False
|
||
|
||
# Session affinity: prefer the last-used instance if its cost is reasonable
|
||
affinity_idx = session_affinity_combined.get(session_id) if session_id else None
|
||
if affinity_idx is not None and affinity_idx < len(combined_instances):
|
||
affinity_cost, affinity_push = _instance_cost(affinity_idx)
|
||
# Compare with the globally best option
|
||
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
|
||
global_best_cost = min(c for c, _ in all_costs)
|
||
# Use affinity if it's within 2x of the best option
|
||
if affinity_cost <= global_best_cost * 2.0:
|
||
best_idx = affinity_idx
|
||
best_cost = affinity_cost
|
||
best_needs_push = affinity_push
|
||
else:
|
||
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
|
||
best_cost, best_needs_push = all_costs[best_idx]
|
||
else:
|
||
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
|
||
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
|
||
best_cost, best_needs_push = all_costs[best_idx]
|
||
|
||
chosen = combined_instances[best_idx]
|
||
cache_hit = cache_hits[best_idx]
|
||
estimated_new = max(0, input_length - cache_hit)
|
||
|
||
breakdown = {
|
||
"request_id": headers.get("X-Request-Id", ""),
|
||
"input_length": input_length,
|
||
"cache_hit": cache_hit,
|
||
"estimated_new_tokens": estimated_new,
|
||
"t_proxy_recv": _time.monotonic(),
|
||
"chosen_cost": round(best_cost, 2),
|
||
}
|
||
|
||
if session_id:
|
||
session_affinity_combined[session_id] = best_idx
|
||
|
||
if best_needs_push:
|
||
c_inst = combined_instances[best_cache_idx]
|
||
d_inst = chosen
|
||
|
||
# Query real cache hit from bootstrap (shadow cache is inaccurate)
|
||
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
|
||
breakdown["shadow_cache_hit"] = best_cache_hit
|
||
breakdown["real_cache_hit"] = real_hit
|
||
|
||
if real_hit is not None:
|
||
push_cache_hit = real_hit
|
||
else:
|
||
push_cache_hit = best_cache_hit # fallback to shadow estimate
|
||
|
||
# If real hit > 0, proceed with cached prefill on C → decode on D
|
||
if push_cache_hit > 0:
|
||
push_new = max(0, input_length - push_cache_hit)
|
||
|
||
c_inst.ongoing_tokens += input_length
|
||
c_inst.pending_prefill_tokens += push_new
|
||
c_inst.num_requests += 1
|
||
c_inst.active_p_offloads += 1
|
||
|
||
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
|
||
breakdown["c_inst"] = c_inst.url
|
||
breakdown["d_inst"] = d_inst.url
|
||
breakdown["push_cache_hit"] = push_cache_hit
|
||
|
||
return await _handle_cached_prefill_offload(
|
||
api, req_data, headers, token_ids, input_length,
|
||
c_inst, d_inst, push_cache_hit, push_new, breakdown)
|
||
|
||
# Real hit is 0 — downgrade to LOCAL
|
||
breakdown["push_downgraded"] = True
|
||
|
||
# LOCAL path (also handles downgraded PUSH)
|
||
breakdown["route_class"] = "LOCAL"
|
||
breakdown["routed_to"] = chosen.url
|
||
|
||
chosen.ongoing_tokens += input_length
|
||
chosen.pending_prefill_tokens += estimated_new
|
||
chosen.num_requests += 1
|
||
|
||
async def generate():
|
||
prefill_done = False
|
||
try:
|
||
for attempt in range(MAX_STREAM_RETRIES):
|
||
try:
|
||
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
if not prefill_done:
|
||
chosen.pending_prefill_tokens -= estimated_new
|
||
chosen.ongoing_decode_tokens += input_length
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
prefill_done = True
|
||
yield chunk
|
||
chosen.record_prefix(token_ids)
|
||
break
|
||
except (httpx.ConnectError, httpx.RemoteProtocolError):
|
||
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
|
||
raise
|
||
await asyncio.sleep(RETRY_DELAY_S)
|
||
finally:
|
||
if not prefill_done:
|
||
chosen.pending_prefill_tokens -= estimated_new
|
||
else:
|
||
chosen.ongoing_decode_tokens -= input_length
|
||
chosen.ongoing_tokens -= input_length
|
||
chosen.num_requests -= 1
|
||
breakdown["t_done"] = _time.monotonic()
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||
|
||
|
||
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
|
||
input_length, c_inst, d_inst,
|
||
cache_hit, estimated_new, breakdown):
|
||
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
|
||
|
||
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
|
||
involved: C prefills (fast, because prefix is cached), pushes KV to
|
||
Mooncake store, then D pulls and decodes. This avoids the broken
|
||
PUSH path where D waits for RDMA transfer while occupying KV blocks.
|
||
"""
|
||
request_id = headers.get("X-Request-Id", "")
|
||
|
||
# Step 1: send blocking prefill to C
|
||
prefill_data = req_data.copy()
|
||
prefill_data["kv_transfer_params"] = {
|
||
"do_remote_decode": True,
|
||
"do_remote_prefill": False,
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
prefill_data["stream"] = False
|
||
prefill_data["max_tokens"] = 1
|
||
prefill_data.pop("max_completion_tokens", None)
|
||
prefill_data.pop("stream_options", None)
|
||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||
|
||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||
|
||
try:
|
||
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
resp.raise_for_status()
|
||
await resp.aclose()
|
||
c_inst.record_prefix(token_ids)
|
||
except Exception as e:
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["prefill_error"] = True
|
||
_breakdown_log.append(breakdown)
|
||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||
c_inst.ongoing_tokens -= input_length
|
||
c_inst.pending_prefill_tokens -= estimated_new
|
||
c_inst.num_requests -= 1
|
||
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
|
||
|
||
c_inst.ongoing_tokens -= input_length
|
||
c_inst.pending_prefill_tokens -= estimated_new
|
||
c_inst.num_requests -= 1
|
||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||
|
||
# Step 2: send decode to D (pull KV from C via Mooncake)
|
||
d_inst.ongoing_tokens += input_length
|
||
d_inst.num_requests += 1
|
||
|
||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
|
||
|
||
decode_data = req_data.copy()
|
||
decode_data["kv_transfer_params"] = {
|
||
"do_remote_decode": False,
|
||
"do_remote_prefill": True,
|
||
"remote_bootstrap_addr": bootstrap_addr,
|
||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
|
||
breakdown["t_decode_sent"] = _time.monotonic()
|
||
|
||
async def generate():
|
||
first_token = True
|
||
try:
|
||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
if first_token:
|
||
d_inst.ongoing_decode_tokens += input_length
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
first_token = False
|
||
yield chunk
|
||
d_inst.record_prefix(token_ids)
|
||
finally:
|
||
if not first_token:
|
||
d_inst.ongoing_decode_tokens -= input_length
|
||
d_inst.ongoing_tokens -= input_length
|
||
d_inst.num_requests -= 1
|
||
breakdown["t_done"] = _time.monotonic()
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||
input_length, c_inst, d_inst,
|
||
cache_hit, estimated_new, breakdown):
|
||
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
|
||
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
|
||
"""
|
||
request_id = headers.get("X-Request-Id", "")
|
||
|
||
# Align cache_hit to block boundary for remote_num_tokens
|
||
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
|
||
breakdown["t_offload_sent"] = _time.monotonic()
|
||
|
||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
|
||
|
||
# Send full prompt to D with direct_read flag
|
||
decode_data = req_data.copy()
|
||
decode_data["kv_transfer_params"] = {
|
||
"do_remote_decode": False,
|
||
"do_remote_prefill": True,
|
||
"direct_read": True,
|
||
"remote_bootstrap_addr": bootstrap_addr,
|
||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||
"transfer_id": "xfer-" + request_id,
|
||
"remote_num_tokens": cached_tokens,
|
||
}
|
||
|
||
async def generate():
|
||
first_token = True
|
||
try:
|
||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
if first_token:
|
||
d_inst.pending_prefill_tokens -= estimated_new
|
||
d_inst.ongoing_decode_tokens += input_length
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
first_token = False
|
||
yield chunk
|
||
d_inst.record_prefix(token_ids)
|
||
finally:
|
||
if first_token:
|
||
d_inst.pending_prefill_tokens -= estimated_new
|
||
else:
|
||
d_inst.ongoing_decode_tokens -= input_length
|
||
d_inst.ongoing_tokens -= input_length
|
||
d_inst.num_requests -= 1
|
||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||
breakdown["t_done"] = _time.monotonic()
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||
|
||
|
||
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
||
session_id, headers):
|
||
"""PD-Sep mode with per-stage breakdown profiling."""
|
||
breakdown = {
|
||
"request_id": request_id,
|
||
"input_length": input_length,
|
||
"t_proxy_recv": _time.monotonic(),
|
||
}
|
||
|
||
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
||
input_length, session_affinity_prefill)
|
||
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
||
breakdown["p_inst"] = p_inst.url
|
||
breakdown["d_inst"] = d_inst.url
|
||
|
||
prefill_data = req_data.copy()
|
||
prefill_data["kv_transfer_params"] = {
|
||
"do_remote_decode": True, "do_remote_prefill": False,
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
prefill_data["stream"] = False
|
||
prefill_data["max_tokens"] = 1
|
||
prefill_data.pop("max_completion_tokens", None)
|
||
prefill_data.pop("stream_options", None)
|
||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||
|
||
p_inst.ongoing_tokens += input_length
|
||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||
|
||
try:
|
||
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
resp.raise_for_status()
|
||
await resp.aclose()
|
||
p_inst.record_prefix(token_ids)
|
||
except Exception as e:
|
||
breakdown["t_prefill_done"] = _time.monotonic()
|
||
breakdown["prefill_error"] = True
|
||
_breakdown_log.append(breakdown)
|
||
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
||
finally:
|
||
p_inst.ongoing_tokens -= input_length
|
||
|
||
# Send decode
|
||
d_inst.ongoing_tokens += input_length
|
||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
||
|
||
decode_data = req_data.copy()
|
||
decode_data["kv_transfer_params"] = {
|
||
"do_remote_decode": False, "do_remote_prefill": True,
|
||
"remote_bootstrap_addr": bootstrap_addr,
|
||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||
"transfer_id": f"xfer-{request_id}",
|
||
}
|
||
|
||
breakdown["t_decode_sent"] = _time.monotonic()
|
||
|
||
async def generate():
|
||
first_token = True
|
||
try:
|
||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
if first_token:
|
||
breakdown["t_first_token"] = _time.monotonic()
|
||
first_token = False
|
||
yield chunk
|
||
finally:
|
||
breakdown["t_done"] = _time.monotonic()
|
||
d_inst.ongoing_tokens -= input_length
|
||
_breakdown_log.append(breakdown)
|
||
|
||
return StreamingResponse(generate(), media_type="application/json")
|
||
|
||
|
||
@app.get("/breakdown")
|
||
async def get_breakdown():
|
||
"""Return per-request breakdown data for analysis."""
|
||
return _breakdown_log
|
||
|
||
|
||
@app.get("/stats")
|
||
async def get_stats():
|
||
"""Return per-instance live state for debugging."""
|
||
instances = combined_instances or prefill_instances + decode_instances
|
||
return [{
|
||
"url": inst.url,
|
||
"role": "combined",
|
||
"ongoing_tokens": inst.ongoing_tokens,
|
||
"pending_prefill_tokens": inst.pending_prefill_tokens,
|
||
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
|
||
"num_requests": inst.num_requests,
|
||
"active_p_offloads": inst.active_p_offloads,
|
||
"cached_blocks": len(inst.cached_blocks),
|
||
} for inst in instances]
|
||
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
||
p.add_argument("--port", type=int, default=8000)
|
||
p.add_argument("--host", type=str, default="0.0.0.0")
|
||
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
||
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
||
help="PD-Sep prefill: URL [bootstrap_port]")
|
||
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
||
help="PD-Sep decode: URL")
|
||
p.add_argument("--heavy-threshold", type=int, default=20000,
|
||
help="New tokens threshold for HEAVY classification (adaptive offload)")
|
||
p.add_argument("--offload", action="store_true",
|
||
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
|
||
p.add_argument("--bootstrap-ports", type=str, default="",
|
||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
|
||
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
|
||
p.add_argument("--overload-factor", type=float, default=2.0,
|
||
help="Break session affinity when instance load > factor * avg")
|
||
p.add_argument("--max-offload-inflight", type=int, default=4,
|
||
help="Global cap on concurrent P-role offloads (M3)")
|
||
p.add_argument("--cache-gate-ratio", type=float, default=0.3,
|
||
help="Min cache_hit/input ratio to allow offload "
|
||
"(0.0 disables gate, 1.0 disables offload entirely)")
|
||
args = p.parse_args()
|
||
|
||
args.prefill = []
|
||
if args.prefill_raw:
|
||
for entry in args.prefill_raw:
|
||
url = entry[0]
|
||
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
||
args.prefill.append((url, bp))
|
||
args.decode = [e[0] for e in (args.decode_raw or [])]
|
||
|
||
if not args.combined and not args.prefill:
|
||
p.error("Must specify either --combined or --prefill/--decode")
|
||
return args
|
||
|
||
|
||
if __name__ == "__main__":
|
||
global_args = parse_args()
|
||
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
|
||
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
|
||
getattr(global_args, 'offload', False)))
|
||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|