Partial remote prefill: C_s exports cache, D computes new tokens locally

vLLM Mooncake patch:
- get_num_new_matched_tokens: support remote_num_tokens parameter for
  partial remote prefill (pull N tokens from remote, compute rest locally)
- update_state_after_alloc: only allocate receive blocks for external portion

Proxy _handle_heavy_offload rewrite:
- Step 1: C_s exports ONLY cached blocks (truncated prompt, 0 compute)
- Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
- C_s's blocks auto-freed by Mooncake delay_free after D confirms receipt

This enables true session migration: C_s releases cache, D takes over.
C_s's GPU is freed immediately (no compute), vs old approach where C_s
had to do full prefill (1-15s GPU occupancy).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-23 20:04:13 +08:00
parent be273f7f27
commit ea5149726c
2 changed files with 69 additions and 39 deletions

View File

@@ -299,14 +299,17 @@ class MooncakeConnectorScheduler:
return 0, False
if params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert not self.is_kv_producer
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
# Partial remote prefill: only pull remote_num_tokens from remote,
# compute the rest locally. Falls back to full remote prefill
# when remote_num_tokens is not set.
remote_total = params.get("remote_num_tokens", len(token_ids))
remote_total = min(remote_total, len(token_ids))
count = max(0, remote_total - num_computed_tokens)
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
@@ -330,13 +333,20 @@ class MooncakeConnectorScheduler:
p in params
for p in ("remote_engine_id", "remote_bootstrap_addr", "transfer_id")
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
)
# Get unhashed blocks to pull from remote.
if num_external_tokens > 0:
all_unhashed = blocks.get_unhashed_block_ids()
# Partial remote prefill: only receive blocks for the
# external portion, leave the rest for local compute.
if params.get("remote_num_tokens") is not None:
block_size = self.vllm_config.cache_config.block_size
num_remote_blocks = (
(num_external_tokens + block_size - 1) // block_size
)
local_block_ids = all_unhashed[:num_remote_blocks]
else:
local_block_ids = all_unhashed
else:
local_block_ids = []
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
logger.warning(