Adaptive v2 (selective Mooncake offload): worse than baseline
Implemented --offload mode: HEAVY requests (>20k new tokens) get P on least-loaded instance, KV via Mooncake RDMA, D on session-sticky instance. WARM/MEDIUM stay co-located (no KV transfer). All 8 instances run kv_both. Result (200 req, same instances, fresh restart): Baseline (no offload): TTFT=1.073 TPOT90=0.074 E2E=5.086 Offload HEAVY: TTFT=1.462 TPOT90=0.077 E2E=6.847 Delta: +36% +4% +35% Conclusion: even selective KV transfer (only 44% of requests) adds more overhead than the isolation benefit provides. On single-machine 8 GPU, PD-combined with hybrid routing is strictly optimal. No form of KV transfer — full PD-sep, selective offload, or otherwise — improves over co-located serving for this workload. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -148,10 +148,15 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
if global_args.combined:
|
||||
is_pd_sep = False
|
||||
for url in global_args.combined:
|
||||
combined_instances.append(InstanceState(url))
|
||||
bp_list = [int(p) for p in global_args.bootstrap_ports.split(",") if p.strip()] if global_args.bootstrap_ports else []
|
||||
for i, url in enumerate(global_args.combined):
|
||||
bp = bp_list[i] if i < len(bp_list) else None
|
||||
combined_instances.append(InstanceState(url, bp))
|
||||
if global_args.offload and bp_list:
|
||||
await init_prefill_bootstrap(combined_instances, app.state.ready)
|
||||
else:
|
||||
app.state.ready.set()
|
||||
print(f"Combined mode: {len(combined_instances)} instances")
|
||||
print(f"Combined mode: {len(combined_instances)} instances, offload={'ON' if global_args.offload else 'OFF'}")
|
||||
else:
|
||||
is_pd_sep = True
|
||||
for url, bp in global_args.prefill:
|
||||
@@ -204,12 +209,13 @@ async def _handle(request: Request, api: str):
|
||||
|
||||
|
||||
async def _handle_combined(api, req_data, token_ids, input_length, session_id, headers):
|
||||
"""Combined mode with adaptive prefill offload.
|
||||
"""Combined mode with adaptive prefill offload (v2).
|
||||
|
||||
WARM/MEDIUM: route by cache-hit + load balance (co-located P+D).
|
||||
HEAVY: route to instance with least decode load, avoiding decode disruption.
|
||||
WARM/MEDIUM: route to best instance, co-located P+D (no KV transfer).
|
||||
HEAVY (kv_both mode): P on least-loaded instance, KV via Mooncake, D on
|
||||
session-sticky instance. Only works if instances have kv_role=kv_both.
|
||||
Falls back to co-located if --no-offload or instances lack Mooncake.
|
||||
"""
|
||||
# Estimate new tokens after cache
|
||||
best_inst, best_idx = pick_instance(combined_instances, token_ids, session_id,
|
||||
input_length, session_affinity)
|
||||
cache_hit = best_inst.estimate_cache_hit(token_ids)
|
||||
@@ -223,19 +229,33 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
"t_proxy_recv": _time.monotonic(),
|
||||
}
|
||||
|
||||
if estimated_new >= HEAVY_THRESHOLD:
|
||||
# HEAVY: pick instance with least ongoing_decode_tokens
|
||||
# This avoids sending heavy prefill to an instance busy decoding
|
||||
inst = min(combined_instances, key=lambda x: x.ongoing_decode_tokens)
|
||||
idx = combined_instances.index(inst)
|
||||
breakdown["route_class"] = "HEAVY"
|
||||
use_offload = (estimated_new >= HEAVY_THRESHOLD and global_args.offload
|
||||
and len(combined_instances) >= 2)
|
||||
|
||||
if use_offload:
|
||||
# HEAVY with offload: P on least-loaded, D on session-sticky (best_inst)
|
||||
p_inst = min(combined_instances, key=lambda x: x.ongoing_tokens)
|
||||
d_inst = best_inst
|
||||
if p_inst is d_inst:
|
||||
# Pick second-least-loaded for P
|
||||
sorted_by_load = sorted(combined_instances, key=lambda x: x.ongoing_tokens)
|
||||
p_inst = sorted_by_load[0] if sorted_by_load[0] is not d_inst else sorted_by_load[1]
|
||||
|
||||
breakdown["route_class"] = "HEAVY_OFFLOAD"
|
||||
breakdown["p_inst"] = p_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
if session_id:
|
||||
session_affinity[session_id] = idx
|
||||
session_affinity[session_id] = combined_instances.index(d_inst)
|
||||
|
||||
return await _handle_heavy_offload(api, req_data, headers, token_ids,
|
||||
input_length, p_inst, d_inst, breakdown)
|
||||
else:
|
||||
if estimated_new >= HEAVY_THRESHOLD:
|
||||
breakdown["route_class"] = "HEAVY_COLO"
|
||||
else:
|
||||
inst = best_inst
|
||||
idx = best_idx
|
||||
breakdown["route_class"] = "WARM" if estimated_new < 5000 else "MEDIUM"
|
||||
|
||||
inst = best_inst
|
||||
breakdown["routed_to"] = inst.url
|
||||
inst.ongoing_tokens += input_length
|
||||
|
||||
@@ -244,7 +264,6 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
try:
|
||||
async with inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
# Once streaming starts, this instance is in "decode phase"
|
||||
inst.ongoing_decode_tokens += input_length
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if first_token:
|
||||
@@ -261,6 +280,76 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
p_inst, d_inst, breakdown):
|
||||
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst."""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
|
||||
# Step 1: Await prefill on p_inst
|
||||
p_inst.ongoing_tokens += input_length
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
try:
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
resp = await p_inst.client.post(api, json=prefill_data, headers=p_headers)
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
p_inst.record_prefix(token_ids)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["error"] = str(e)
|
||||
_breakdown_log.append(breakdown)
|
||||
raise HTTPException(status_code=502, detail="Prefill failed: %s" % e)
|
||||
finally:
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
|
||||
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
|
||||
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
}
|
||||
|
||||
async def generate():
|
||||
first_token = True
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=decode_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if first_token:
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
finally:
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="application/json")
|
||||
|
||||
|
||||
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
|
||||
input_length, breakdown):
|
||||
"""Fire-and-forget prefill: send and don't block caller."""
|
||||
@@ -376,6 +465,10 @@ def parse_args():
|
||||
help="Send prefill async, don't await before decode")
|
||||
p.add_argument("--heavy-threshold", type=int, default=20000,
|
||||
help="New tokens threshold for HEAVY classification (adaptive offload)")
|
||||
p.add_argument("--offload", action="store_true",
|
||||
help="Enable Mooncake KV offload for HEAVY requests (requires kv_both instances)")
|
||||
p.add_argument("--bootstrap-ports", type=str, default="",
|
||||
help="Comma-separated bootstrap ports for combined instances (for offload mode)")
|
||||
args = p.parse_args()
|
||||
|
||||
args.prefill = []
|
||||
|
||||
Reference in New Issue
Block a user