Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang 6b255fad91 Unified routing: single argmin(expected_latency) over all instances
Replace two-phase routing (pick_instance → offload gate) with a single
cost function evaluated per instance:

  latency(D) = queue(D) + prefill_time(D) + transfer_cost(D)

  - If D has local cache: prefill = (input - local_hit) / throughput
  - If D can receive PUSH from cache source: prefill = (input - push_hit) / throughput + rdma
  - Otherwise: prefill = input / throughput (cold)

Choose argmin(latency). If the winner needs PUSH → trigger migration.

Removed:
- WARM/MEDIUM/HEAVY classification (no routing purpose)
- heavy_threshold, overload_factor, max_offload_inflight, cache_gate_ratio
- Interference penalty magic number (0.3)
- Separate pick_instance + offload gate stages

Only 2 measured parameters remain:
- prefill_throughput = 7000 tokens/s (H20 measured)
- rdma_overhead_s = 0.1s (RDMA PUSH measured)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-24 02:21:34 +08:00

658 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from collections import OrderedDict
from contextlib import asynccontextmanager
from dataclasses import dataclass
import httpx
MAX_STREAM_RETRIES = 3
RETRY_DELAY_S = 0.5
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
@dataclass
class Settings:
"""Runtime-tunable knobs. Populated from argparse in __main__.
All routing/offload code reads from the SETTINGS singleton so that
CLI overrides survive even when the module is imported as a library
(e.g. by tests/) and __main__ does not run.
"""
prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20)
rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured)
cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks
SETTINGS = Settings()
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
# OrderedDict acts as an LRU keyed by block hash; value is unused.
self.cached_blocks: OrderedDict[int, None] = OrderedDict()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh) # LRU touch on hit
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
self.cached_blocks.move_to_end(bh)
else:
self.cached_blocks[bh] = None
if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks:
self.cached_blocks.popitem(last=False)
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for PD-sep mode routing (legacy)."""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * 20000
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * 2.0
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity.
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
# Session affinity is namespace-isolated: combined-mode and pd-sep mode index
# different instance lists, so a shared dict could mis-route after a mode switch.
session_affinity_combined: dict[str, int] = {}
session_affinity_prefill: dict[str, int] = {}
# Backwards-compat alias used by /stats etc.
session_affinity = session_affinity_combined
is_pd_sep = False
_breakdown_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
async def _reconcile_loop():
"""Periodic safety net for shadow state.
StreamingResponse generators decrement load counters in their finally
block, but if a client disconnects before the body is consumed the
generator is never entered and the decrement is lost. Clamp negative
drift every minute so router scores stay sane. This does not replace
proper exact-state syncing with vLLM (see TODO.md item 6).
"""
while True:
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
return
for inst in combined_instances + prefill_instances + decode_instances:
if inst.ongoing_tokens < 0:
inst.ongoing_tokens = 0
if inst.ongoing_decode_tokens < 0:
inst.ongoing_decode_tokens = 0
if inst.pending_prefill_tokens < 0:
inst.pending_prefill_tokens = 0
if inst.num_requests < 0:
inst.num_requests = 0
if inst.active_p_offloads < 0:
inst.active_p_offloads = 0
def _verify_vllm_patch():
"""Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch.
The patch turns an `assert req_id in self.requests` into a soft warn so
that engines do not crash on the KV-transfer abort race (see REPORT
§3.x). If somebody upgrades vLLM without re-applying the patch, the
assert returns and elastic mode dies under load. Print a loud warning
so we catch the regression before the first HEAVY request.
"""
try:
import inspect
from vllm.v1.core.sched.scheduler import Scheduler
src = inspect.getsource(Scheduler)
if "assert req_id in self.requests" in src:
print("WARNING: vLLM scheduler still contains the unpatched "
"`assert req_id in self.requests` line; expect engine "
"death on KV-transfer abort race. Apply "
"patches/0001-fix-kv-transfer-abort-race.patch.")
else:
print("vLLM patch self-check: kv-transfer-abort assert is patched.")
except Exception as exc:
print(f"vLLM patch self-check skipped: {exc!r}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
_verify_vllm_patch()
reconcile_task = asyncio.create_task(_reconcile_loop())
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
reconcile_task.cancel()
try:
await reconcile_task
except asyncio.CancelledError:
pass
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Unified routing: pick the instance with lowest expected latency.
For each instance, estimate:
latency = queue_time + prefill_time + transfer_cost
where prefill_time depends on whether the instance has cache (local),
can receive cache via PUSH (remote), or must do cold prefill.
"""
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
throughput = SETTINGS.prefill_throughput
# Find the best cache source (instance with highest prefix cache hit)
cache_hits = []
for i, inst in enumerate(combined_instances):
hit = inst.estimate_cache_hit(token_ids)
cache_hits.append(hit)
best_cache_idx = max(range(len(combined_instances)), key=lambda i: cache_hits[i])
best_cache_hit = cache_hits[best_cache_idx]
# Score each instance by expected latency
best_idx = 0
best_cost = float("inf")
best_needs_push = False
costs = []
for i, inst in enumerate(combined_instances):
queue = inst.pending_prefill_tokens / throughput
local_hit = cache_hits[i]
local_new = max(0, input_length - local_hit)
if offload_enabled and best_cache_hit > 0 and i != best_cache_idx:
# This instance could receive cached blocks via PUSH
push_new = max(0, input_length - best_cache_hit)
push_cost = queue + push_new / throughput + SETTINGS.rdma_overhead_s
local_cost = queue + local_new / throughput
# Use whichever is cheaper (push vs local cache)
if push_cost < local_cost:
cost = push_cost
needs_push = True
else:
cost = local_cost
needs_push = False
else:
cost = queue + local_new / throughput
needs_push = False
costs.append((cost, needs_push))
if cost < best_cost:
best_cost = cost
best_idx = i
best_needs_push = needs_push
chosen = combined_instances[best_idx]
cache_hit = cache_hits[best_idx]
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"cache_hit": cache_hit,
"estimated_new_tokens": estimated_new,
"t_proxy_recv": _time.monotonic(),
"chosen_cost": round(best_cost, 2),
}
if session_id:
session_affinity_combined[session_id] = best_idx
if best_needs_push:
c_inst = combined_instances[best_cache_idx]
d_inst = chosen
push_cache_hit = best_cache_hit
push_new = max(0, input_length - push_cache_hit)
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += push_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "PUSH_MIGRATE"
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["push_cache_hit"] = push_cache_hit
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, push_cache_hit, push_new, breakdown)
else:
breakdown["route_class"] = "LOCAL"
breakdown["routed_to"] = chosen.url
chosen.ongoing_tokens += input_length
chosen.pending_prefill_tokens += estimated_new
chosen.num_requests += 1
async def generate():
prefill_done = False
try:
for attempt in range(MAX_STREAM_RETRIES):
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
inst.record_prefix(token_ids)
break
except (httpx.ConnectError, httpx.RemoteProtocolError):
if prefill_done or attempt >= MAX_STREAM_RETRIES - 1:
raise
await asyncio.sleep(RETRY_DELAY_S)
finally:
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
else:
inst.ongoing_decode_tokens -= input_length
inst.ongoing_tokens -= input_length
inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
"""
request_id = headers.get("X-Request-Id", "")
# Align cache_hit to block boundary for remote_num_tokens
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
breakdown["t_offload_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
# Send full prompt to D with direct_read flag
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"direct_read": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": cached_tokens,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity_prefill)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
p.add_argument("--max-offload-inflight", type=int, default=4,
help="Global cap on concurrent P-role offloads (M3)")
p.add_argument("--cache-gate-ratio", type=float, default=0.3,
help="Min cache_hit/input ratio to allow offload "
"(0.0 disables gate, 1.0 disables offload entirely)")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % (
SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s,
getattr(global_args, 'offload', False)))
uvicorn.run(app, host=global_args.host, port=global_args.port)