Routing fix: new sessions placed by cumulative token load (greedy bin packing) with cache-hit tiebreak. Session affinity for turn 2+. Replayer now sends X-Session-Id header for proper session tracking. Agentic workload core patterns (GLM-5.1 trace): - 91% of reusable KV is intra-session (not cross-session) - Session-sticky routing is THE critical optimization - 36% warm requests (1.3k new tokens), 64% cold (17k+) - After cache: effective prefill/decode ratio drops from 61.5x to 28.7x - Cross-session sharing (system prompt) is only 4.8% of tokens Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
402 lines
15 KiB
Python
402 lines
15 KiB
Python
"""Unified cache-aware + token-level load-balanced global scheduler.
|
|
|
|
Supports two modes:
|
|
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
|
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
|
|
|
Routing policy (same for both modes):
|
|
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
|
|
Normalized load prevents "rich get richer"; cache bonus gives affinity.
|
|
Session affinity: multi-turn sessions stick to same instance.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import time as _time
|
|
import urllib.parse
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
BLOCK_SIZE = 512
|
|
CACHE_HIT_ALPHA = 1.0
|
|
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
|
|
|
|
|
|
class InstanceState:
|
|
def __init__(self, url: str, bootstrap_port: int | None = None):
|
|
self.url = url
|
|
self.bootstrap_port = bootstrap_port
|
|
self.client = httpx.AsyncClient(
|
|
timeout=None, base_url=url,
|
|
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
|
)
|
|
self.ongoing_tokens = 0
|
|
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
|
|
self.engine_id: dict[int, str] = {}
|
|
self.dp_size = 1
|
|
self.cached_blocks: set[int] = set()
|
|
|
|
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
|
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
|
return 0
|
|
hit = 0
|
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
|
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
|
if bh in self.cached_blocks:
|
|
hit += BLOCK_SIZE
|
|
else:
|
|
break
|
|
return hit
|
|
|
|
def record_prefix(self, token_ids: list[int] | None):
|
|
if not token_ids:
|
|
return
|
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
|
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
|
|
if len(self.cached_blocks) > 200000:
|
|
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
|
|
|
|
|
|
# Cumulative token load per instance (for balanced session placement)
|
|
_inst_cumulative_tokens: list[int] = []
|
|
|
|
|
|
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
|
session_id: str | None, input_length: int,
|
|
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
|
"""Session-sticky + KV-size balanced placement.
|
|
|
|
Turn 2+: session affinity (sticky to same instance for KV reuse).
|
|
Turn 1 (new session): place on instance with least cumulative token load
|
|
(greedy bin packing), with cache-hit tiebreak.
|
|
"""
|
|
global _inst_cumulative_tokens
|
|
if not _inst_cumulative_tokens:
|
|
_inst_cumulative_tokens = [0] * len(instances)
|
|
|
|
# Session affinity for turn 2+
|
|
if session_id and session_id in affinity:
|
|
idx = affinity[session_id]
|
|
if idx < len(instances):
|
|
return instances[idx], idx
|
|
|
|
# New session: balanced placement
|
|
# Primary: least cumulative tokens (long-term balance)
|
|
# Secondary: cache hit (tiebreak for prefix reuse)
|
|
min_load = min(_inst_cumulative_tokens)
|
|
# Candidates within 10% of min load
|
|
threshold = min_load + max(min_load * 0.1, 10000)
|
|
candidates = [i for i in range(len(instances))
|
|
if _inst_cumulative_tokens[i] <= threshold]
|
|
|
|
if not candidates:
|
|
candidates = list(range(len(instances)))
|
|
|
|
# Among candidates, pick best cache hit
|
|
best_idx = candidates[0]
|
|
best_hit = 0
|
|
for i in candidates:
|
|
hit = instances[i].estimate_cache_hit(token_ids)
|
|
if hit > best_hit:
|
|
best_hit = hit
|
|
best_idx = i
|
|
|
|
_inst_cumulative_tokens[best_idx] += input_length
|
|
if session_id:
|
|
affinity[session_id] = best_idx
|
|
return instances[best_idx], best_idx
|
|
|
|
|
|
global_args = None
|
|
combined_instances: list[InstanceState] = []
|
|
prefill_instances: list[InstanceState] = []
|
|
decode_instances: list[InstanceState] = []
|
|
session_affinity: dict[str, int] = {}
|
|
is_pd_sep = False
|
|
_breakdown_log: list[dict] = []
|
|
|
|
|
|
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
|
for inst in instances:
|
|
if inst.bootstrap_port is None:
|
|
continue
|
|
while True:
|
|
try:
|
|
await inst.client.get("/health")
|
|
except Exception:
|
|
await asyncio.sleep(1)
|
|
continue
|
|
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
|
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
|
resp = await inst.client.get(url)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
for dp_rank, dp_entry in data.items():
|
|
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
|
inst.dp_size = len(data)
|
|
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
|
break
|
|
ready.set()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global is_pd_sep
|
|
app.state.ready = asyncio.Event()
|
|
|
|
if global_args.combined:
|
|
is_pd_sep = False
|
|
for url in global_args.combined:
|
|
combined_instances.append(InstanceState(url))
|
|
app.state.ready.set()
|
|
print(f"Combined mode: {len(combined_instances)} instances")
|
|
else:
|
|
is_pd_sep = True
|
|
for url, bp in global_args.prefill:
|
|
prefill_instances.append(InstanceState(url, bp))
|
|
for url in global_args.decode:
|
|
decode_instances.append(InstanceState(url))
|
|
await init_prefill_bootstrap(prefill_instances, app.state.ready)
|
|
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
|
|
|
yield
|
|
for inst in combined_instances + prefill_instances + decode_instances:
|
|
await inst.client.aclose()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def handle_completions(request: Request):
|
|
return await _handle(request, "/v1/completions")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def handle_chat(request: Request):
|
|
return await _handle(request, "/v1/chat/completions")
|
|
|
|
|
|
async def _handle(request: Request, api: str):
|
|
if not app.state.ready.is_set():
|
|
raise HTTPException(status_code=503, detail="Service Unavailable")
|
|
|
|
req_data = await request.json()
|
|
request_id = str(uuid.uuid4())
|
|
prompt = req_data.get("prompt")
|
|
token_ids = prompt if isinstance(prompt, list) else None
|
|
input_length = len(token_ids) if token_ids else 0
|
|
session_id = request.headers.get("X-Session-Id")
|
|
|
|
headers = {"X-Request-Id": request_id}
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
if is_pd_sep:
|
|
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
|
input_length, session_id, headers)
|
|
else:
|
|
return await _handle_combined(api, req_data, token_ids,
|
|
input_length, session_id, headers)
|
|
|
|
|
|
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
|
"""Combined mode with adaptive prefill offload.
|
|
|
|
WARM/MEDIUM: route by cache-hit + load balance (co-located P+D).
|
|
HEAVY: route to instance with least decode load, avoiding decode disruption.
|
|
"""
|
|
# Estimate new tokens after cache
|
|
best_inst, best_idx = pick_instance(combined_instances, token_ids, session_id,
|
|
input_length, session_affinity)
|
|
cache_hit = best_inst.estimate_cache_hit(token_ids)
|
|
estimated_new = max(0, input_length - cache_hit)
|
|
|
|
breakdown = {
|
|
"request_id": headers.get("X-Request-Id", ""),
|
|
"input_length": input_length,
|
|
"estimated_new_tokens": estimated_new,
|
|
"cache_hit": cache_hit,
|
|
"t_proxy_recv": _time.monotonic(),
|
|
}
|
|
|
|
if estimated_new >= HEAVY_THRESHOLD:
|
|
# HEAVY: pick instance with least ongoing_decode_tokens
|
|
# This avoids sending heavy prefill to an instance busy decoding
|
|
inst = min(combined_instances, key=lambda x: x.ongoing_decode_tokens)
|
|
idx = combined_instances.index(inst)
|
|
breakdown["route_class"] = "HEAVY"
|
|
if session_id:
|
|
session_affinity[session_id] = idx
|
|
else:
|
|
inst = best_inst
|
|
idx = best_idx
|
|
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
|
|
|
|
breakdown["routed_to"] = inst.url
|
|
inst.ongoing_tokens += input_length
|
|
|
|
async def generate():
|
|
first_token = True
|
|
try:
|
|
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
# Once streaming starts, this instance is in "decode phase"
|
|
inst.ongoing_decode_tokens += input_length
|
|
async for chunk in resp.aiter_bytes():
|
|
if first_token:
|
|
breakdown["t_first_token"] = _time.monotonic()
|
|
first_token = False
|
|
yield chunk
|
|
inst.record_prefix(token_ids)
|
|
finally:
|
|
inst.ongoing_tokens -= input_length
|
|
inst.ongoing_decode_tokens -= input_length
|
|
breakdown["t_done"] = _time.monotonic()
|
|
_breakdown_log.append(breakdown)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
|
|
|
|
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
|
|
input_length, breakdown):
|
|
"""Fire-and-forget prefill: send and don't block caller."""
|
|
try:
|
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
resp.raise_for_status()
|
|
await resp.aclose()
|
|
p_inst.record_prefix(token_ids)
|
|
except Exception:
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
breakdown["prefill_error"] = True
|
|
finally:
|
|
p_inst.ongoing_tokens -= input_length
|
|
|
|
|
|
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
|
session_id, headers):
|
|
"""PD-Sep mode with per-stage breakdown profiling."""
|
|
breakdown = {
|
|
"request_id": request_id,
|
|
"input_length": input_length,
|
|
"t_proxy_recv": _time.monotonic(),
|
|
}
|
|
|
|
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
|
input_length, session_affinity)
|
|
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
|
breakdown["p_inst"] = p_inst.url
|
|
breakdown["d_inst"] = d_inst.url
|
|
|
|
prefill_data = req_data.copy()
|
|
prefill_data["kv_transfer_params"] = {
|
|
"do_remote_decode": True, "do_remote_prefill": False,
|
|
"transfer_id": f"xfer-{request_id}",
|
|
}
|
|
prefill_data["stream"] = False
|
|
prefill_data["max_tokens"] = 1
|
|
prefill_data.pop("max_completion_tokens", None)
|
|
prefill_data.pop("stream_options", None)
|
|
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
|
|
|
p_inst.ongoing_tokens += input_length
|
|
breakdown["t_prefill_sent"] = _time.monotonic()
|
|
|
|
if global_args.fire_and_forget:
|
|
asyncio.create_task(_send_prefill_async(
|
|
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
|
|
else:
|
|
try:
|
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
resp.raise_for_status()
|
|
await resp.aclose()
|
|
p_inst.record_prefix(token_ids)
|
|
except Exception as e:
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
breakdown["prefill_error"] = True
|
|
_breakdown_log.append(breakdown)
|
|
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
|
finally:
|
|
p_inst.ongoing_tokens -= input_length
|
|
|
|
# Send decode
|
|
d_inst.ongoing_tokens += input_length
|
|
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
|
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
|
|
|
decode_data = req_data.copy()
|
|
decode_data["kv_transfer_params"] = {
|
|
"do_remote_decode": False, "do_remote_prefill": True,
|
|
"remote_bootstrap_addr": bootstrap_addr,
|
|
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
|
"transfer_id": f"xfer-{request_id}",
|
|
}
|
|
|
|
breakdown["t_decode_sent"] = _time.monotonic()
|
|
|
|
async def generate():
|
|
first_token = True
|
|
try:
|
|
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
async for chunk in resp.aiter_bytes():
|
|
if first_token:
|
|
breakdown["t_first_token"] = _time.monotonic()
|
|
first_token = False
|
|
yield chunk
|
|
finally:
|
|
breakdown["t_done"] = _time.monotonic()
|
|
d_inst.ongoing_tokens -= input_length
|
|
_breakdown_log.append(breakdown)
|
|
|
|
return StreamingResponse(generate(), media_type="application/json")
|
|
|
|
|
|
@app.get("/breakdown")
|
|
async def get_breakdown():
|
|
"""Return per-request breakdown data for analysis."""
|
|
return _breakdown_log
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
|
p.add_argument("--port", type=int, default=8000)
|
|
p.add_argument("--host", type=str, default="0.0.0.0")
|
|
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
|
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
|
help="PD-Sep prefill: URL [bootstrap_port]")
|
|
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
|
help="PD-Sep decode: URL")
|
|
p.add_argument("--fire-and-forget", action="store_true",
|
|
help="Send prefill async, don't await before decode")
|
|
p.add_argument("--heavy-threshold", type=int, default=20000,
|
|
help="New tokens threshold for HEAVY classification (adaptive offload)")
|
|
args = p.parse_args()
|
|
|
|
args.prefill = []
|
|
if args.prefill_raw:
|
|
for entry in args.prefill_raw:
|
|
url = entry[0]
|
|
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
|
args.prefill.append((url, bp))
|
|
args.decode = [e[0] for e in (args.decode_raw or [])]
|
|
|
|
if not args.combined and not args.prefill:
|
|
p.error("Must specify either --combined or --prefill/--decode")
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
global_args = parse_args()
|
|
HEAVY_THRESHOLD = global_args.heavy_threshold
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|