Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang 2b0ac70ee7 Phase 1 milestone: system-level analysis + reproducible report
- REPORT.md: self-contained milestone report covering baseline vs elastic
  setup, exact launch commands, benchmark params, results, log locations,
  and repo structure — sufficient for anyone to reproduce
- analysis/pd_separation_analysis.md §5: elastic P2P system-level breakdown
  (KV cache hit ratio, per-class TTFT, GPU util paradox explanation)
- scripts/cache_aware_proxy.py: round-robin P-instance selection replacing
  argmin(ongoing_tokens) to fix GPU load imbalance (3.0x → expected ~2x)
- scripts/launch_elastic_p2p.sh: one-command launch for elastic P2P config

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 16:17:41 +08:00

524 lines
20 KiB
Python

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policy (same for both modes):
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
Normalized load prevents "rich get richer"; cache bonus gives affinity.
Session affinity: multi-turn sessions stick to same instance.
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
OVERLOAD_FACTOR = 2.0
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
# Cumulative token load per instance (for balanced session placement)
_inst_cumulative_tokens: list[int] = []
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
(ongoing_tokens > 2x average), in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
"""
global _inst_cumulative_tokens
if not _inst_cumulative_tokens:
_inst_cumulative_tokens = [0] * len(instances)
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
# Session affinity for turn 2+ (with load override)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
# Stick if not overloaded
if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR:
return inst, idx
# Overloaded: fall through to score-based selection
# Score = ongoing_tokens - ALPHA * cache_hit_tokens
# Balances load (lower is better) with cache affinity (higher hit is better)
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit
if score < best_score:
best_score = score
best_idx = i
_inst_cumulative_tokens[best_idx] += input_length
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
_breakdown_log: list[dict] = []
_offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload
_p_round_robin_idx = 0 # round-robin counter for P-instance selection
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
print(f"Combined mode: {len(combined_instances)} instances, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with adaptive prefill offload (v2).
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on
session-sticky instance. Only works if instances have kv_role=kv_both.
Falls back to co-located if --no-offload or instances lack Mooncake.
"""
best_inst, best_idx = pick_instance(combined_instances, token_ids, session_id,
input_length, session_affinity)
cache_hit = best_inst.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"t_proxy_recv": _time.monotonic(),
}
offload_enabled = getattr(global_args, 'offload', False) if global_args else False
has_bootstrap = any(inst.bootstrap_port for inst in combined_instances)
# Elastic offload decision: offload only when it helps
use_offload = False
offload_reason = "disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled and has_bootstrap and len(combined_instances) >= 2:
d_inst = best_inst
p_candidates = [(i, inst) for i, inst in enumerate(combined_instances) if inst is not d_inst]
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
# Round-robin P selection with overload skip (spreads P-role evenly)
global _offload_inflight, _p_round_robin_idx
p_inst = None
for _ in range(len(p_candidates)):
_p_round_robin_idx = (_p_round_robin_idx + 1) % len(p_candidates)
candidate = p_candidates[_p_round_robin_idx][1]
if candidate.ongoing_tokens < avg_load * OVERLOAD_FACTOR:
p_inst = candidate
break
if p_inst is None:
p_inst = min(p_candidates, key=lambda x: x[1].ongoing_tokens)[1]
if _offload_inflight >= MAX_OFFLOAD_INFLIGHT:
offload_reason = "max_concurrent_reached"
elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2:
offload_reason = "p_saturated"
else:
use_offload = True
offload_reason = "offload_accepted"
_offload_inflight += 1
if use_offload:
d_idx = best_idx
p_inst.ongoing_tokens += input_length # reserve immediately
breakdown["route_class"] = "HEAVY_P2P"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
else:
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
inst = best_inst
breakdown["routed_to"] = inst.url
inst.ongoing_tokens += input_length
async def generate():
first_token = True
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
inst.ongoing_decode_tokens += input_length
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
inst.record_prefix(token_ids)
finally:
inst.ongoing_tokens -= input_length
inst.ongoing_decode_tokens -= input_length
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst."""
request_id = headers.get("X-Request-Id", "")
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
breakdown["t_prefill_sent"] = _time.monotonic()
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
breakdown["t_prefill_done"] = _time.monotonic()
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["error"] = str(e)
_breakdown_log.append(breakdown)
global _offload_inflight
_offload_inflight = max(0, _offload_inflight - 1)
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
finally:
p_inst.ongoing_tokens -= input_length
_offload_inflight = max(0, _offload_inflight - 1)
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.ongoing_decode_tokens += input_length
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
d_inst.ongoing_tokens -= input_length
d_inst.ongoing_decode_tokens -= input_length
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
input_length, breakdown):
"""Fire-and-forget prefill: send and don't block caller."""
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
finally:
p_inst.ongoing_tokens -= input_length
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
if global_args.fire_and_forget:
asyncio.create_task(_send_prefill_async(
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
else:
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--fire-and-forget", action="store_true",
help="Send prefill async, don't await before decode")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
uvicorn.run(app, host=global_args.host, port=global_args.port)