Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang e4fa56cb1e LMetric routing policy (OSDI'26) + A/B results vs linear baseline
Implement LMetric (P_tokens × BS multiplication score) from "Simple is
Better" (Zhang et al., OSDI'26) as alternative routing policy for
combined mode. Key changes:

- cache_aware_proxy.py: add --policy {linear,lmetric} flag, track
  pending_prefill_tokens and num_requests per instance, /stats endpoint
- run_lmetric_ab.sh: automated A/B script for fair comparison

Results (200 req, fresh restart, same trace):
  Linear:  TTFT50=1.086  TPOT90=0.077  E2E50=5.423
  LMetric: TTFT50=1.099  TPOT90=0.073  E2E50=5.205
  Delta:   TTFT +1.2%    TPOT -5.9%    E2E -4.0%

LMetric improves TPOT/E2E modestly through better load balancing, but
routing policy headroom is limited vs elastic P2P offload (-44% E2E).

TODO: vLLM → Redis → router pipeline for exact state ablation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 16:57:32 +08:00

596 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
OVERLOAD_FACTOR = 2.0
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
# Cumulative token load per instance (for balanced session placement)
_inst_cumulative_tokens: list[int] = []
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
(ongoing_tokens > 2x average), in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
"""
global _inst_cumulative_tokens
if not _inst_cumulative_tokens:
_inst_cumulative_tokens = [0] * len(instances)
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
# Session affinity for turn 2+ (with load override)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
# Stick if not overloaded
if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR:
return inst, idx
# Overloaded: fall through to score-based selection
# Score = ongoing_tokens - ALPHA * cache_hit_tokens
# Balances load (lower is better) with cache affinity (higher hit is better)
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit
if score < best_score:
best_score = score
best_idx = i
_inst_cumulative_tokens[best_idx] += input_length
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
P_tokens = pending_prefill_tokens on instance + new request's uncached tokens.
BS = num_requests on instance + 1 (counting the new request).
"""
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR:
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests + 1
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
_breakdown_log: list[dict] = []
_offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload
_p_round_robin_idx = 0 # round-robin counter for P-instance selection
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with adaptive prefill offload (v2).
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on
session-sticky instance. Only works if instances have kv_role=kv_both.
Falls back to co-located if --no-offload or instances lack Mooncake.
"""
policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear'
picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance
best_inst, best_idx = picker(combined_instances, token_ids, session_id,
input_length, session_affinity)
cache_hit = best_inst.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"t_proxy_recv": _time.monotonic(),
}
offload_enabled = getattr(global_args, 'offload', False) if global_args else False
has_bootstrap = any(inst.bootstrap_port for inst in combined_instances)
# Elastic offload decision: offload only when it helps
use_offload = False
offload_reason = "disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled and has_bootstrap and len(combined_instances) >= 2:
d_inst = best_inst
p_candidates = [(i, inst) for i, inst in enumerate(combined_instances) if inst is not d_inst]
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
# Round-robin P selection with overload skip (spreads P-role evenly)
global _offload_inflight, _p_round_robin_idx
p_inst = None
for _ in range(len(p_candidates)):
_p_round_robin_idx = (_p_round_robin_idx + 1) % len(p_candidates)
candidate = p_candidates[_p_round_robin_idx][1]
if candidate.ongoing_tokens < avg_load * OVERLOAD_FACTOR:
p_inst = candidate
break
if p_inst is None:
p_inst = min(p_candidates, key=lambda x: x[1].ongoing_tokens)[1]
if _offload_inflight >= MAX_OFFLOAD_INFLIGHT:
offload_reason = "max_concurrent_reached"
elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2:
offload_reason = "p_saturated"
else:
use_offload = True
offload_reason = "offload_accepted"
_offload_inflight += 1
if use_offload:
d_idx = best_idx
p_inst.ongoing_tokens += input_length # reserve immediately
p_inst.pending_prefill_tokens += estimated_new
p_inst.num_requests += 1
breakdown["route_class"] = "HEAVY_P2P"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
else:
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
inst = best_inst
breakdown["routed_to"] = inst.url
breakdown["policy"] = policy
inst.ongoing_tokens += input_length
inst.pending_prefill_tokens += estimated_new
inst.num_requests += 1
async def generate():
prefill_done = False
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
inst.record_prefix(token_ids)
finally:
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
else:
inst.ongoing_decode_tokens -= input_length
inst.ongoing_tokens -= input_length
inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst."""
request_id = headers.get("X-Request-Id", "")
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
breakdown["t_prefill_sent"] = _time.monotonic()
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
breakdown["t_prefill_done"] = _time.monotonic()
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["error"] = str(e)
_breakdown_log.append(breakdown)
global _offload_inflight
_offload_inflight = max(0, _offload_inflight - 1)
p_inst.num_requests -= 1
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
finally:
p_inst.ongoing_tokens -= input_length
p_inst.pending_prefill_tokens -= breakdown.get("estimated_new_tokens", 0)
_offload_inflight = max(0, _offload_inflight - 1)
p_inst.num_requests -= 1
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.ongoing_decode_tokens += input_length
d_inst.num_requests += 1
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
d_inst.ongoing_tokens -= input_length
d_inst.ongoing_decode_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
input_length, breakdown):
"""Fire-and-forget prefill: send and don't block caller."""
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
finally:
p_inst.ongoing_tokens -= input_length
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
if global_args.fire_and_forget:
asyncio.create_task(_send_prefill_async(
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
else:
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--fire-and-forget", action="store_true",
help="Send prefill async, don't await before decode")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
uvicorn.run(app, host=global_args.host, port=global_args.port)