Direct RDMA read: D reads cached KV from C's GPU without C's scheduler

Complete implementation of direct RDMA read for KV cache migration:

vLLM Mooncake connector (mooncake_connector.py):
- PullReqMeta: add direct_read flag + block_hashes
- MooncakeConnectorMetadata: add hash_table_updates/removals for
  scheduler->worker block hash sync
- MooncakeConnectorScheduler: set_block_pool() to access BlockPool,
  build_connector_meta() computes hash table deltas each step,
  update_state_after_alloc() captures request block hashes for direct_read
- MooncakeConnectorWorker: _start_direct_read() + _direct_read_single()
  implements D-side RDMA read via batch_transfer_sync_read, with
  HTTP query/unpin to C's bootstrap server

Bootstrap server (mooncake_utils.py):
- POST /query_blocks: look up block hashes, return block_ids + GPU layout
- POST /unpin_blocks: release pin tracking
- set_worker_kv_info(): register GPU addresses at init
- update_hash_table(): receive scheduler deltas each step

Scheduler (scheduler.py):
- One-line hookup: pass block_pool to connector after KVCacheManager init

Proxy (cache_aware_proxy.py):
- _handle_direct_read_offload: sends request ONLY to D with
  direct_read=True + remote_bootstrap_addr. No request to C at all.
- C's scheduler is completely uninvolved (0 GPU time on C)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-23 21:02:13 +08:00
parent 020be9f444
commit a7df84bd3b
4 changed files with 271 additions and 123 deletions

View File

@@ -351,33 +351,28 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
offload_reason = "colocated_cheaper_%.1fvs%.1f" % (colocated_cost, offload_cost)
if use_offload:
p_inst = p_candidate
# Direct RDMA read: D reads cached KV from C_s's GPU, no request to C_s
c_inst = best_inst # has cache (not doing any work)
d_inst = d_candidate
d_idx = combined_instances.index(d_inst)
# Accounting: reserve both P and D immediately so router sees the load
p_new = max(0, input_length - p_inst.estimate_cache_hit(token_ids)) if token_ids else input_length
p_inst.ongoing_tokens += input_length
p_inst.pending_prefill_tokens += p_new
p_inst.num_requests += 1
p_inst.active_p_offloads += 1
breakdown["p_new_tokens"] = p_new
d_inst.ongoing_tokens += input_length
d_inst.pending_prefill_tokens += estimated_new
d_inst.num_requests += 1
c_inst.active_p_offloads += 1
breakdown["route_class"] = "HEAVY_OFFLOAD"
breakdown["offload_reason"] = offload_reason
breakdown["p_inst"] = p_inst.url
breakdown["c_inst"] = c_inst.url
breakdown["d_inst"] = d_inst.url
breakdown["p_load"] = p_inst.ongoing_tokens
breakdown["d_load"] = d_inst.ongoing_tokens
breakdown["cache_hit_tokens"] = cache_hit
if session_id:
session_affinity[session_id] = d_idx
return await _handle_heavy_offload(api, req_data, headers, token_ids,
input_length, p_inst, d_inst, breakdown)
return await _handle_direct_read_offload(
api, req_data, headers, token_ids, input_length,
c_inst, d_inst, cache_hit, estimated_new, breakdown)
else:
if estimated_new >= HEAVY_THRESHOLD:
breakdown["route_class"] = "HEAVY_COLO"
@@ -423,112 +418,31 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
p_inst, d_inst, breakdown):
"""HEAVY request with cache-aware KV migration.
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
On export failure, falls back to co-located prefill+decode on d_inst.
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
input_length, c_inst, d_inst,
cache_hit, estimated_new, breakdown):
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
"""
request_id = headers.get("X-Request-Id", "")
estimated_new = breakdown.get("estimated_new_tokens", 0)
cache_hit = breakdown.get("cache_hit", 0)
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
# Step 1: C_s exports cached KV blocks
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
# (full prefix cache hit), then pushes cached blocks to Mooncake.
breakdown["t_export_sent"] = _time.monotonic()
export_ok = False
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
# Align cache_hit to block boundary for remote_num_tokens
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
if cached_tokens > 0 and token_ids:
export_prompt = token_ids[:cached_tokens]
else:
export_prompt = token_ids
breakdown["t_offload_sent"] = _time.monotonic()
try:
export_data = {
"model": req_data.get("model", ""),
"prompt": export_prompt,
"max_tokens": 1,
"temperature": 0,
"stream": False,
"kv_transfer_params": {
"do_remote_decode": True,
"do_remote_prefill": False,
"transfer_id": "xfer-" + request_id,
},
}
p_headers = {**headers, "X-data-parallel-rank": "0"}
resp = await asyncio.wait_for(
p_inst.client.post(api, json=export_data, headers=p_headers),
timeout=PREFILL_TIMEOUT_S,
)
resp.raise_for_status()
await resp.aclose()
breakdown["t_export_done"] = _time.monotonic()
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
export_ok = True
except Exception as e:
breakdown["t_export_done"] = _time.monotonic()
breakdown["export_error"] = str(e)
finally:
p_inst.ongoing_tokens -= input_length
p_inst.pending_prefill_tokens -= p_prefill_release
p_inst.num_requests -= 1
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
if not export_ok:
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
d_inst.pending_prefill_tokens += estimated_new
async def generate_fallback():
prefill_done = False
try:
async with d_inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
d_inst.ongoing_decode_tokens += input_length
breakdown["t_first_token"] = _time.monotonic()
prefill_done = True
yield chunk
d_inst.record_prefix(token_ids)
finally:
if not prefill_done:
d_inst.pending_prefill_tokens -= estimated_new
else:
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
exported_tokens = breakdown.get("exported_tokens", 0)
d_inst.pending_prefill_tokens += estimated_new
breakdown["t_decode_sent"] = _time.monotonic()
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
# Send full prompt to D with direct_read flag
decode_data = req_data.copy()
decode_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"direct_read": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": p_inst.engine_id.get(0, ""),
"remote_engine_id": c_inst.engine_id.get(0, ""),
"transfer_id": "xfer-" + request_id,
"remote_num_tokens": exported_tokens,
"remote_num_tokens": cached_tokens,
}
async def generate():
@@ -551,6 +465,7 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
d_inst.ongoing_decode_tokens -= input_length
d_inst.ongoing_tokens -= input_length
d_inst.num_requests -= 1
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
breakdown["t_done"] = _time.monotonic()
_breakdown_log.append(breakdown)