Partial remote prefill: C_s exports cache, D computes new tokens locally

vLLM Mooncake patch:
- get_num_new_matched_tokens: support remote_num_tokens parameter for
  partial remote prefill (pull N tokens from remote, compute rest locally)
- update_state_after_alloc: only allocate receive blocks for external portion

Proxy _handle_heavy_offload rewrite:
- Step 1: C_s exports ONLY cached blocks (truncated prompt, 0 compute)
- Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
- C_s's blocks auto-freed by Mooncake delay_free after D confirms receipt

This enables true session migration: C_s releases cache, D takes over.
C_s's GPU is freed immediately (no compute), vs old approach where C_s
had to do full prefill (1-15s GPU occupancy).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-23 20:04:13 +08:00
parent be273f7f27
commit ea5149726c
2 changed files with 69 additions and 39 deletions

View File

@@ -392,52 +392,66 @@ PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request: prefill on p_inst, KV via Mooncake, decode on d_inst.
"""HEAVY request with cache-aware KV migration.
On prefill timeout/failure, falls back to co-located decode on d_inst.
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
On export failure, falls back to co-located prefill+decode on d_inst.
"""
request_id = headers.get("X-Request-Id", "")
estimated_new = breakdown.get("estimated_new_tokens", 0)
cache_hit = breakdown.get("cache_hit", 0)
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
# Step 1: Await prefill on p_inst (ongoing_tokens already reserved by caller)
breakdown["t_prefill_sent"] = _time.monotonic()
prefill_ok = False
# Step 1: C_s exports cached KV blocks
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
# (full prefix cache hit), then pushes cached blocks to Mooncake.
breakdown["t_export_sent"] = _time.monotonic()
export_ok = False
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
if cached_tokens > 0 and token_ids:
export_prompt = token_ids[:cached_tokens]
else:
export_prompt = token_ids
try:
prefill_data = req_data.copy()
prefill_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
export_data = {
"model": req_data.get("model", ""),
"prompt": export_prompt,
"max_tokens": 1,
"temperature": 0,
"stream": False,
"kv_transfer_params": {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
},
}
prefill_data["stream"] = False
prefill_data["max_tokens"] = 1
prefill_data.pop("max_completion_tokens", None)
prefill_data.pop("stream_options", None)
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await asyncio.wait_for(
p_inst.client.post(api, json=prefill_data, headers=p_headers),
p_inst.client.post(api, json=export_data, headers=p_headers),
timeout=PREFILL_TIMEOUT_S,
)
resp.raise_for_status()
await resp.aclose()
p_inst.record_prefix(token_ids)
breakdown["t_prefill_done"] = _time.monotonic()
prefill_ok = True
breakdown["t_export_done"] = _time.monotonic()
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
export_ok = True
except Exception as e:
breakdown["t_prefill_done"] = _time.monotonic()
breakdown["prefill_error"] = str(e)
breakdown["t_export_done"] = _time.monotonic()
breakdown["export_error"] = str(e)
finally:
# Always release P-instance resources exactly once
p_inst.ongoing_tokens -= input_length
p_inst.pending_prefill_tokens -= p_prefill_release
p_inst.num_requests -= 1
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
if not prefill_ok:
# Fallback: co-located prefill+decode on d_inst (no KV transfer)
# D already has ongoing_tokens and num_requests reserved by caller
if not export_ok:
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
d_inst.pending_prefill_tokens += estimated_new
@@ -466,9 +480,9 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
# Step 2: Stream decode on d_inst (pulls KV from Mooncake)
# D already has ongoing_tokens and num_requests reserved by caller
d_inst.ongoing_decode_tokens += input_length
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
exported_tokens = breakdown.get("exported_tokens", 0)
d_inst.pending_prefill_tokens += estimated_new
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
@@ -481,6 +495,7 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": exported_tokens,
}
async def generate():
@@ -490,18 +505,23 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
first_token = False
yield chunk
d_inst.record_prefix(token_ids)
finally:
if first_token:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.ongoing_decode_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate(), media_type="application/json")
return StreamingResponse(generate(), media_type="text/event-stream")
async def _send_prefill_async(p_inst, api, prefill_data, p_headers, token_ids,

View File

@@ -299,14 +299,17 @@ class MooncakeConnectorScheduler:
return 0, False
if params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert not self.is_kv_producer
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
# Partial remote prefill: only pull remote_num_tokens from remote,
# compute the rest locally. Falls back to full remote prefill
# when remote_num_tokens is not set.
remote_total = params.get("remote_num_tokens", len(token_ids))
remote_total = min(remote_total, len(token_ids))
count = max(0, remote_total - num_computed_tokens)
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
@@ -330,13 +333,20 @@ class MooncakeConnectorScheduler:
p in params
for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id")
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
)
# Get unhashed blocks to pull from remote.
if num_external_tokens > 0:
all_unhashed = blocks.get_unhashed_block_ids()
# Partial remote prefill: only receive blocks for the
# external portion, leave the rest for local compute.
if params.get("remote_num_tokens") is not None:
block_size = self.vllm_config.cache_config.block_size
num_remote_blocks = (
(num_external_tokens + block_size - 1) // block_size
)
local_block_ids = all_unhashed[:num_remote_blocks]
else:
local_block_ids = all_unhashed
else:
local_block_ids = []
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
logger.warning(