Files
agentic-kvc/scripts/cache_aware_proxy.py
Gahow Wang ea5149726c Partial remote prefill: C_s exports cache, D computes new tokens locally
vLLM Mooncake patch:
- get_num_new_matched_tokens: support remote_num_tokens parameter for
  partial remote prefill (pull N tokens from remote, compute rest locally)
- update_state_after_alloc: only allocate receive blocks for external portion

Proxy _handle_heavy_offload rewrite:
- Step 1: C_s exports ONLY cached blocks (truncated prompt, 0 compute)
- Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
- C_s's blocks auto-freed by Mooncake delay_free after D confirms receipt

This enables true session migration: C_s releases cache, D takes over.
C_s's GPU is freed immediately (no compute), vs old approach where C_s
had to do full prefill (1-15s GPU occupancy).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-23 20:04:13 +08:00

686 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified cache-aware + token-level load-balanced global scheduler.
Supports two modes:
--combined URL [URL ...]: PD co-located instances (normal vLLM, no KV transfer)
--prefill URL BP --decode URL: PD disaggregated instances (Mooncake KV transfer)
Routing policies (--policy):
linear (default): score = ongoing_tokens - ALPHA * cache_hit_tokens
lmetric: score = P_tokens * BS (LMetric, OSDI'26)
P_tokens = pending_prefill_tokens + new_uncached_tokens
BS = num_requests (waiting + running)
Session affinity: multi-turn sessions stick to same instance (all policies).
"""
import argparse
import asyncio
import os
import time as _time
import urllib.parse
import uuid
from contextlib import asynccontextmanager
import httpx
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
BLOCK_SIZE = 512
CACHE_HIT_ALPHA = 1.0
HEAVY_THRESHOLD = 20000 # default; overridden by --heavy-threshold
OVERLOAD_FACTOR = 2.0 # default; overridden by --overload-factor
MAX_OFFLOAD_INFLIGHT = 4 # cap concurrent P-role offloads
PREFILL_THROUGHPUT = 7000 # tokens/s per GPU (from H20 measurements)
RDMA_OVERHEAD_S = 2.0 # seconds of RDMA transfer + decode start overhead
class InstanceState:
def __init__(self, url: str, bootstrap_port: int | None = None):
self.url = url
self.bootstrap_port = bootstrap_port
self.client = httpx.AsyncClient(
timeout=None, base_url=url,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=None),
)
self.ongoing_tokens = 0
self.ongoing_decode_tokens = 0 # subset: tokens in decode phase
self.pending_prefill_tokens = 0 # tokens for requests still in prefill
self.num_requests = 0 # total in-flight requests (waiting + running)
self.active_p_offloads = 0 # number of HEAVY prefills this instance is doing for others
self.engine_id: dict[int, str] = {}
self.dp_size = 1
self.cached_blocks: set[int] = set()
def estimate_cache_hit(self, token_ids: list[int] | None) -> int:
if not token_ids or len(token_ids) < BLOCK_SIZE:
return 0
hit = 0
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
bh = hash(tuple(token_ids[i:i + BLOCK_SIZE]))
if bh in self.cached_blocks:
hit += BLOCK_SIZE
else:
break
return hit
def record_prefix(self, token_ids: list[int] | None):
if not token_ids:
return
for i in range(0, len(token_ids) - BLOCK_SIZE + 1, BLOCK_SIZE):
self.cached_blocks.add(hash(tuple(token_ids[i:i + BLOCK_SIZE])))
if len(self.cached_blocks) > 200000:
self.cached_blocks = set(list(self.cached_blocks)[-100000:])
# Cumulative token load per instance (for balanced session placement)
_inst_cumulative_tokens: list[int] = []
def _p_offload_penalty(inst: InstanceState) -> int:
"""Penalty for instances currently doing P-role offloaded prefills.
When an instance is busy with offloaded HEAVY prefills for other
instances, we want to steer WARM/MEDIUM requests away from it so
its GPU is dedicated to prefill (soft PD separation).
"""
if inst.active_p_offloads <= 0:
return 0
return inst.active_p_offloads * HEAVY_THRESHOLD
def pick_instance(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""Session-sticky with load-aware override.
Turn 2+: use session affinity UNLESS pinned instance is overloaded
or busy with P-role offloads, in which case pick least-loaded.
Turn 1: pick instance with best score (load + cache combined).
Instances doing P-role offloads get a large penalty to steer
WARM/MEDIUM traffic away.
"""
global _inst_cumulative_tokens
if not _inst_cumulative_tokens:
_inst_cumulative_tokens = [0] * len(instances)
avg_load = max(sum(i.ongoing_tokens for i in instances) / len(instances), 1.0)
if session_id and session_id in affinity:
idx = affinity[session_id]
if idx < len(instances):
inst = instances[idx]
if (inst.ongoing_tokens <= avg_load * OVERLOAD_FACTOR
and inst.active_p_offloads == 0):
return inst, idx
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
score = (inst.ongoing_tokens + _p_offload_penalty(inst)
- CACHE_HIT_ALPHA * cache_hit)
if score < best_score:
best_score = score
best_idx = i
_inst_cumulative_tokens[best_idx] += input_length
if session_id:
affinity[session_id] = best_idx
return instances[best_idx], best_idx
def pick_instance_lmetric(instances: list[InstanceState], token_ids: list[int] | None,
session_id: str | None, input_length: int,
affinity: dict[str, int]) -> tuple[InstanceState, int]:
"""LMetric routing: score = P_tokens × BS (OSDI'26).
Pure per-request load-based routing, no session affinity.
P = pending_prefill_tokens + (input_length - cache_hit)
BS = num_requests (current batch size)
"""
best_idx, best_score = 0, float("inf")
for i, inst in enumerate(instances):
cache_hit = inst.estimate_cache_hit(token_ids)
new_prefill = max(0, input_length - cache_hit)
p_tokens = inst.pending_prefill_tokens + new_prefill
bs = inst.num_requests
score = p_tokens * bs
if score < best_score:
best_score = score
best_idx = i
return instances[best_idx], best_idx
global_args = None
combined_instances: list[InstanceState] = []
prefill_instances: list[InstanceState] = []
decode_instances: list[InstanceState] = []
session_affinity: dict[str, int] = {}
is_pd_sep = False
_breakdown_log: list[dict] = []
async def init_prefill_bootstrap(instances: list[InstanceState], ready: asyncio.Event):
for inst in instances:
if inst.bootstrap_port is None:
continue
while True:
try:
await inst.client.get("/health")
except Exception:
await asyncio.sleep(1)
continue
parsed = urllib.parse.urlparse(str(inst.client.base_url))
url = f"http://{parsed.hostname}:{inst.bootstrap_port}/query"
resp = await inst.client.get(url)
resp.raise_for_status()
data = resp.json()
for dp_rank, dp_entry in data.items():
inst.engine_id[int(dp_rank)] = dp_entry["engine_id"]
inst.dp_size = len(data)
print(f"Inited {inst.url} engine_ids={inst.engine_id}")
break
ready.set()
@asynccontextmanager
async def lifespan(app: FastAPI):
global is_pd_sep
app.state.ready = asyncio.Event()
if global_args.combined:
is_pd_sep = False
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
for i, url in enumerate(global_args.combined):
bp = bp_list[i] if i < len(bp_list) else None
combined_instances.append(InstanceState(url, bp))
# Bootstrap combined instances for offload (need engine_ids for KV transfer)
if global_args.offload and bp_list:
await init_prefill_bootstrap(combined_instances, app.state.ready)
else:
app.state.ready.set()
policy = getattr(global_args, 'policy', 'linear')
print(f"Combined mode: {len(combined_instances)} instances, policy={policy}, offload={'ON' if global_args.offload else 'OFF'}")
else:
is_pd_sep = True
for url, bp in global_args.prefill:
prefill_instances.append(InstanceState(url, bp))
for url in global_args.decode:
decode_instances.append(InstanceState(url))
await init_prefill_bootstrap(prefill_instances, app.state.ready)
print(f"PD-Sep mode: {len(prefill_instances)}P + {len(decode_instances)}D")
yield
for inst in combined_instances + prefill_instances + decode_instances:
await inst.client.aclose()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle(request, "/v1/completions")
@app.post("/v1/chat/completions")
async def handle_chat(request: Request):
return await _handle(request, "/v1/chat/completions")
async def _handle(request: Request, api: str):
if not app.state.ready.is_set():
raise HTTPException(status_code=503, detail="Service Unavailable")
req_data = await request.json()
request_id = str(uuid.uuid4())
prompt = req_data.get("prompt")
token_ids = prompt if isinstance(prompt, list) else None
input_length = len(token_ids) if token_ids else 0
session_id = request.headers.get("X-Session-Id")
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if is_pd_sep:
return await _handle_pd_sep(api, req_data, request_id, token_ids,
input_length, session_id, headers)
else:
return await _handle_combined(api, req_data, token_ids,
input_length, session_id, headers)
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
"""Combined mode with V2 P2P offload.
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
HEAVY: C_s (session-sticky, has cache) does FAST prefill,
D (least-loaded C, D != C_s) pulls KV via Mooncake and decodes.
Offload only when D is meaningfully less loaded than C_s.
"""
policy = getattr(global_args, 'policy', 'linear') if global_args else 'linear'
picker = pick_instance_lmetric if policy == 'lmetric' else pick_instance
best_inst, best_idx = picker(combined_instances, token_ids, session_id,
input_length, session_affinity)
cache_hit = best_inst.estimate_cache_hit(token_ids)
estimated_new = max(0, input_length - cache_hit)
breakdown = {
"request_id": headers.get("X-Request-Id", ""),
"input_length": input_length,
"estimated_new_tokens": estimated_new,
"cache_hit": cache_hit,
"t_proxy_recv": _time.monotonic(),
}
# Runtime cost-model offload gate: compare co-located vs offload latency
# Co-located = queue(C_s) + prefill(new_tokens)
# Offload = queue(P) + prefill(P_new_tokens) + RDMA_overhead
offload_enabled = getattr(global_args, 'offload', False) and len(combined_instances) >= 2
use_offload = False
offload_reason = "offload_disabled"
if estimated_new >= HEAVY_THRESHOLD and offload_enabled:
cache_ratio = cache_hit / max(input_length, 1)
current_offloads = sum(c.active_p_offloads for c in combined_instances)
# P candidate: least-loaded instance (excluding C_s)
p_candidate = min((c for c in combined_instances if c is not best_inst),
key=lambda c: c.ongoing_tokens)
# D candidate: least-loaded excluding both C_s and P
remaining = [c for c in combined_instances if c is not best_inst and c is not p_candidate]
d_candidate = min(remaining, key=lambda c: c.ongoing_tokens) if remaining else p_candidate
# Cost model: compare co-located vs offload expected latency
# Co-located: queue on C_s + prefill new tokens on C_s
cs_queue = best_inst.pending_prefill_tokens / PREFILL_THROUGHPUT
colocated_cost = cs_queue + estimated_new / PREFILL_THROUGHPUT
# Offload: prefill on P (may or may not have cache) + RDMA + decode start
p_queue = p_candidate.pending_prefill_tokens / PREFILL_THROUGHPUT
p_cache_hit = p_candidate.estimate_cache_hit(token_ids) if token_ids else 0
p_new_tokens = max(0, input_length - p_cache_hit)
offload_cost = p_queue + p_new_tokens / PREFILL_THROUGHPUT + RDMA_OVERHEAD_S
breakdown["cache_ratio"] = cache_ratio
breakdown["colocated_cost"] = round(colocated_cost, 2)
breakdown["offload_cost"] = round(offload_cost, 2)
if current_offloads >= MAX_OFFLOAD_INFLIGHT:
offload_reason = "cap_reached_%d" % current_offloads
elif offload_cost < colocated_cost:
use_offload = True
offload_reason = "cost_model_%.1fvs%.1f" % (offload_cost, colocated_cost)
else:
offload_reason = "colocated_cheaper_%.1fvs%.1f" % (colocated_cost, offload_cost)
if use_offload:
p_inst = p_candidate
d_inst = d_candidate
d_idx = combined_instances.index(d_inst)
# Accounting: reserve both P and D immediately so router sees the load
p_new = max(0, input_length - p_inst.estimate_cache_hit(token_ids)) if token_ids else input_length
p_inst.ongoing_tokens += input_length
p_inst.pending_prefill_tokens += p_new
p_inst.num_requests += 1
p_inst.active_p_offloads += 1
breakdown["p_new_tokens"] = p_new
d_inst.ongoing_tokens += input_length
d_inst.num_requests += 1
breakdown["route_class"] = "HEAVY_OFFLOAD"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
breakdown["offload_reason"] = offload_reason
elif estimated_new < 5000:
breakdown["route_class"] = "WARM"
else:
breakdown["route_class"] = "MEDIUM"
inst = best_inst
breakdown["routed_to"] = inst.url
breakdown["policy"] = policy
inst.ongoing_tokens += input_length
inst.pending_prefill_tokens += estimated_new
inst.num_requests += 1
async def generate():
prefill_done = False
try:
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
inst.record_prefix(token_ids)
finally:
if not prefill_done:
inst.pending_prefill_tokens -= estimated_new
else:
inst.ongoing_decode_tokens -= input_length
inst.ongoing_tokens -= input_length
inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request with cache-aware KV migration.
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
On export failure, falls back to co-located prefill+decode on d_inst.
"""
request_id = headers.get("X-Request-Id", "")
estimated_new = breakdown.get("estimated_new_tokens", 0)
cache_hit = breakdown.get("cache_hit", 0)
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
# Step 1: C_s exports cached KV blocks
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
# (full prefix cache hit), then pushes cached blocks to Mooncake.
breakdown["t_export_sent"] = _time.monotonic()
export_ok = False
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
if cached_tokens > 0 and token_ids:
export_prompt = token_ids[:cached_tokens]
else:
export_prompt = token_ids
try:
export_data = {
"model": req_data.get("model", ""),
"prompt": export_prompt,
"max_tokens": 1,
"temperature": 0,
"stream": False,
"kv_transfer_params": {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
},
}
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await asyncio.wait_for(
p_inst.client.post(api, json=export_data, headers=p_headers),
timeout=PREFILL_TIMEOUT_S,
)
resp.raise_for_status()
await resp.aclose()
breakdown["t_export_done"] = _time.monotonic()
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
export_ok = True
except Exception as e:
breakdown["t_export_done"] = _time.monotonic()
breakdown["export_error"] = str(e)
finally:
p_inst.ongoing_tokens -= input_length
p_inst.pending_prefill_tokens -= p_prefill_release
p_inst.num_requests -= 1
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
if not export_ok:
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
d_inst.pending_prefill_tokens += estimated_new
async def generate_fallback():
prefill_done = False
try:
async with d_inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
exported_tokens = breakdown.get("exported_tokens", 0)
d_inst.pending_prefill_tokens += estimated_new
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": exported_tokens,
}
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="text/event-stream")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
input_length, breakdown):
"""Fire-and-forget prefill: send and don't block caller."""
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
finally:
p_inst.ongoing_tokens -= input_length
async def _handle_pd_sep(api, req_data, request_id, token_ids, input_length,
session_id, headers):
"""PD-Sep mode with per-stage breakdown profiling."""
breakdown = {
"request_id": request_id,
"input_length": input_length,
"t_proxy_recv": _time.monotonic(),
}
p_inst, _ = pick_instance(prefill_instances, token_ids, session_id,
input_length, session_affinity)
d_inst = min(decode_instances, key=lambda x: x.ongoing_tokens)
breakdown["p_inst"] = p_inst.url
breakdown["d_inst"] = d_inst.url
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": f"xfer-{request_id}",
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
p_inst.ongoing_tokens += input_length
breakdown["t_prefill_sent"] = _time.monotonic()
if global_args.fire_and_forget:
asyncio.create_task(_send_prefill_async(
p_inst, api, prefill_data, p_headers, token_ids, input_length, breakdown))
else:
try:
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
breakdown["t_prefill_done"] = _time.monotonic()
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = True
_breakdown_log.append(breakdown)
raise HTTPException(status_code=502, detail=f"Prefill failed: {e}")
finally:
p_inst.ongoing_tokens -= input_length
# Send decode
d_inst.ongoing_tokens += input_length
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = f"http://{parsed.hostname}:{p_inst.bootstrap_port}"
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": f"xfer-{request_id}",
}
breakdown["t_decode_sent"] = _time.monotonic()
async def generate():
first_token = True
try:
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
finally:
breakdown["t_done"] = _time.monotonic()
d_inst.ongoing_tokens -= input_length
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
@app.get("/breakdown")
async def get_breakdown():
"""Return per-request breakdown data for analysis."""
return _breakdown_log
@app.get("/stats")
async def get_stats():
"""Return per-instance live state for debugging."""
instances = combined_instances or prefill_instances + decode_instances
return [{
"url": inst.url,
"role": "combined",
"ongoing_tokens": inst.ongoing_tokens,
"pending_prefill_tokens": inst.pending_prefill_tokens,
"ongoing_decode_tokens": inst.ongoing_decode_tokens,
"num_requests": inst.num_requests,
"active_p_offloads": inst.active_p_offloads,
"cached_blocks": len(inst.cached_blocks),
} for inst in instances]
def parse_args():
p = argparse.ArgumentParser(description="Unified cache-aware global scheduler")
p.add_argument("--port", type=int, default=8000)
p.add_argument("--host", type=str, default="0.0.0.0")
p.add_argument("--combined", nargs="+", help="Combined mode: list of instance URLs")
p.add_argument("--prefill", nargs="+", action="append", dest="prefill_raw",
help="PD-Sep prefill: URL [bootstrap_port]")
p.add_argument("--decode", nargs=1, action="append", dest="decode_raw",
help="PD-Sep decode: URL")
p.add_argument("--fire-and-forget", action="store_true",
help="Send prefill async, don't await before decode")
p.add_argument("--heavy-threshold", type=int, default=20000,
help="New tokens threshold for HEAVY classification (adaptive offload)")
p.add_argument("--offload", action="store_true",
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
p.add_argument("--bootstrap-ports", type=str, default="",
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
p.add_argument("--policy", type=str, default="linear", choices=["linear", "lmetric"],
help="Routing policy: linear (default) or lmetric (P_tokens × BS, OSDI'26)")
p.add_argument("--overload-factor", type=float, default=2.0,
help="Break session affinity when instance load > factor * avg")
args = p.parse_args()
args.prefill = []
if args.prefill_raw:
for entry in args.prefill_raw:
url = entry[0]
bp = int(entry[1]) if len(entry) > 1 and entry[1].lower() != "none" else None
args.prefill.append((url, bp))
args.decode = [e[0] for e in (args.decode_raw or [])]
if not args.combined and not args.prefill:
p.error("Must specify either --combined or --prefill/--decode")
return args
if __name__ == "__main__":
global_args = parse_args()
HEAVY_THRESHOLD = global_args.heavy_threshold
OVERLOAD_FACTOR = global_args.overload_factor
uvicorn.run(app, host=global_args.host, port=global_args.port)