Fixed offload decision: removed p>=d gate (was blocking all offloads), added MAX_OFFLOAD_INFLIGHT=4 cap and p_saturated threshold. Result (200 req, fresh restart): Baseline: 99% success, TTFT=1.080/9.410, TPOT90=0.076, E2E=5.306 Elastic: 96% success, TTFT=0.946/15.843, TPOT90=0.077, E2E=5.717 Architectural tradeoff confirmed: - Median (p50) improves: D instances not disrupted by heavy prefill - Tail (p90) worsens: offloaded HEAVY requests pay KV transfer cost - TPOT unchanged: decode isolation is not the bottleneck To improve p90: need layerwise pipelined KV transfer (overlap with prefill compute) or smarter offload gating that avoids offloading the very largest requests (which have the longest prefill time and generate the most KV). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
520 lines
20 KiB
Python
520 lines
20 KiB
Python
"""Unified cache-aware + token-level load-balanced global scheduler.
|
|
|
|
Supports two modes:
|
|
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
|
|
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
|
|
|
|
Routing policy (same for both modes):
|
|
score = ongoing_tokens / avg_ongoing - ALPHA * cache_hit_ratio
|
|
Normalized load prevents "rich get richer"; cache bonus gives affinity.
|
|
Session affinity: multi-turn sessions stick to same instance.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import time as _time
|
|
import urllib.parse
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
BLOCK_SIZE = 512
|
|
CACHE_HIT_ALPHA = 1.0
|
|
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
|
|
|
|
|
|
class InstanceState:
|
|
def __init__(self, url: str, bootstrap_port: int | None = None):
|
|
self.url = url
|
|
self.bootstrap_port = bootstrap_port
|
|
self.client = httpx.AsyncClient(
|
|
timeout=None, base_url=url,
|
|
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
|
|
)
|
|
self.ongoing_tokens = 0
|
|
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
|
|
self.engine_id: dict[int, str] = {}
|
|
self.dp_size = 1
|
|
self.cached_blocks: set[int] = set()
|
|
|
|
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
|
|
if not token_ids or len(token_ids) < BLOCK_SIZE:
|
|
return 0
|
|
hit = 0
|
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
|
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
|
|
if bh in self.cached_blocks:
|
|
hit += BLOCK_SIZE
|
|
else:
|
|
break
|
|
return hit
|
|
|
|
def record_prefix(self, token_ids: list[int] | None):
|
|
if not token_ids:
|
|
return
|
|
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
|
|
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
|
|
if len(self.cached_blocks) > 200000:
|
|
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
|
|
|
|
|
|
# Cumulative token load per instance (for balanced session placement)
|
|
_inst_cumulative_tokens: list[int] = []
|
|
|
|
|
|
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
|
|
session_id: str | None, input_length: int,
|
|
affinity: dict[str, int]) -> tuple[InstanceState, int]:
|
|
"""Session-sticky with load-aware override.
|
|
|
|
Turn 2+: use session affinity UNLESS pinned instance is overloaded
|
|
(ongoing_tokens > 2x average), in which case pick least-loaded.
|
|
Turn 1: pick instance with best score (load + cache combined).
|
|
"""
|
|
global _inst_cumulative_tokens
|
|
if not _inst_cumulative_tokens:
|
|
_inst_cumulative_tokens = [0] * len(instances)
|
|
|
|
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
|
|
OVERLOAD_FACTOR = 2.0
|
|
|
|
# Session affinity for turn 2+ (with load override)
|
|
if session_id and session_id in affinity:
|
|
idx = affinity[session_id]
|
|
if idx < len(instances):
|
|
inst = instances[idx]
|
|
# Stick if not overloaded
|
|
if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR:
|
|
return inst, idx
|
|
# Overloaded: fall through to score-based selection
|
|
|
|
# Score = ongoing_tokens - ALPHA * cache_hit_tokens
|
|
# Balances load (lower is better) with cache affinity (higher hit is better)
|
|
best_idx, best_score = 0, float("inf")
|
|
for i, inst in enumerate(instances):
|
|
cache_hit = inst.estimate_cache_hit(token_ids)
|
|
score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit
|
|
if score < best_score:
|
|
best_score = score
|
|
best_idx = i
|
|
|
|
_inst_cumulative_tokens[best_idx] += input_length
|
|
if session_id:
|
|
affinity[session_id] = best_idx
|
|
return instances[best_idx], best_idx
|
|
|
|
|
|
global_args = None
|
|
combined_instances: list[InstanceState] = []
|
|
prefill_instances: list[InstanceState] = []
|
|
decode_instances: list[InstanceState] = []
|
|
session_affinity: dict[str, int] = {}
|
|
is_pd_sep = False
|
|
_breakdown_log: list[dict] = []
|
|
_offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests
|
|
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload
|
|
|
|
|
|
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
|
|
for inst in instances:
|
|
if inst.bootstrap_port is None:
|
|
continue
|
|
while True:
|
|
try:
|
|
await inst.client.get("/health")
|
|
except Exception:
|
|
await asyncio.sleep(1)
|
|
continue
|
|
parsed = urllib.parse.urlparse(str(inst.client.base_url))
|
|
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
|
|
resp = await inst.client.get(url)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
for dp_rank, dp_entry in data.items():
|
|
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
|
|
inst.dp_size = len(data)
|
|
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
|
|
break
|
|
ready.set()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global is_pd_sep
|
|
app.state.ready = asyncio.Event()
|
|
|
|
if global_args.combined:
|
|
is_pd_sep = False
|
|
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
|
|
for i, url in enumerate(global_args.combined):
|
|
bp = bp_list[i] if i < len(bp_list) else None
|
|
combined_instances.append(InstanceState(url, bp))
|
|
if global_args.offload and bp_list:
|
|
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
|
else:
|
|
app.state.ready.set()
|
|
print(f"Combined mode: {len(combined_instances)} instances, offload={'ON' if global_args.offload else 'OFF'}")
|
|
else:
|
|
is_pd_sep = True
|
|
for url, bp in global_args.prefill:
|
|
prefill_instances.append(InstanceState(url, bp))
|
|
for url in global_args.decode:
|
|
decode_instances.append(InstanceState(url))
|
|
await init_prefill_bootstrap(prefill_instances, app.state.ready)
|
|
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
|
|
|
|
yield
|
|
for inst in combined_instances + prefill_instances + decode_instances:
|
|
await inst.client.aclose()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def handle_completions(request: Request):
|
|
return await _handle(request, "/v1/completions")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def handle_chat(request: Request):
|
|
return await _handle(request, "/v1/chat/completions")
|
|
|
|
|
|
async def _handle(request: Request, api: str):
|
|
if not app.state.ready.is_set():
|
|
raise HTTPException(status_code=503, detail="Service Unavailable")
|
|
|
|
req_data = await request.json()
|
|
request_id = str(uuid.uuid4())
|
|
prompt = req_data.get("prompt")
|
|
token_ids = prompt if isinstance(prompt, list) else None
|
|
input_length = len(token_ids) if token_ids else 0
|
|
session_id = request.headers.get("X-Session-Id")
|
|
|
|
headers = {"X-Request-Id": request_id}
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
if is_pd_sep:
|
|
return await _handle_pd_sep(api, req_data, request_id, token_ids,
|
|
input_length, session_id, headers)
|
|
else:
|
|
return await _handle_combined(api, req_data, token_ids,
|
|
input_length, session_id, headers)
|
|
|
|
|
|
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
|
"""Combined mode with adaptive prefill offload (v2).
|
|
|
|
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
|
|
HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on
|
|
session-sticky instance. Only works if instances have kv_role=kv_both.
|
|
Falls back to co-located if --no-offload or instances lack Mooncake.
|
|
"""
|
|
best_inst, best_idx = pick_instance(combined_instances, token_ids, session_id,
|
|
input_length, session_affinity)
|
|
cache_hit = best_inst.estimate_cache_hit(token_ids)
|
|
estimated_new = max(0, input_length - cache_hit)
|
|
|
|
breakdown = {
|
|
"request_id": headers.get("X-Request-Id", ""),
|
|
"input_length": input_length,
|
|
"estimated_new_tokens": estimated_new,
|
|
"cache_hit": cache_hit,
|
|
"t_proxy_recv": _time.monotonic(),
|
|
}
|
|
|
|
offload_enabled = getattr(global_args, 'offload', False) if global_args else False
|
|
has_bootstrap = any(inst.bootstrap_port for inst in combined_instances)
|
|
|
|
# Elastic offload decision: offload only when it helps
|
|
use_offload = False
|
|
offload_reason = "disabled"
|
|
if estimated_new >= HEAVY_THRESHOLD and offload_enabled and has_bootstrap and len(combined_instances) >= 2:
|
|
d_inst = best_inst
|
|
p_candidates = [inst for inst in combined_instances if inst is not d_inst]
|
|
p_inst = min(p_candidates, key=lambda x: x.ongoing_tokens)
|
|
avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0)
|
|
|
|
# Decision logic:
|
|
# 1. Global cap: max N concurrent offloads (prevents all-offload storm)
|
|
# 2. P must not already be saturated with heavy prefills
|
|
# 3. D must be doing something (otherwise no benefit from offloading)
|
|
# NOTE: We do NOT require P < D. P can be busier than D — the point
|
|
# is to keep heavy prefill OFF the session-sticky D instance so D's
|
|
# decode is not disrupted and D's KV cache is available for future turns.
|
|
global _offload_inflight
|
|
if _offload_inflight >= MAX_OFFLOAD_INFLIGHT:
|
|
offload_reason = "max_concurrent_reached"
|
|
elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2:
|
|
offload_reason = "p_saturated"
|
|
else:
|
|
use_offload = True
|
|
offload_reason = "offload_accepted"
|
|
_offload_inflight += 1
|
|
|
|
if use_offload:
|
|
d_idx = best_idx
|
|
p_inst.ongoing_tokens += input_length # reserve immediately
|
|
|
|
breakdown["route_class"] = "HEAVY_P2P"
|
|
breakdown["offload_reason"] = offload_reason
|
|
breakdown["p_inst"] = p_inst.url
|
|
breakdown["d_inst"] = d_inst.url
|
|
breakdown["p_load"] = p_inst.ongoing_tokens
|
|
breakdown["d_load"] = d_inst.ongoing_tokens
|
|
if session_id:
|
|
session_affinity[session_id] = d_idx
|
|
|
|
return await _handle_heavy_offload(api, req_data, headers, token_ids,
|
|
input_length, p_inst, d_inst, breakdown)
|
|
else:
|
|
if estimated_new >= HEAVY_THRESHOLD:
|
|
breakdown["route_class"] = "HEAVY_COLO"
|
|
breakdown["offload_reason"] = offload_reason
|
|
else:
|
|
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
|
|
|
|
inst = best_inst
|
|
breakdown["routed_to"] = inst.url
|
|
inst.ongoing_tokens += input_length
|
|
|
|
async def generate():
|
|
first_token = True
|
|
try:
|
|
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
inst.ongoing_decode_tokens += input_length
|
|
async for chunk in resp.aiter_bytes():
|
|
if first_token:
|
|
breakdown["t_first_token"] = _time.monotonic()
|
|
first_token = False
|
|
yield chunk
|
|
inst.record_prefix(token_ids)
|
|
finally:
|
|
inst.ongoing_tokens -= input_length
|
|
inst.ongoing_decode_tokens -= input_length
|
|
breakdown["t_done"] = _time.monotonic()
|
|
_breakdown_log.append(breakdown)
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
|
|
|
|
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
|
p_inst, d_inst, breakdown):
|
|
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst."""
|
|
request_id = headers.get("X-Request-Id", "")
|
|
|
|
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
|
|
breakdown["t_prefill_sent"] = _time.monotonic()
|
|
try:
|
|
prefill_data = req_data.copy()
|
|
prefill_data["kv_transfer_params"] = {
|
|
"do_remote_decode": True,
|
|
"do_remote_prefill": False,
|
|
"transfer_id": "xfer-" + request_id,
|
|
}
|
|
prefill_data["stream"] = False
|
|
prefill_data["max_tokens"] = 1
|
|
prefill_data.pop("max_completion_tokens", None)
|
|
prefill_data.pop("stream_options", None)
|
|
|
|
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
|
resp.raise_for_status()
|
|
await resp.aclose()
|
|
p_inst.record_prefix(token_ids)
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
except Exception as e:
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
breakdown["error"] = str(e)
|
|
_breakdown_log.append(breakdown)
|
|
global _offload_inflight
|
|
_offload_inflight = max(0, _offload_inflight - 1)
|
|
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
|
|
finally:
|
|
p_inst.ongoing_tokens -= input_length
|
|
_offload_inflight = max(0, _offload_inflight - 1)
|
|
|
|
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
|
|
d_inst.ongoing_tokens += input_length
|
|
d_inst.ongoing_decode_tokens += input_length
|
|
breakdown["t_decode_sent"] = _time.monotonic()
|
|
|
|
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
|
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
|
|
|
|
decode_data = req_data.copy()
|
|
decode_data["kv_transfer_params"] = {
|
|
"do_remote_decode": False,
|
|
"do_remote_prefill": True,
|
|
"remote_bootstrap_addr": bootstrap_addr,
|
|
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
|
"transfer_id": "xfer-" + request_id,
|
|
}
|
|
|
|
async def generate():
|
|
first_token = True
|
|
try:
|
|
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
async for chunk in resp.aiter_bytes():
|
|
if first_token:
|
|
breakdown["t_first_token"] = _time.monotonic()
|
|
first_token = False
|
|
yield chunk
|
|
finally:
|
|
d_inst.ongoing_tokens -= input_length
|
|
d_inst.ongoing_decode_tokens -= input_length
|
|
breakdown["t_done"] = _time.monotonic()
|
|
_breakdown_log.append(breakdown)
|
|
|
|
return StreamingResponse(generate(), media_type="application/json")
|
|
|
|
|
|
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
|
|
input_length, breakdown):
|
|
"""Fire-and-forget prefill: send and don't block caller."""
|
|
try:
|
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
resp.raise_for_status()
|
|
await resp.aclose()
|
|
p_inst.record_prefix(token_ids)
|
|
except Exception:
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
breakdown["prefill_error"] = True
|
|
finally:
|
|
p_inst.ongoing_tokens -= input_length
|
|
|
|
|
|
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
|
|
session_id, headers):
|
|
"""PD-Sep mode with per-stage breakdown profiling."""
|
|
breakdown = {
|
|
"request_id": request_id,
|
|
"input_length": input_length,
|
|
"t_proxy_recv": _time.monotonic(),
|
|
}
|
|
|
|
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
|
|
input_length, session_affinity)
|
|
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
|
|
breakdown["p_inst"] = p_inst.url
|
|
breakdown["d_inst"] = d_inst.url
|
|
|
|
prefill_data = req_data.copy()
|
|
prefill_data["kv_transfer_params"] = {
|
|
"do_remote_decode": True, "do_remote_prefill": False,
|
|
"transfer_id": f"xfer-{request_id}",
|
|
}
|
|
prefill_data["stream"] = False
|
|
prefill_data["max_tokens"] = 1
|
|
prefill_data.pop("max_completion_tokens", None)
|
|
prefill_data.pop("stream_options", None)
|
|
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
|
|
|
p_inst.ongoing_tokens += input_length
|
|
breakdown["t_prefill_sent"] = _time.monotonic()
|
|
|
|
if global_args.fire_and_forget:
|
|
asyncio.create_task(_send_prefill_async(
|
|
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
|
|
else:
|
|
try:
|
|
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
resp.raise_for_status()
|
|
await resp.aclose()
|
|
p_inst.record_prefix(token_ids)
|
|
except Exception as e:
|
|
breakdown["t_prefill_done"] = _time.monotonic()
|
|
breakdown["prefill_error"] = True
|
|
_breakdown_log.append(breakdown)
|
|
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
|
|
finally:
|
|
p_inst.ongoing_tokens -= input_length
|
|
|
|
# Send decode
|
|
d_inst.ongoing_tokens += input_length
|
|
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
|
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
|
|
|
|
decode_data = req_data.copy()
|
|
decode_data["kv_transfer_params"] = {
|
|
"do_remote_decode": False, "do_remote_prefill": True,
|
|
"remote_bootstrap_addr": bootstrap_addr,
|
|
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
|
"transfer_id": f"xfer-{request_id}",
|
|
}
|
|
|
|
breakdown["t_decode_sent"] = _time.monotonic()
|
|
|
|
async def generate():
|
|
first_token = True
|
|
try:
|
|
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
|
resp.raise_for_status()
|
|
async for chunk in resp.aiter_bytes():
|
|
if first_token:
|
|
breakdown["t_first_token"] = _time.monotonic()
|
|
first_token = False
|
|
yield chunk
|
|
finally:
|
|
breakdown["t_done"] = _time.monotonic()
|
|
d_inst.ongoing_tokens -= input_length
|
|
_breakdown_log.append(breakdown)
|
|
|
|
return StreamingResponse(generate(), media_type="application/json")
|
|
|
|
|
|
@app.get("/breakdown")
|
|
async def get_breakdown():
|
|
"""Return per-request breakdown data for analysis."""
|
|
return _breakdown_log
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
|
|
p.add_argument("--port", type=int, default=8000)
|
|
p.add_argument("--host", type=str, default="0.0.0.0")
|
|
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
|
|
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
|
|
help="PD-Sep prefill: URL [bootstrap_port]")
|
|
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
|
|
help="PD-Sep decode: URL")
|
|
p.add_argument("--fire-and-forget", action="store_true",
|
|
help="Send prefill async, don't await before decode")
|
|
p.add_argument("--heavy-threshold", type=int, default=20000,
|
|
help="New tokens threshold for HEAVY classification (adaptive offload)")
|
|
p.add_argument("--offload", action="store_true",
|
|
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
|
|
p.add_argument("--bootstrap-ports", type=str, default="",
|
|
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
|
args = p.parse_args()
|
|
|
|
args.prefill = []
|
|
if args.prefill_raw:
|
|
for entry in args.prefill_raw:
|
|
url = entry[0]
|
|
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
|
|
args.prefill.append((url, bp))
|
|
args.decode = [e[0] for e in (args.decode_raw or [])]
|
|
|
|
if not args.combined and not args.prefill:
|
|
p.error("Must specify either --combined or --prefill/--decode")
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
global_args = parse_args()
|
|
HEAVY_THRESHOLD = global_args.heavy_threshold
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|