Partial remote prefill: C_s exports cache, D computes new tokens locally
vLLM Mooncake patch: - get_num_new_matched_tokens: support remote_num_tokens parameter for partial remote prefill (pull N tokens from remote, compute rest locally) - update_state_after_alloc: only allocate receive blocks for external portion Proxy _handle_heavy_offload rewrite: - Step 1: C_s exports ONLY cached blocks (truncated prompt, 0 compute) - Step 2: D pulls cached blocks + does local prefill for new tokens + decodes - C_s's blocks auto-freed by Mooncake delay_free after D confirms receipt This enables true session migration: C_s releases cache, D takes over. C_s's GPU is freed immediately (no compute), vs old approach where C_s had to do full prefill (1-15s GPU occupancy). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -392,52 +392,66 @@ PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
p_inst, d_inst, breakdown):
|
||||
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst.
|
||||
"""HEAVY request with cache-aware KV migration.
|
||||
|
||||
On prefill timeout/failure, falls back to co-located decode on d_inst.
|
||||
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
|
||||
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
|
||||
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
|
||||
|
||||
On export failure, falls back to co-located prefill+decode on d_inst.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
estimated_new = breakdown.get("estimated_new_tokens", 0)
|
||||
cache_hit = breakdown.get("cache_hit", 0)
|
||||
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
|
||||
|
||||
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
|
||||
breakdown["t_prefill_sent"] = _time.monotonic()
|
||||
prefill_ok = False
|
||||
# Step 1: C_s exports cached KV blocks
|
||||
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
|
||||
# (full prefix cache hit), then pushes cached blocks to Mooncake.
|
||||
breakdown["t_export_sent"] = _time.monotonic()
|
||||
export_ok = False
|
||||
|
||||
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
|
||||
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
|
||||
if cached_tokens > 0 and token_ids:
|
||||
export_prompt = token_ids[:cached_tokens]
|
||||
else:
|
||||
export_prompt = token_ids
|
||||
|
||||
try:
|
||||
prefill_data = req_data.copy()
|
||||
prefill_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
export_data = {
|
||||
"model": req_data.get("model", ""),
|
||||
"prompt": export_prompt,
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
"stream": False,
|
||||
"kv_transfer_params": {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
},
|
||||
}
|
||||
prefill_data["stream"] = False
|
||||
prefill_data["max_tokens"] = 1
|
||||
prefill_data.pop("max_completion_tokens", None)
|
||||
prefill_data.pop("stream_options", None)
|
||||
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
resp = await asyncio.wait_for(
|
||||
p_inst.client.post(api, json=prefill_data, headers=p_headers),
|
||||
p_inst.client.post(api, json=export_data, headers=p_headers),
|
||||
timeout=PREFILL_TIMEOUT_S,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
p_inst.record_prefix(token_ids)
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
prefill_ok = True
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
|
||||
export_ok = True
|
||||
except Exception as e:
|
||||
breakdown["t_prefill_done"] = _time.monotonic()
|
||||
breakdown["prefill_error"] = str(e)
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["export_error"] = str(e)
|
||||
finally:
|
||||
# Always release P-instance resources exactly once
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
p_inst.pending_prefill_tokens -= p_prefill_release
|
||||
p_inst.num_requests -= 1
|
||||
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
|
||||
|
||||
if not prefill_ok:
|
||||
# Fallback: co-located prefill+decode on d_inst (no KV transfer)
|
||||
# D already has ongoing_tokens and num_requests reserved by caller
|
||||
if not export_ok:
|
||||
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
|
||||
@@ -466,9 +480,9 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
|
||||
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
|
||||
|
||||
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
|
||||
# D already has ongoing_tokens and num_requests reserved by caller
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
|
||||
exported_tokens = breakdown.get("exported_tokens", 0)
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||
@@ -481,6 +495,7 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
"remote_num_tokens": exported_tokens,
|
||||
}
|
||||
|
||||
async def generate():
|
||||
@@ -490,18 +505,23 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if first_token:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
first_token = False
|
||||
yield chunk
|
||||
d_inst.record_prefix(token_ids)
|
||||
finally:
|
||||
if first_token:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate(), media_type="application/json")
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,
|
||||
|
||||
@@ -299,14 +299,17 @@ class MooncakeConnectorScheduler:
|
||||
return 0, False
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert not self.is_kv_producer
|
||||
token_ids = request.prompt_token_ids or []
|
||||
count = len(token_ids) - num_computed_tokens
|
||||
# Partial remote prefill: only pull remote_num_tokens from remote,
|
||||
# compute the rest locally. Falls back to full remote prefill
|
||||
# when remote_num_tokens is not set.
|
||||
remote_total = params.get("remote_num_tokens", len(token_ids))
|
||||
remote_total = min(remote_total, len(token_ids))
|
||||
count = max(0, remote_total - num_computed_tokens)
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
@@ -330,13 +333,20 @@ class MooncakeConnectorScheduler:
|
||||
p in params
|
||||
for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id")
|
||||
):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (
|
||||
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
)
|
||||
# Get unhashed blocks to pull from remote.
|
||||
if num_external_tokens > 0:
|
||||
all_unhashed = blocks.get_unhashed_block_ids()
|
||||
# Partial remote prefill: only receive blocks for the
|
||||
# external portion, leave the rest for local compute.
|
||||
if params.get("remote_num_tokens") is not None:
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
num_remote_blocks = (
|
||||
(num_external_tokens + block_size - 1) // block_size
|
||||
)
|
||||
local_block_ids = all_unhashed[:num_remote_blocks]
|
||||
else:
|
||||
local_block_ids = all_unhashed
|
||||
else:
|
||||
local_block_ids = []
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user