Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang 3594f7dce0 Fix LMetric routing: remove session affinity, align with OSDI'26 spec
LMetric was incorrectly sharing session-sticky logic with Linear policy.
Fixed to pure per-request routing: score = P_tokens × BS where
P = pending_prefill + (input - cache_hit), BS = num_requests.

Experiment result (200 req, fresh restart): Linear vs corrected LMetric
show <2% difference on all metrics — LMetric's cache-hit estimation
provides implicit soft affinity that preserves locality without explicit
session stickiness.

Also fix bench.sh missing cd (replayer module not found from non-project
cwd) and rewrite run_lmetric_ab.sh as thin wrapper around bench.sh to
eliminate duplicated launch/cleanup logic that broke under set -euo.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-23 11:56:58 +08:00

642 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
OVERLOAD_FACTOR = 2.0 # default; overridden by --overload-factor
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
# Cumulative token load per instance (for balanced session placement)
_inst_cumulative_tokens: list[int] = []
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for instances currently doing P-role offloaded prefills.
When an instance is busy with offloaded HEAVY prefills for other
instances, we want to steer WARM/MEDIUM requests away from it so
its GPU is dedicated to prefill (soft PD separation).
"""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * HEAVY_THRESHOLD
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
global _inst_cumulative_tokens
if not _inst_cumulative_tokens:
_inst_cumulative_tokens = [0] * len(instances)
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
_inst_cumulative_tokens[best_idx] += input_length
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity.
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
_breakdown_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with V2 P2P offload.
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY: C_s (session-sticky, has cache) does FAST prefill,
D (least-loaded C, D != C_s) pulls KV via Mooncake and decodes.
Offload only when D is meaningfully less loaded than C_s.
"""
policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear'
picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance
best_inst, best_idx = picker(combined_instances, token_ids, session_id,
input_length, session_affinity)
cache_hit = best_inst.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"t_proxy_recv": _time.monotonic(),
}
# H4 cache-aware offload gate: only offload when C_s has significant cache
# Cold turn-1 HEAVY: stay co-located (no RDMA overhead)
# Cached turn-2+ HEAVY: offload to flexible D (C_s fast prefill + D decode)
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
use_offload = False
offload_reason = "offload_disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled:
cache_ratio = cache_hit / max(input_length, 1)
d_candidate = min((c for c in combined_instances if c is not best_inst),
key=lambda c: c.ongoing_tokens)
breakdown["cache_ratio"] = cache_ratio
if cache_ratio >= 0.3: # at least 30% cache hit to justify RDMA offload
use_offload = True
offload_reason = "cached_offload_%.0f%%" % (cache_ratio * 100)
else:
offload_reason = "cold_colocated_%.0f%%" % (cache_ratio * 100)
if use_offload:
# C_s does fast cached prefill, D does decode
p_inst = best_inst # session-sticky, has prefix cache
d_inst = d_candidate
d_idx = combined_instances.index(d_inst)
# Accounting: C_s only prefills estimated_new tokens (cached prefix is free)
p_inst.ongoing_tokens += input_length
p_inst.pending_prefill_tokens += estimated_new
p_inst.num_requests += 1
p_inst.active_p_offloads += 1
breakdown["route_class"] = "HEAVY_OFFLOAD"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
# Update session affinity to D (D will have KV after this request)
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
elif estimated_new < 5000:
breakdown["route_class"] = "WARM"
else:
breakdown["route_class"] = "MEDIUM"
inst = best_inst
breakdown["routed_to"] = inst.url
breakdown["policy"] = policy
inst.ongoing_tokens += input_length
inst.pending_prefill_tokens += estimated_new
inst.num_requests += 1
async def generate():
prefill_done = False
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
inst.record_prefix(token_ids)
finally:
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
else:
inst.ongoing_decode_tokens -= input_length
inst.ongoing_tokens -= input_length
inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request: prefill on p_inst (C_s), KV via Mooncake, decode on d_inst (D).
On prefill timeout/failure, falls back to co-located decode on d_inst.
"""
request_id = headers.get("X-Request-Id", "")
estimated_new = breakdown.get("estimated_new_tokens", 0)
# V2: p_inst is C_s with cache, so pending_prefill_tokens was incremented
# by estimated_new (only new tokens), not full input_length.
p_prefill_release = estimated_new
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
breakdown["t_prefill_sent"] = _time.monotonic()
prefill_ok = False
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await asyncio.wait_for(
p_inst.client.post(api, json=prefill_data, headers=p_headers),
timeout=PREFILL_TIMEOUT_S,
)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
breakdown["t_prefill_done"] = _time.monotonic()
prefill_ok = True
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = str(e)
finally:
# Always release P-instance resources exactly once
p_inst.ongoing_tokens -= input_length
p_inst.pending_prefill_tokens -= p_prefill_release
p_inst.num_requests -= 1
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
if not prefill_ok:
# Fallback: co-located prefill+decode on d_inst (no KV transfer)
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += estimated_new
d_inst.num_requests += 1
async def generate_fallback():
prefill_done = False
try:
async with d_inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
d_inst.ongoing_tokens += input_length
d_inst.ongoing_decode_tokens += input_length
d_inst.num_requests += 1
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
d_inst.ongoing_tokens -= input_length
d_inst.ongoing_decode_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
input_length, breakdown):
"""Fire-and-forget prefill: send and don't block caller."""
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
finally:
p_inst.ongoing_tokens -= input_length
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
if global_args.fire_and_forget:
asyncio.create_task(_send_prefill_async(
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
else:
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--fire-and-forget", action="store_true",
help="Send prefill async, don't await before decode")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
OVERLOAD_FACTOR = global_args.overload_factor
uvicorn.run(app, host=global_args.host, port=global_args.port)