"""Unified cache-aware + token-level load-balanced global scheduler. Supports two modes: --combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer) --prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer) Routing policies (--policy): linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens lmetric: score = P_tokens * BS (LMetric, OSDI'26) P_tokens = pending_prefill_tokens + new_uncached_tokens BS = num_requests (waiting + running) Session affinity: multi-turn sessions stick to same instance (all policies). """ import argparse import asyncio import json import os import time as _time import urllib.parse import uuid from collections import OrderedDict from contextlib import asynccontextmanager from dataclasses import dataclass import httpx MAX_STREAM_RETRIES = 3 RETRY_DELAY_S = 0.5 import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse BLOCK_SIZE = 512 CACHE_HIT_ALPHA = 1.0 @dataclass class Settings: """Runtime-tunable knobs. Populated from argparse in __main__. All routing/offload code reads from the SETTINGS singleton so that CLI overrides survive even when the module is imported as a library (e.g. by tests/) and __main__ does not run. """ prefill_throughput: float = 7000.0 # tokens/s per GPU (measured on H20) rdma_overhead_s: float = 0.1 # RDMA PUSH overhead (~10-50ms measured) cache_capacity_blocks: int = 200000 # per-instance LRU cap on shadow cached_blocks heavy_threshold: int = 20000 overload_factor: float = 2.0 max_offload_inflight: int = 4 cache_gate_ratio: float = 0.0 decode_iteration_s: float = 0.05 # per-request decode iteration cost (H20) SETTINGS = Settings() class InstanceState: def __init__(self, url: str, bootstrap_port: int | None = None): self.url = url self.bootstrap_port = bootstrap_port self.client = httpx.AsyncClient( timeout=None, base_url=url, limits=httpx.Limits(max_connections=None, max_keepalive_connections=None), ) self.ongoing_tokens = 0 self.ongoing_decode_tokens = 0 # subset: tokens in decode phase self.pending_prefill_tokens = 0 # tokens for requests still in prefill self.num_requests = 0 # total in-flight requests (waiting + running) self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others self.engine_id: dict[int, str] = {} self.dp_size = 1 # OrderedDict acts as an LRU keyed by block hash; value is unused. self.cached_blocks: OrderedDict[int, None] = OrderedDict() def estimate_cache_hit(self, token_ids: list[int] | None) -> int: if not token_ids or len(token_ids) < BLOCK_SIZE: return 0 hit = 0 for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) if bh in self.cached_blocks: self.cached_blocks.move_to_end(bh) # LRU touch on hit hit += BLOCK_SIZE else: break return hit def record_prefix(self, token_ids: list[int] | None): if not token_ids: return for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE): bh = hash(tuple(token_ids[i:i + BLOCK_SIZE])) if bh in self.cached_blocks: self.cached_blocks.move_to_end(bh) else: self.cached_blocks[bh] = None if len(self.cached_blocks) > SETTINGS.cache_capacity_blocks: self.cached_blocks.popitem(last=False) def _p_offload_penalty(inst: InstanceState) -> int: """Penalty for PD-sep mode routing (legacy).""" if inst.active_p_offloads <= 0: return 0 return inst.active_p_offloads * SETTINGS.heavy_threshold def snapshot_workers( instances: list[InstanceState], token_ids: list[int] | None = None, input_length: int = 0, ) -> list[dict]: """Per-worker state at route-decision time. All routing-relevant counters plus the score each policy would have produced for `input_length` if it were dispatched now. Cheap enough to call on every request; B3 hot-spot analysis depends on this being captured per decision. """ snap: list[dict] = [] for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) if token_ids else 0 new_prefill = max(0, input_length - cache_hit) snap.append({ "idx": i, "url": inst.url, "ongoing_tokens": inst.ongoing_tokens, "ongoing_decode_tokens": inst.ongoing_decode_tokens, "pending_prefill_tokens": inst.pending_prefill_tokens, "num_requests": inst.num_requests, "active_p_offloads": inst.active_p_offloads, "cached_blocks": len(inst.cached_blocks), "cache_hit": cache_hit, "new_prefill": new_prefill, "score_linear": (inst.ongoing_tokens + _p_offload_penalty(inst) - CACHE_HIT_ALPHA * cache_hit), "score_lmetric": (inst.pending_prefill_tokens + new_prefill) * inst.num_requests, }) return snap def pick_instance(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """Session-sticky with load-aware override. Turn 2+: use session affinity UNLESS pinned instance is overloaded or busy with P-role offloads, in which case pick least-loaded. Turn 1: pick instance with best score (load + cache combined). Instances doing P-role offloads get a large penalty to steer WARM/MEDIUM traffic away. """ avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0) if session_id and session_id in affinity: idx = affinity[session_id] if idx < len(instances): inst = instances[idx] if (inst.ongoing_tokens <= avg_load * SETTINGS.overload_factor and inst.active_p_offloads == 0): return inst, idx best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) score = (inst.ongoing_tokens + _p_offload_penalty(inst) - CACHE_HIT_ALPHA * cache_hit) if score < best_score: best_score = score best_idx = i if session_id: affinity[session_id] = best_idx return instances[best_idx], best_idx def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int]) -> tuple[InstanceState, int]: """LMetric routing: score = P_tokens × BS (OSDI'26). Pure per-request load-based routing, no session affinity (the session_id/affinity args are accepted for signature compatibility with pick_instance/pick_instance_unified_hybrid but ignored). P = pending_prefill_tokens + (input_length - cache_hit) BS = num_requests (current batch size) """ best_idx, best_score = 0, float("inf") for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) new_prefill = max(0, input_length - cache_hit) p_tokens = inst.pending_prefill_tokens + new_prefill bs = inst.num_requests score = p_tokens * bs if score < best_score: best_score = score best_idx = i return instances[best_idx], best_idx _unified_fallback_rr_counter = 0 def pick_instance_unified_hybrid( instances: list[InstanceState], token_ids: list[int] | None, session_id: str | None, input_length: int, affinity: dict[str, int], ) -> tuple[InstanceState, int, dict]: """Hybrid routing: high-cache affinity, else LMetric with tie-breaker. Affinity gate (both must hold to stick): - affinity instance cache_hit / input_length > 0.5 - affinity.num_requests <= avg_num_requests * SETTINGS.overload_factor Fallback ordering (when affinity not used): primary: score = P_tokens * BS (LMetric) secondary: new_uncached_tokens (prefer instance with most cache) tertiary: num_requests (prefer least-loaded) quaternary: round-robin (avoid degenerate inst-0 pinning when BS=0 across the board) Returns (chosen, idx, decision_dict). decision_dict carries the review #7 breakdown fields so the caller can merge them verbatim. """ global _unified_fallback_rr_counter n = len(instances) avg_reqs = max(sum(i.num_requests for i in instances) / n, 1.0) decision: dict = { "decision": "lmetric_fallback", "affinity_idx": None, "chosen_idx": None, "affinity_cache_hit": None, "affinity_cache_ratio": None, "affinity_num_requests": None, "avg_num_requests": avg_reqs, "fallback_score": None, "tie_break_used": False, } if session_id and session_id in affinity: a_idx = affinity[session_id] if a_idx < n: a_inst = instances[a_idx] a_hit = a_inst.estimate_cache_hit(token_ids) a_ratio = a_hit / max(input_length, 1) decision["affinity_idx"] = a_idx decision["affinity_cache_hit"] = a_hit decision["affinity_cache_ratio"] = a_ratio decision["affinity_num_requests"] = a_inst.num_requests if (a_ratio > 0.5 and a_inst.num_requests <= avg_reqs * SETTINGS.overload_factor): decision["decision"] = "affinity" decision["chosen_idx"] = a_idx return a_inst, a_idx, decision keys: list[tuple[int, int, int, int]] = [] for i, inst in enumerate(instances): cache_hit = inst.estimate_cache_hit(token_ids) new_prefill = max(0, input_length - cache_hit) p_tokens = inst.pending_prefill_tokens + new_prefill bs = inst.num_requests score = p_tokens * bs keys.append((score, new_prefill, bs, i)) best_triple = min(k[:3] for k in keys) tied = [k for k in keys if k[:3] == best_triple] if len(tied) > 1: decision["tie_break_used"] = True _unified_fallback_rr_counter += 1 winner = tied[_unified_fallback_rr_counter % len(tied)] else: winner = tied[0] chosen_idx = winner[3] decision["fallback_score"] = winner[0] decision["chosen_idx"] = chosen_idx return instances[chosen_idx], chosen_idx, decision def _extract_output_token_ids_from_sse( buffer: str, chunk: bytes, ) -> tuple[str, list[int]]: """Extract vLLM streaming token_ids while preserving the raw stream.""" buffer += chunk.decode("utf-8", errors="ignore") complete = buffer.endswith("\n") or buffer.endswith("\r") lines = buffer.splitlines() if complete: buffer = "" elif lines: buffer = lines.pop() else: return buffer, [] output_ids: list[int] = [] for line in lines: line = line.strip() if not line.startswith("data:"): continue data = line[5:].strip() if not data or data == "[DONE]": continue try: payload = json.loads(data) except json.JSONDecodeError: continue choices = payload.get("choices", []) for choice in choices: token_ids = choice.get("token_ids") if isinstance(token_ids, list): output_ids.extend( int(t) for t in token_ids if isinstance(t, int) ) return buffer, output_ids def _realized_tokens( prompt_token_ids: list[int] | None, output_token_ids: list[int], ) -> list[int] | None: if prompt_token_ids is None: return None if not output_token_ids: return prompt_token_ids return prompt_token_ids + output_token_ids global_args = None combined_instances: list[InstanceState] = [] prefill_instances: list[InstanceState] = [] decode_instances: list[InstanceState] = [] # Session affinity is namespace-isolated: combined-mode and pd-sep mode index # different instance lists, so a shared dict could mis-route after a mode switch. session_affinity_combined: dict[str, int] = {} session_affinity_prefill: dict[str, int] = {} # Backwards-compat alias used by /stats etc. session_affinity = session_affinity_combined is_pd_sep = False _breakdown_log: list[dict] = [] _worker_state_log: list[dict] = [] async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event): for inst in instances: if inst.bootstrap_port is None: continue while True: try: await inst.client.get("/health") except Exception: await asyncio.sleep(1) continue parsed = urllib.parse.urlparse(str(inst.client.base_url)) url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query" resp = await inst.client.get(url) resp.raise_for_status() data = resp.json() for dp_rank, dp_entry in data.items(): inst.engine_id[int(dp_rank)] = dp_entry["engine_id"] inst.dp_size = len(data) print(f"Inited {inst.url} engine_ids={inst.engine_id}") break ready.set() async def _reconcile_loop(): """Periodic safety net for shadow state. StreamingResponse generators decrement load counters in their finally block, but if a client disconnects before the body is consumed the generator is never entered and the decrement is lost. Clamp negative drift every minute so router scores stay sane. This does not replace proper exact-state syncing with vLLM (see TODO.md item 6). """ while True: try: await asyncio.sleep(60) except asyncio.CancelledError: return for inst in combined_instances + prefill_instances + decode_instances: if inst.ongoing_tokens < 0: inst.ongoing_tokens = 0 if inst.ongoing_decode_tokens < 0: inst.ongoing_decode_tokens = 0 if inst.pending_prefill_tokens < 0: inst.pending_prefill_tokens = 0 if inst.num_requests < 0: inst.num_requests = 0 if inst.active_p_offloads < 0: inst.active_p_offloads = 0 def _verify_vllm_patch(): """Startup self-check for patches/0001-fix-kv-transfer-abort-race.patch. The patch turns an `assert req_id in self.requests` into a soft warn so that engines do not crash on the KV-transfer abort race (see REPORT §3.x). If somebody upgrades vLLM without re-applying the patch, the assert returns and elastic mode dies under load. Print a loud warning so we catch the regression before the first HEAVY request. """ try: import inspect from vllm.v1.core.sched.scheduler import Scheduler src = inspect.getsource(Scheduler) if "assert req_id in self.requests" in src: print("WARNING: vLLM scheduler still contains the unpatched " "`assert req_id in self.requests` line; expect engine " "death on KV-transfer abort race. Apply " "patches/0001-fix-kv-transfer-abort-race.patch.") else: print("vLLM patch self-check: kv-transfer-abort assert is patched.") except Exception as exc: print(f"vLLM patch self-check skipped: {exc!r}") @asynccontextmanager async def lifespan(app: FastAPI): global is_pd_sep app.state.ready = asyncio.Event() _verify_vllm_patch() reconcile_task = asyncio.create_task(_reconcile_loop()) if global_args.combined: is_pd_sep = False bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else [] for i, url in enumerate(global_args.combined): bp = bp_list[i] if i < len(bp_list) else None combined_instances.append(InstanceState(url, bp)) # Bootstrap combined instances for offload (need engine_ids for KV transfer) if global_args.offload and bp_list: await init_prefill_bootstrap(combined_instances, app.state.ready) else: app.state.ready.set() policy = getattr(global_args, 'policy', 'linear') print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}") else: is_pd_sep = True for url, bp in global_args.prefill: prefill_instances.append(InstanceState(url, bp)) for url in global_args.decode: decode_instances.append(InstanceState(url)) await init_prefill_bootstrap(prefill_instances, app.state.ready) print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D") yield reconcile_task.cancel() try: await reconcile_task except asyncio.CancelledError: pass for inst in combined_instances + prefill_instances + decode_instances: await inst.client.aclose() app = FastAPI(lifespan=lifespan) @app.post("/v1/completions") async def handle_completions(request: Request): return await _handle(request, "/v1/completions") @app.post("/v1/chat/completions") async def handle_chat(request: Request): return await _handle(request, "/v1/chat/completions") async def _handle(request: Request, api: str): if not app.state.ready.is_set(): raise HTTPException(status_code=503, detail="Service Unavailable") req_data = await request.json() incoming_rid = request.headers.get("X-Request-Id") request_id = incoming_rid or str(uuid.uuid4()) prompt = req_data.get("prompt") token_ids = prompt if isinstance(prompt, list) else None input_length = len(token_ids) if token_ids else 0 session_id = request.headers.get("X-Session-Id") headers = {"X-Request-Id": request_id} api_key = os.environ.get("OPENAI_API_KEY") if api_key: headers["Authorization"] = f"Bearer {api_key}" if is_pd_sep: return await _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers) else: return await _handle_combined(api, req_data, token_ids, input_length, session_id, headers) async def _handle_local_request(api, req_data, headers, token_ids, input_length, chosen: InstanceState, estimated_new: int, breakdown: dict): breakdown.setdefault("route_class", "LOCAL") breakdown.setdefault("routed_to", chosen.url) chosen.ongoing_tokens += input_length chosen.pending_prefill_tokens += estimated_new chosen.num_requests += 1 async def generate(): prefill_done = False sse_buffer = "" output_token_ids: list[int] = [] try: for attempt in range(MAX_STREAM_RETRIES): try: async with chosen.client.stream("POST", api, json=req_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( sse_buffer, chunk) output_token_ids.extend(new_output_ids) if not prefill_done: chosen.pending_prefill_tokens -= estimated_new chosen.ongoing_decode_tokens += input_length breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token_unix"] = _time.time() prefill_done = True yield chunk chosen.record_prefix( _realized_tokens(token_ids, output_token_ids)) break except (httpx.ConnectError, httpx.RemoteProtocolError): if prefill_done or attempt >= MAX_STREAM_RETRIES - 1: raise await asyncio.sleep(RETRY_DELAY_S) finally: if not prefill_done: chosen.pending_prefill_tokens -= estimated_new else: chosen.ongoing_decode_tokens -= input_length chosen.ongoing_tokens -= input_length chosen.num_requests -= 1 breakdown["t_done"] = _time.monotonic() breakdown["t_done_unix"] = _time.time() _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="text/event-stream") async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers): """Route a /v1/* request among combined (PD-colocated) instances. --policy options: linear: cache_hit-aware load score + sticky session affinity. lmetric: P_tokens * BS (LMetric, OSDI'26). No session affinity. unified: hybrid — stick to affinity instance when cache_ratio > 0.5 and it is not overloaded; otherwise fall back to LMetric with a multi-key tie-breaker. PD-sep offload / PUSH migration is retired (see REPORT.md §3.9 and commits 4c583f2 / cc6e562: relaxed-gate and forced-migration variants both regressed E2E tail). Re-enabling requires a new transfer mechanism. """ policy = getattr(global_args, 'policy', 'linear') t_decision_unix = _time.time() request_id = headers.get("X-Request-Id", "") breakdown: dict = { "request_id": request_id, "session_id": session_id, "input_length": input_length, "t_proxy_recv": _time.monotonic(), "t_decision_unix": t_decision_unix, "policy": policy, } pre_decision_workers = snapshot_workers( combined_instances, token_ids, input_length) if policy == "lmetric": chosen, best_idx = pick_instance_lmetric( combined_instances, token_ids, session_id, input_length, session_affinity_combined) elif policy == "unified": chosen, best_idx, decision = pick_instance_unified_hybrid( combined_instances, token_ids, session_id, input_length, session_affinity_combined) breakdown.update(decision) if session_id: session_affinity_combined[session_id] = best_idx else: # linear (default) chosen, best_idx = pick_instance( combined_instances, token_ids, session_id, input_length, session_affinity_combined) chosen_snap = pre_decision_workers[best_idx] cache_hit = chosen_snap["cache_hit"] estimated_new = chosen_snap["new_prefill"] breakdown.update({ "cache_hit": cache_hit, "estimated_new_tokens": estimated_new, "route_class": "LOCAL", "routed_to": chosen.url, "chosen_idx": best_idx, "candidate_scores": pre_decision_workers, "chosen_score_linear": chosen_snap["score_linear"], "chosen_score_lmetric": chosen_snap["score_lmetric"], }) _worker_state_log.append({ "t_decision_unix": t_decision_unix, "request_id": request_id, "session_id": session_id, "policy": policy, "chosen_idx": best_idx, "workers": pre_decision_workers, }) return await _handle_local_request( api, req_data, headers, token_ids, input_length, chosen, estimated_new, breakdown) async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length, session_id, headers): """PD-Sep mode with per-stage breakdown profiling.""" t_decision_unix = _time.time() breakdown = { "request_id": request_id, "session_id": session_id, "input_length": input_length, "t_proxy_recv": _time.monotonic(), "t_decision_unix": t_decision_unix, "policy": "pd_sep", } pre_decision_p = snapshot_workers(prefill_instances, token_ids, input_length) pre_decision_d = snapshot_workers(decode_instances, token_ids, input_length) p_inst, p_idx = pick_instance(prefill_instances, token_ids, session_id, input_length, session_affinity_prefill) d_idx = min(range(len(decode_instances)), key=lambda i: decode_instances[i].ongoing_tokens) d_inst = decode_instances[d_idx] breakdown["p_inst"] = p_inst.url breakdown["d_inst"] = d_inst.url breakdown["candidate_scores_prefill"] = pre_decision_p breakdown["candidate_scores_decode"] = pre_decision_d breakdown["chosen_p_idx"] = p_idx breakdown["chosen_d_idx"] = d_idx _worker_state_log.append({ "t_decision_unix": t_decision_unix, "request_id": request_id, "session_id": session_id, "policy": "pd_sep", "chosen_p_idx": p_idx, "chosen_d_idx": d_idx, "workers_prefill": pre_decision_p, "workers_decode": pre_decision_d, }) prefill_data = req_data.copy() prefill_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "transfer_id": f"xfer-{request_id}", } prefill_data["stream"] = False prefill_data["max_tokens"] = 1 prefill_data["min_tokens"] = 1 prefill_data.pop("max_completion_tokens", None) prefill_data.pop("stream_options", None) p_headers = {**headers, "X-data-parallel-rank": "0"} p_inst.ongoing_tokens += input_length breakdown["t_prefill_sent"] = _time.monotonic() breakdown["t_prefill_sent_unix"] = _time.time() try: resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers) breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() resp.raise_for_status() await resp.aclose() p_inst.record_prefix(token_ids) except Exception as e: breakdown["t_prefill_done"] = _time.monotonic() breakdown["t_prefill_done_unix"] = _time.time() breakdown["prefill_error"] = True _breakdown_log.append(breakdown) raise HTTPException(status_code=502, detail=f"Prefill failed: {e}") finally: p_inst.ongoing_tokens -= input_length # Send decode d_inst.ongoing_tokens += input_length parsed = urllib.parse.urlparse(str(p_inst.client.base_url)) bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}" decode_data = req_data.copy() decode_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, "remote_bootstrap_addr": bootstrap_addr, "remote_engine_id": p_inst.engine_id.get(0, ""), "transfer_id": f"xfer-{request_id}", } breakdown["t_decode_sent"] = _time.monotonic() breakdown["t_decode_sent_unix"] = _time.time() async def generate(): first_token = True sse_buffer = "" output_token_ids: list[int] = [] try: async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp: resp.raise_for_status() async for chunk in resp.aiter_bytes(): sse_buffer, new_output_ids = _extract_output_token_ids_from_sse( sse_buffer, chunk) output_token_ids.extend(new_output_ids) if first_token: breakdown["t_first_token"] = _time.monotonic() breakdown["t_first_token_unix"] = _time.time() first_token = False yield chunk d_inst.record_prefix(_realized_tokens(token_ids, output_token_ids)) finally: breakdown["t_done"] = _time.monotonic() breakdown["t_done_unix"] = _time.time() d_inst.ongoing_tokens -= input_length _breakdown_log.append(breakdown) return StreamingResponse(generate(), media_type="application/json") @app.get("/breakdown") async def get_breakdown(): """Return per-request breakdown data for analysis.""" return _breakdown_log @app.get("/worker_state") async def get_worker_state(): """Return per-decision worker-state snapshot log (one entry per route decision).""" return _worker_state_log @app.get("/worker_state/latest") async def get_worker_state_latest(): """Return current per-worker state snapshot without recording it.""" if combined_instances: return { "t_unix": _time.time(), "mode": "combined", "workers": snapshot_workers(combined_instances), } return { "t_unix": _time.time(), "mode": "pd_sep", "workers_prefill": snapshot_workers(prefill_instances), "workers_decode": snapshot_workers(decode_instances), } @app.get("/stats") async def get_stats(): """Return per-instance live state for debugging.""" instances = combined_instances or prefill_instances + decode_instances return [{ "url": inst.url, "role": "combined", "ongoing_tokens": inst.ongoing_tokens, "pending_prefill_tokens": inst.pending_prefill_tokens, "ongoing_decode_tokens": inst.ongoing_decode_tokens, "num_requests": inst.num_requests, "active_p_offloads": inst.active_p_offloads, "cached_blocks": len(inst.cached_blocks), } for inst in instances] def parse_args(): p = argparse.ArgumentParser(description="Unified cache-aware global scheduler") p.add_argument("--port", type=int, default=8000) p.add_argument("--host", type=str, default="0.0.0.0") p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs") p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw", help="PD-Sep prefill: URL [bootstrap_port]") p.add_argument("--decode", nargs=1, action="append", dest="decode_raw", help="PD-Sep decode: URL") p.add_argument("--heavy-threshold", type=int, default=20000, help="New tokens threshold for HEAVY classification (adaptive offload)") p.add_argument("--offload", action="store_true", help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)") p.add_argument("--bootstrap-ports", type=str, default="", help="Comma-separated bootstrap ports for combined instances (for offload mode)") p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric", "unified"], help="Routing policy: linear (cache-aware), lmetric (P_tokens × BS), " "or unified (hybrid affinity + LMetric fallback)") p.add_argument("--overload-factor", type=float, default=2.0, help="Break session affinity when instance load > factor * avg") # The four flags below are accepted for bench.sh backward compatibility but # have no effect after the PD-sep offload path was retired (REPORT §3.9, # commits 4c583f2 / cc6e562). Removing them would break scripts/bench.sh and # scripts/legacy/*.sh which still pass them through. p.add_argument("--max-offload-inflight", type=int, default=4, help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--offload-mode", type=str, default="cached_prefill", choices=["direct_read", "cached_prefill"], help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--cache-gate-ratio", type=float, default=0.0, help="[DEPRECATED] PUSH offload retired; no effect") p.add_argument("--decode-iteration-s", type=float, default=0.05, help="[DEPRECATED] PUSH offload retired; no effect") args = p.parse_args() args.prefill = [] if args.prefill_raw: for entry in args.prefill_raw: url = entry[0] bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None args.prefill.append((url, bp)) args.decode = [e[0] for e in (args.decode_raw or [])] if not args.combined and not args.prefill: p.error("Must specify either --combined or --prefill/--decode") return args if __name__ == "__main__": global_args = parse_args() SETTINGS.heavy_threshold = global_args.heavy_threshold SETTINGS.overload_factor = global_args.overload_factor SETTINGS.max_offload_inflight = global_args.max_offload_inflight SETTINGS.cache_gate_ratio = global_args.cache_gate_ratio SETTINGS.decode_iteration_s = getattr(global_args, 'decode_iteration_s', 0.05) print("SETTINGS: throughput=%.0f rdma_overhead=%.2f offload=%s" % ( SETTINGS.prefill_throughput, SETTINGS.rdma_overhead_s, getattr(global_args, 'offload', False))) uvicorn.run(app, host=global_args.host, port=global_args.port)