"""Unified cache-aware + token-level load-balanced global scheduler. Supports two modes: --combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer) --prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer) Routing policies (--policy): linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens lmetric: score = P_tokens * BS (LMetric, OSDI'26) P_tokens = pending_prefill_tokens + new_uncached_tokens BS = num_requests (waiting + running) Session affinity: multi-turn sessions stick to same instance (all policies). """ import argparse import asyncio import os import time as _time import urllib.parse import uuid from contextlib import asynccontextmanager import httpx import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse BLOCK_SIZE = 512 CACHE_HIT_ALPHA = 1.0 HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold OVERLOAD_FACTOR = 2.0 class InstanceState: def __init__(self, url: str, bootstrap_port: int | None = None): self.url = url self.bootstrap_port = bootstrap_port self.client = httpx.AsyncClient( timeout=None, base_url=url, limits=httpx.Limits(max_connections=None, max_keepalive_connections=None), ) self.ongoing_tokens = 0 self.ongoing_decode_tokens = 0 # subset: tokens in decode phase self.pending_prefill_tokens = 0 # tokens for requests still in prefill self.num_requests = 0 # total in-flight requests (waiting + running) self.engine_id: dict[int, str] = {} self.dp_size = 1 self.cached_blocks: set[int] = set() def estimate_cache_hit(self, token_ids: list[int] | None) -> int: if not token_ids or len(token_ids) < BLOCK_SIZE: return 0 hit = 0 for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) if bh in self.cached_blocks: hit += BLOCK_SIZE else: break return hit def record_prefix(self, token_ids: list[int] | None): if not token_ids: return for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE]))) if len(self.cached_blocks) > 200000: self.cached_blocks = set(list(self.cached_blocks)[-100000:]) # Cumulative token load per instance (for balanced session placement) _inst_cumulative_tokens: list[int] = [] def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """Session-sticky with load-aware override. Turn 2+: use session affinity UNLESS pinned instance is overloaded (ongoing_tokens > 2x average), in which case pick least-loaded. Turn 1: pick instance with best score (load + cache combined). """ global _inst_cumulative_tokens if not _inst_cumulative_tokens: _inst_cumulative_tokens = [0] * len(instances) avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) # Session affinity for turn 2+ (with load override) if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): inst = instances[idx] # Stick if not overloaded if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR: return inst, idx # Overloaded: fall through to score-based selection # Score = ongoing_tokens - ALPHA * cache_hit_tokens # Balances load (lower is better) with cache affinity (higher hit is better) best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) score = inst.ongoing_tokens - CACHE_HIT_ALPHA * cache_hit if score < best_score: best_score = score best_idx = i _inst_cumulative_tokens[best_idx] += input_length if session_id: affinity[session_id] = best_idx return instances[best_idx], best_idx def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """LMetric routing: score = P_tokens × BS (OSDI'26). P_tokens = pending_prefill_tokens on instance + new request's uncached tokens. BS = num_requests on instance + 1 (counting the new request). """ avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): inst = instances[idx] if inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR: return inst, idx best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) new_prefill = max(0, input_length - cache_hit) p_tokens = inst.pending_prefill_tokens + new_prefill bs = inst.num_requests + 1 score = p_tokens * bs if score < best_score: best_score = score best_idx = i if session_id: affinity[session_id] = best_idx return instances[best_idx], best_idx global_args = None combined_instances: list[InstanceState] = [] prefill_instances: list[InstanceState] = [] decode_instances: list[InstanceState] = [] session_affinity: dict[str, int] = {} is_pd_sep = False _breakdown_log: list[dict] = [] _offload_inflight = 0 # number of currently in-flight offloaded HEAVY requests MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent offloads to prevent P overload _p_round_robin_idx = 0 # round-robin counter for P-instance selection async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): for inst in instances: if inst.bootstrap_port is None: continue while True: try: await inst.client.get("/health") except Exception: await asyncio.sleep(1) continue parsed = urllib.parse.urlparse(str(inst.client.base_url)) url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query" resp = await inst.client.get(url) resp.raise_for_status() data = resp.json() for dp_rank, dp_entry in data.items(): inst.engine_id[int(dp_rank)] = dp_entry["engine_id"] inst.dp_size = len(data) print(f"Inited {inst.url} engine_ids={inst.engine_id}") break ready.set() @asynccontextmanager async def lifespan(app: FastAPI): global is_pd_sep app.state.ready = asyncio.Event() if global_args.combined: is_pd_sep = False bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else [] for i, url in enumerate(global_args.combined): bp = bp_list[i] if i < len(bp_list) else None combined_instances.append(InstanceState(url, bp)) if global_args.offload and bp_list: await init_prefill_bootstrap(combined_instances, app.state.ready) else: app.state.ready.set() policy = getattr(global_args, 'policy', 'linear') print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}") else: is_pd_sep = True for url, bp in global_args.prefill: prefill_instances.append(InstanceState(url, bp)) for url in global_args.decode: decode_instances.append(InstanceState(url)) await init_prefill_bootstrap(prefill_instances, app.state.ready) print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D") yield for inst in combined_instances + prefill_instances + decode_instances: await inst.client.aclose() app = FastAPI(lifespan=lifespan) @app.post("/v1/completions") async def handle_completions(request: Request): return await _handle(request, "/v1/completions") @app.post("/v1/chat/completions") async def handle_chat(request: Request): return await _handle(request, "/v1/chat/completions") async def _handle(request: Request, api: str): if not app.state.ready.is_set(): raise HTTPException(status_code=503, detail="Service Unavailable") req_data = await request.json() request_id = str(uuid.uuid4()) prompt = req_data.get("prompt") token_ids = prompt if isinstance(prompt, list) else None input_length = len(token_ids) if token_ids else 0 session_id = request.headers.get("X-Session-Id") headers = {"X-Request-Id": request_id} api_key = os.environ.get("OPENAI_API_KEY") if api_key: headers["Authorization"] = f"Bearer {api_key}" if is_pd_sep: return await _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers) else: return await _handle_combined(api, req_data, token_ids, input_length, session_id, headers) async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): """Combined mode with adaptive prefill offload (v2). WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer). HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on session-sticky instance. Only works if instances have kv_role=kv_both. Falls back to co-located if --no-offload or instances lack Mooncake. """ policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear' picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance best_inst, best_idx = picker(combined_instances, token_ids, session_id, input_length, session_affinity) cache_hit = best_inst.estimate_cache_hit(token_ids) estimated_new = max(0, input_length - cache_hit) breakdown = { "request_id": headers.get("X-Request-Id", ""), "input_length": input_length, "estimated_new_tokens": estimated_new, "cache_hit": cache_hit, "t_proxy_recv": _time.monotonic(), } offload_enabled = getattr(global_args, 'offload', False) if global_args else False has_bootstrap = any(inst.bootstrap_port for inst in combined_instances) # Elastic offload decision: offload only when it helps use_offload = False offload_reason = "disabled" if estimated_new >= HEAVY_THRESHOLD and offload_enabled and has_bootstrap and len(combined_instances) >= 2: d_inst = best_inst p_candidates = [(i, inst) for i, inst in enumerate(combined_instances) if inst is not d_inst] avg_load = max(sum(i.ongoing_tokens for i in combined_instances) / len(combined_instances), 1.0) # Round-robin P selection with overload skip (spreads P-role evenly) global _offload_inflight, _p_round_robin_idx p_inst = None for _ in range(len(p_candidates)): _p_round_robin_idx = (_p_round_robin_idx + 1) % len(p_candidates) candidate = p_candidates[_p_round_robin_idx][1] if candidate.ongoing_tokens < avg_load * OVERLOAD_FACTOR: p_inst = candidate break if p_inst is None: p_inst = min(p_candidates, key=lambda x: x[1].ongoing_tokens)[1] if _offload_inflight >= MAX_OFFLOAD_INFLIGHT: offload_reason = "max_concurrent_reached" elif p_inst.ongoing_tokens >= HEAVY_THRESHOLD * 2: offload_reason = "p_saturated" else: use_offload = True offload_reason = "offload_accepted" _offload_inflight += 1 if use_offload: d_idx = best_idx p_inst.ongoing_tokens += input_length # reserve immediately p_inst.pending_prefill_tokens += estimated_new p_inst.num_requests += 1 breakdown["route_class"] = "HEAVY_P2P" breakdown["offload_reason"] = offload_reason breakdown["p_inst"] = p_inst.url breakdown["d_inst"] = d_inst.url breakdown["p_load"] = p_inst.ongoing_tokens breakdown["d_load"] = d_inst.ongoing_tokens if session_id: session_affinity[session_id] = d_idx return await _handle_heavy_offload(api, req_data, headers, token_ids, input_length, p_inst, d_inst, breakdown) else: if estimated_new >= HEAVY_THRESHOLD: breakdown["route_class"] = "HEAVY_COLO" breakdown["offload_reason"] = offload_reason else: breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM" inst = best_inst breakdown["routed_to"] = inst.url breakdown["policy"] = policy inst.ongoing_tokens += input_length inst.pending_prefill_tokens += estimated_new inst.num_requests += 1 async def generate(): prefill_done = False try: async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): if not prefill_done: inst.pending_prefill_tokens -= estimated_new inst.ongoing_decode_tokens += input_length breakdown["t_first_token"] = _time.monotonic() prefill_done = True yield chunk inst.record_prefix(token_ids) finally: if not prefill_done: inst.pending_prefill_tokens -= estimated_new else: inst.ongoing_decode_tokens -= input_length inst.ongoing_tokens -= input_length inst.num_requests -= 1 breakdown["t_done"] = _time.monotonic() _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="text/event-stream") async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length, p_inst, d_inst, breakdown): """HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst.""" request_id = headers.get("X-Request-Id", "") # Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller) breakdown["t_prefill_sent"] = _time.monotonic() try: prefill_data = req_data.copy() prefill_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": "xfer-" + request_id, } prefill_data["stream"] = False prefill_data["max_tokens"] = 1 prefill_data.pop("max_completion_tokens", None) prefill_data.pop("stream_options", None) p_headers = {**headers, "X-data-parallel-rank": "0"} resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) resp.raise_for_status() await resp.aclose() p_inst.record_prefix(token_ids) breakdown["t_prefill_done"] = _time.monotonic() except Exception as e: breakdown["t_prefill_done"] = _time.monotonic() breakdown["error"] = str(e) _breakdown_log.append(breakdown) global _offload_inflight _offload_inflight = max(0, _offload_inflight - 1) p_inst.num_requests -= 1 raise HTTPException(status_code=502, detail="Prefill failed: %s" % e) finally: p_inst.ongoing_tokens -= input_length p_inst.pending_prefill_tokens -= breakdown.get("estimated_new_tokens", 0) _offload_inflight = max(0, _offload_inflight - 1) p_inst.num_requests -= 1 # Step 2: Stream decode on d_inst (pulls KV from Mooncake) d_inst.ongoing_tokens += input_length d_inst.ongoing_decode_tokens += input_length d_inst.num_requests += 1 breakdown["t_decode_sent"] = _time.monotonic() parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port) decode_data = req_data.copy() decode_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, "remote_bootstrap_addr": bootstrap_addr, "remote_engine_id": p_inst.engine_id.get(0, ""), "transfer_id": "xfer-" + request_id, } async def generate(): first_token = True try: async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): if first_token: breakdown["t_first_token"] = _time.monotonic() first_token = False yield chunk finally: d_inst.ongoing_tokens -= input_length d_inst.ongoing_decode_tokens -= input_length d_inst.num_requests -= 1 breakdown["t_done"] = _time.monotonic() _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="application/json") async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown): """Fire-and-forget prefill: send and don't block caller.""" try: resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) breakdown["t_prefill_done"] = _time.monotonic() resp.raise_for_status() await resp.aclose() p_inst.record_prefix(token_ids) except Exception: breakdown["t_prefill_done"] = _time.monotonic() breakdown["prefill_error"] = True finally: p_inst.ongoing_tokens -= input_length async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers): """PD-Sep mode with per-stage breakdown profiling.""" breakdown = { "request_id": request_id, "input_length": input_length, "t_proxy_recv": _time.monotonic(), } p_inst, _ = pick_instance(prefill_instances, token_ids, session_id, input_length, session_affinity) d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens) breakdown["p_inst"] = p_inst.url breakdown["d_inst"] = d_inst.url prefill_data = req_data.copy() prefill_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": f"xfer-{request_id}", } prefill_data["stream"] = False prefill_data["max_tokens"] = 1 prefill_data.pop("max_completion_tokens", None) prefill_data.pop("stream_options", None) p_headers = {**headers, "X-data-parallel-rank": "0"} p_inst.ongoing_tokens += input_length breakdown["t_prefill_sent"] = _time.monotonic() if global_args.fire_and_forget: asyncio.create_task(_send_prefill_async( p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown)) else: try: resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) breakdown["t_prefill_done"] = _time.monotonic() resp.raise_for_status() await resp.aclose() p_inst.record_prefix(token_ids) except Exception as e: breakdown["t_prefill_done"] = _time.monotonic() breakdown["prefill_error"] = True _breakdown_log.append(breakdown) raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") finally: p_inst.ongoing_tokens -= input_length # Send decode d_inst.ongoing_tokens += input_length parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}" decode_data = req_data.copy() decode_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, "remote_bootstrap_addr": bootstrap_addr, "remote_engine_id": p_inst.engine_id.get(0, ""), "transfer_id": f"xfer-{request_id}", } breakdown["t_decode_sent"] = _time.monotonic() async def generate(): first_token = True try: async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): if first_token: breakdown["t_first_token"] = _time.monotonic() first_token = False yield chunk finally: breakdown["t_done"] = _time.monotonic() d_inst.ongoing_tokens -= input_length _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="application/json") @app.get("/breakdown") async def get_breakdown(): """Return per-request breakdown data for analysis.""" return _breakdown_log @app.get("/stats") async def get_stats(): """Return per-instance live state for debugging.""" instances = combined_instances or prefill_instances + decode_instances return [{ "url": inst.url, "ongoing_tokens": inst.ongoing_tokens, "pending_prefill_tokens": inst.pending_prefill_tokens, "ongoing_decode_tokens": inst.ongoing_decode_tokens, "num_requests": inst.num_requests, "cached_blocks": len(inst.cached_blocks), } for inst in instances] def parse_args(): p = argparse.ArgumentParser(description="Unified cache-aware global scheduler") p.add_argument("--port", type=int, default=8000) p.add_argument("--host", type=str, default="0.0.0.0") p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs") p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw", help="PD-Sep prefill: URL [bootstrap_port]") p.add_argument("--decode", nargs=1, action="append", dest="decode_raw", help="PD-Sep decode: URL") p.add_argument("--fire-and-forget", action="store_true", help="Send prefill async, don't await before decode") p.add_argument("--heavy-threshold", type=int, default=20000, help="New tokens threshold for HEAVY classification (adaptive offload)") p.add_argument("--offload", action="store_true", help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)") p.add_argument("--bootstrap-ports", type=str, default="", help="Comma-separated bootstrap ports for combined instances (for offload mode)") p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"], help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)") args = p.parse_args() args.prefill = [] if args.prefill_raw: for entry in args.prefill_raw: url = entry[0] bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None args.prefill.append((url, bp)) args.decode = [e[0] for e in (args.decode_raw or [])] if not args.combined and not args.prefill: p.error("Must specify either --combined or --prefill/--decode") return args if __name__ == "__main__": global_args = parse_args() HEAVY_THRESHOLD = global_args.heavy_threshold uvicorn.run(app, host=global_args.host, port=global_args.port)