Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang 76ee28a40f Elastic P2P v4: error rate 25% -> 4%, TTFT p50 -12% (median-tail tradeoff)
Fixed offload decision: removed p>=d gate (was blocking all offloads),
added MAX_OFFLOAD_INFLIGHT=4 cap and p_saturated threshold.

Result (200 req, fresh restart):
  Baseline: 99% success, TTFT=1.080/9.410, TPOT90=0.076, E2E=5.306
  Elastic:  96% success, TTFT=0.946/15.843, TPOT90=0.077, E2E=5.717

Architectural tradeoff confirmed:
  - Median (p50) improves: D instances not disrupted by heavy prefill
  - Tail (p90) worsens: offloaded HEAVY requests pay KV transfer cost
  - TPOT unchanged: decode isolation is not the bottleneck

To improve p90: need layerwise pipelined KV transfer (overlap with prefill
compute) or smarter offload gating that avoids offloading the very largest
requests (which have the longest prefill time and generate the most KV).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 15:08:16 +08:00

520 lines
20 KiB
Python

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policy (same for both modes):
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
Normalized load prevents "rich get richer"; cache bonus gives affinity.
Session affinity: multi-turn sessions stick to same instance.
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
# Cumulative token load per instance (for balanced session placement)
_inst_cumulative_tokens: list[int] = []
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
(ongoing_tokens > 2x average), in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
"""
global _inst_cumulative_tokens
if not _inst_cumulative_tokens:
_inst_cumulative_tokens = [0] * len(instances)
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
OVERLOAD_FACTOR = 2.0
# Session affinity for turn 2+ (with load override)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
# Stick if not overloaded
if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR:
return inst, idx
# Overloaded: fall through to score-based selection
# Score = ongoing_tokens - ALPHA * cache_hit_tokens
# Balances load (lower is better) with cache affinity (higher hit is better)
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit
if score < best_score:
best_score = score
best_idx = i
_inst_cumulative_tokens[best_idx] += input_length
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
_breakdown_log: list[dict] = []
_offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
print(f"Combined mode: {len(combined_instances)} instances, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with adaptive prefill offload (v2).
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on
session-sticky instance. Only works if instances have kv_role=kv_both.
Falls back to co-located if --no-offload or instances lack Mooncake.
"""
best_inst, best_idx = pick_instance(combined_instances, token_ids, session_id,
input_length, session_affinity)
cache_hit = best_inst.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"t_proxy_recv": _time.monotonic(),
}
offload_enabled = getattr(global_args, 'offload', False) if global_args else False
has_bootstrap = any(inst.bootstrap_port for inst in combined_instances)
# Elastic offload decision: offload only when it helps
use_offload = False
offload_reason = "disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled and has_bootstrap and len(combined_instances) >= 2:
d_inst = best_inst
p_candidates = [inst for inst in combined_instances if inst is not d_inst]
p_inst = min(p_candidates, key=lambda x: x.ongoing_tokens)
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
# Decision logic:
# 1. Global cap: max N concurrent offloads (prevents all-offload storm)
# 2. P must not already be saturated with heavy prefills
# 3. D must be doing something (otherwise no benefit from offloading)
# NOTE: We do NOT require P < D. P can be busier than D — the point
# is to keep heavy prefill OFF the session-sticky D instance so D's
# decode is not disrupted and D's KV cache is available for future turns.
global _offload_inflight
if _offload_inflight >= MAX_OFFLOAD_INFLIGHT:
offload_reason = "max_concurrent_reached"
elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2:
offload_reason = "p_saturated"
else:
use_offload = True
offload_reason = "offload_accepted"
_offload_inflight += 1
if use_offload:
d_idx = best_idx
p_inst.ongoing_tokens += input_length # reserve immediately
breakdown["route_class"] = "HEAVY_P2P"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
else:
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
inst = best_inst
breakdown["routed_to"] = inst.url
inst.ongoing_tokens += input_length
async def generate():
first_token = True
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
inst.ongoing_decode_tokens += input_length
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
inst.record_prefix(token_ids)
finally:
inst.ongoing_tokens -= input_length
inst.ongoing_decode_tokens -= input_length
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst."""
request_id = headers.get("X-Request-Id", "")
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
breakdown["t_prefill_sent"] = _time.monotonic()
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
breakdown["t_prefill_done"] = _time.monotonic()
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["error"] = str(e)
_breakdown_log.append(breakdown)
global _offload_inflight
_offload_inflight = max(0, _offload_inflight - 1)
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
finally:
p_inst.ongoing_tokens -= input_length
_offload_inflight = max(0, _offload_inflight - 1)
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.ongoing_decode_tokens += input_length
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
d_inst.ongoing_tokens -= input_length
d_inst.ongoing_decode_tokens -= input_length
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
input_length, breakdown):
"""Fire-and-forget prefill: send and don't block caller."""
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
finally:
p_inst.ongoing_tokens -= input_length
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
if global_args.fire_and_forget:
asyncio.create_task(_send_prefill_async(
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
else:
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--fire-and-forget", action="store_true",
help="Send prefill async, don't await before decode")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
uvicorn.run(app, host=global_args.host, port=global_args.port)