Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang cdf83493ab Fix A+C: real cache sync + cached-prefill-on-C architecture
A: Add /estimate_hit endpoint to bootstrap server for real-time cache
   probing. Proxy queries this before committing to PUSH, eliminating
   24% zero-match PUSH requests (shadow cache divergence).

C: Add _handle_cached_prefill_offload: C (cache source) does fast
   cached prefill → KV to Mooncake → D pulls and decodes.
   Replaces broken direct_read PUSH where D waited for RDMA transfer
   while occupying KV blocks without doing compute.

Also: update §3.9 baseline to plain vLLM with full mean/p50/p90/p99.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-24 11:22:38 +08:00

802 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
MAX_STREAM_RETRIES = 3
RETRY_DELAY_S = 0.5
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
SETTINGS = Settings()
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh) # LRU touch on hit
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * 20000
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * 2.0
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity.
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
_bootstrap_client: httpx.AsyncClient | None = None
BOOTSTRAP_TIMEOUT_S = 1.0 # timeout for /estimate_hit calls
async def _get_bootstrap_client() -> httpx.AsyncClient:
global _bootstrap_client
if _bootstrap_client is None:
_bootstrap_client = httpx.AsyncClient(
timeout=httpx.Timeout(BOOTSTRAP_TIMEOUT_S),
limits=httpx.Limits(max_connections=32, max_keepalive_connections=16),
)
return _bootstrap_client
async def _query_bootstrap_hit(
inst: InstanceState, token_ids: list[int],
) -> int | None:
"""Query bootstrap's /estimate_hit for real cache hit count.
Returns hit_tokens on success, None on failure (caller should fallback).
"""
if inst.bootstrap_port is None:
return None
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/estimate_hit"
try:
client = await _get_bootstrap_client()
resp = await client.post(url, json={
"token_ids": token_ids,
"block_size": BLOCK_SIZE,
})
resp.raise_for_status()
return resp.json()["hit_tokens"]
except Exception:
return None
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
# different instance lists, so a shared dict could mis-route after a mode switch.
session_affinity_combined: dict[str, int] = {}
session_affinity_prefill: dict[str, int] = {}
# Backwards-compat alias used by /stats etc.
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
async def _reconcile_loop():
"""Periodic safety net for shadow state.
StreamingResponse generators decrement load counters in their finally
block, but if a client disconnects before the body is consumed the
generator is never entered and the decrement is lost. Clamp negative
drift every minute so router scores stay sane. This does not replace
proper exact-state syncing with vLLM (see TODO.md item 6).
"""
while True:
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
return
for inst in combined_instances + prefill_instances + decode_instances:
if inst.ongoing_tokens < 0:
inst.ongoing_tokens = 0
if inst.ongoing_decode_tokens < 0:
inst.ongoing_decode_tokens = 0
if inst.pending_prefill_tokens < 0:
inst.pending_prefill_tokens = 0
if inst.num_requests < 0:
inst.num_requests = 0
if inst.active_p_offloads < 0:
inst.active_p_offloads = 0
def _verify_vllm_patch():
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
The patch turns an `assert req_id in self.requests` into a soft warn so
that engines do not crash on the KV-transfer abort race (see REPORT
§3.x). If somebody upgrades vLLM without re-applying the patch, the
assert returns and elastic mode dies under load. Print a loud warning
so we catch the regression before the first HEAVY request.
"""
try:
import inspect
from vllm.v1.core.sched.scheduler import Scheduler
src = inspect.getsource(Scheduler)
if "assert req_id in self.requests" in src:
print("WARNING: vLLM scheduler still contains the unpatched "
"`assert req_id in self.requests` line; expect engine "
"death on KV-transfer abort race. Apply "
"patches/0001-fix-kv-transfer-abort-race.patch.")
else:
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
except Exception as exc:
print(f"vLLM patch self-check skipped: {exc!r}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
_verify_vllm_patch()
reconcile_task = asyncio.create_task(_reconcile_loop())
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
reconcile_task.cancel()
try:
await reconcile_task
except asyncio.CancelledError:
pass
if _bootstrap_client is not None:
await _bootstrap_client.aclose()
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Unified routing: pick the instance with lowest expected latency.
For each instance, estimate:
latency = queue_time + prefill_time + transfer_cost
where prefill_time depends on whether the instance has cache (local),
can receive cache via PUSH (remote), or must do cold prefill.
"""
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
throughput = SETTINGS.prefill_throughput
# Compute cache hits for all instances
cache_hits = [inst.estimate_cache_hit(token_ids) for inst in combined_instances]
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx]
def _instance_cost(i: int) -> tuple[float, bool]:
"""Expected latency if this request goes to instance i."""
inst = combined_instances[i]
queue = inst.pending_prefill_tokens / throughput
local_hit = cache_hits[i]
local_new = max(0, input_length - local_hit)
local_cost = queue + local_new / throughput
if offload_enabled and best_cache_hit > 0 and i != best_cache_idx and local_hit < best_cache_hit:
push_new = max(0, input_length - best_cache_hit)
push_cost = queue + push_new / throughput + SETTINGS.rdma_overhead_s
if push_cost < local_cost:
return push_cost, True
return local_cost, False
# Session affinity: prefer the last-used instance if its cost is reasonable
affinity_idx = session_affinity_combined.get(session_id) if session_id else None
if affinity_idx is not None and affinity_idx < len(combined_instances):
affinity_cost, affinity_push = _instance_cost(affinity_idx)
# Compare with the globally best option
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
global_best_cost = min(c for c, _ in all_costs)
# Use affinity if it's within 2x of the best option
if affinity_cost <= global_best_cost * 2.0:
best_idx = affinity_idx
best_cost = affinity_cost
best_needs_push = affinity_push
else:
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
best_cost, best_needs_push = all_costs[best_idx]
else:
all_costs = [_instance_cost(i) for i in range(len(combined_instances))]
best_idx = min(range(len(combined_instances)), key=lambda i: all_costs[i][0])
best_cost, best_needs_push = all_costs[best_idx]
chosen = combined_instances[best_idx]
cache_hit = cache_hits[best_idx]
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"chosen_cost": round(best_cost, 2),
}
if session_id:
session_affinity_combined[session_id] = best_idx
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
# Query real cache hit from bootstrap (shadow cache is inaccurate)
real_hit = await _query_bootstrap_hit(c_inst, token_ids)
breakdown["shadow_cache_hit"] = best_cache_hit
breakdown["real_cache_hit"] = real_hit
if real_hit is not None:
push_cache_hit = real_hit
else:
push_cache_hit = best_cache_hit # fallback to shadow estimate
# If real hit > 0, proceed with cached prefill on C → decode on D
if push_cache_hit > 0:
push_new = max(0, input_length - push_cache_hit)
c_inst.ongoing_tokens += input_length
c_inst.pending_prefill_tokens += push_new
c_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "CACHED_PREFILL_OFFLOAD"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
return await _handle_cached_prefill_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
# Real hit is 0 — downgrade to LOCAL
breakdown["push_downgraded"] = True
# LOCAL path (also handles downgraded PUSH)
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
chosen.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
chosen.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
chosen.pending_prefill_tokens -= estimated_new
else:
chosen.ongoing_decode_tokens -= input_length
chosen.ongoing_tokens -= input_length
chosen.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_cached_prefill_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""C does fast cached prefill → KV to Mooncake → D pulls KV and decodes.
Unlike direct_read (D pulls blocks from C), here C's scheduler IS
involved: C prefills (fast, because prefix is cached), pushes KV to
Mooncake store, then D pulls and decodes. This avoids the broken
PUSH path where D waits for RDMA transfer while occupying KV blocks.
"""
request_id = headers.get("X-Request-Id", "")
# Step 1: send blocking prefill to C
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await c_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
c_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
raise HTTPException(status_code=502, detail=f"Prefill on C failed: {e}")
c_inst.ongoing_tokens -= input_length
c_inst.pending_prefill_tokens -= estimated_new
c_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
# Step 2: send decode to D (pull KV from C via Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{c_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not first_token:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
"""
request_id = headers.get("X-Request-Id", "")
# Align cache_hit to block boundary for remote_num_tokens
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
breakdown["t_offload_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
# Send full prompt to D with direct_read flag
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"direct_read": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": cached_tokens,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
p.add_argument("--max-offload-inflight", type=int, default=4,
help="Global cap on concurrent P-role offloads (M3)")
p.add_argument("--cache-gate-ratio", type=float, default=0.3,
help="Min cache_hit/input ratio to allow offload "
"(0.0 disables gate, 1.0 disables offload entirely)")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))
uvicorn.run(app, host=global_args.host, port=global_args.port)