Direct RDMA read: D reads cached KV from C's GPU without C's scheduler
Complete implementation of direct RDMA read for KV cache migration: vLLM Mooncake connector (mooncake_connector.py): - PullReqMeta: add direct_read flag + block_hashes - MooncakeConnectorMetadata: add hash_table_updates/removals for scheduler->worker block hash sync - MooncakeConnectorScheduler: set_block_pool() to access BlockPool, build_connector_meta() computes hash table deltas each step, update_state_after_alloc() captures request block hashes for direct_read - MooncakeConnectorWorker: _start_direct_read() + _direct_read_single() implements D-side RDMA read via batch_transfer_sync_read, with HTTP query/unpin to C's bootstrap server Bootstrap server (mooncake_utils.py): - POST /query_blocks: look up block hashes, return block_ids + GPU layout - POST /unpin_blocks: release pin tracking - set_worker_kv_info(): register GPU addresses at init - update_hash_table(): receive scheduler deltas each step Scheduler (scheduler.py): - One-line hookup: pass block_pool to connector after KVCacheManager init Proxy (cache_aware_proxy.py): - _handle_direct_read_offload: sends request ONLY to D with direct_read=True + remote_bootstrap_addr. No request to C at all. - C's scheduler is completely uninvolved (0 GPU time on C) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -351,33 +351,28 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
offload_reason = "colocated_cheaper_%.1fvs%.1f" % (colocated_cost, offload_cost)
|
||||
|
||||
if use_offload:
|
||||
p_inst = p_candidate
|
||||
# Direct RDMA read: D reads cached KV from C_s's GPU, no request to C_s
|
||||
c_inst = best_inst # has cache (not doing any work)
|
||||
d_inst = d_candidate
|
||||
d_idx = combined_instances.index(d_inst)
|
||||
|
||||
# Accounting: reserve both P and D immediately so router sees the load
|
||||
p_new = max(0, input_length - p_inst.estimate_cache_hit(token_ids)) if token_ids else input_length
|
||||
p_inst.ongoing_tokens += input_length
|
||||
p_inst.pending_prefill_tokens += p_new
|
||||
p_inst.num_requests += 1
|
||||
p_inst.active_p_offloads += 1
|
||||
breakdown["p_new_tokens"] = p_new
|
||||
|
||||
d_inst.ongoing_tokens += input_length
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
d_inst.num_requests += 1
|
||||
c_inst.active_p_offloads += 1
|
||||
|
||||
breakdown["route_class"] = "HEAVY_OFFLOAD"
|
||||
breakdown["offload_reason"] = offload_reason
|
||||
breakdown["p_inst"] = p_inst.url
|
||||
breakdown["c_inst"] = c_inst.url
|
||||
breakdown["d_inst"] = d_inst.url
|
||||
breakdown["p_load"] = p_inst.ongoing_tokens
|
||||
breakdown["d_load"] = d_inst.ongoing_tokens
|
||||
breakdown["cache_hit_tokens"] = cache_hit
|
||||
|
||||
if session_id:
|
||||
session_affinity[session_id] = d_idx
|
||||
|
||||
return await _handle_heavy_offload(api, req_data, headers, token_ids,
|
||||
input_length, p_inst, d_inst, breakdown)
|
||||
return await _handle_direct_read_offload(
|
||||
api, req_data, headers, token_ids, input_length,
|
||||
c_inst, d_inst, cache_hit, estimated_new, breakdown)
|
||||
else:
|
||||
if estimated_new >= HEAVY_THRESHOLD:
|
||||
breakdown["route_class"] = "HEAVY_COLO"
|
||||
@@ -423,112 +418,31 @@ async def _handle_combined(api, req_data, token_ids, input_length, session_id, h
|
||||
PREFILL_TIMEOUT_S = 120 # max seconds to wait for P-instance prefill
|
||||
|
||||
|
||||
async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
p_inst, d_inst, breakdown):
|
||||
"""HEAVY request with cache-aware KV migration.
|
||||
|
||||
C_s (p_inst, has cache) exports cached KV blocks via Mooncake.
|
||||
D (d_inst, idle) pulls cached blocks + does local prefill for new tokens + decodes.
|
||||
C_s's blocks are auto-freed by Mooncake after D confirms receipt.
|
||||
|
||||
On export failure, falls back to co-located prefill+decode on d_inst.
|
||||
async def _handle_direct_read_offload(api, req_data, headers, token_ids,
|
||||
input_length, c_inst, d_inst,
|
||||
cache_hit, estimated_new, breakdown):
|
||||
"""HEAVY request: D direct-RDMA-reads cached KV from C_s, then does
|
||||
local prefill for new tokens + decode. C_s's scheduler is NOT involved.
|
||||
"""
|
||||
request_id = headers.get("X-Request-Id", "")
|
||||
estimated_new = breakdown.get("estimated_new_tokens", 0)
|
||||
cache_hit = breakdown.get("cache_hit", 0)
|
||||
p_prefill_release = breakdown.get("p_new_tokens", estimated_new)
|
||||
|
||||
# Step 1: C_s exports cached KV blocks
|
||||
# Send TRUNCATED prompt (only cached portion) so C_s does 0 compute
|
||||
# (full prefix cache hit), then pushes cached blocks to Mooncake.
|
||||
breakdown["t_export_sent"] = _time.monotonic()
|
||||
export_ok = False
|
||||
|
||||
# Truncate prompt to cached portion (aligned to BLOCK_SIZE)
|
||||
# Align cache_hit to block boundary for remote_num_tokens
|
||||
cached_tokens = (cache_hit // BLOCK_SIZE) * BLOCK_SIZE
|
||||
if cached_tokens > 0 and token_ids:
|
||||
export_prompt = token_ids[:cached_tokens]
|
||||
else:
|
||||
export_prompt = token_ids
|
||||
breakdown["t_offload_sent"] = _time.monotonic()
|
||||
|
||||
try:
|
||||
export_data = {
|
||||
"model": req_data.get("model", ""),
|
||||
"prompt": export_prompt,
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
"stream": False,
|
||||
"kv_transfer_params": {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
},
|
||||
}
|
||||
|
||||
p_headers = {**headers, "X-data-parallel-rank": "0"}
|
||||
resp = await asyncio.wait_for(
|
||||
p_inst.client.post(api, json=export_data, headers=p_headers),
|
||||
timeout=PREFILL_TIMEOUT_S,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
await resp.aclose()
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["exported_tokens"] = cached_tokens if cached_tokens > 0 else len(export_prompt)
|
||||
export_ok = True
|
||||
except Exception as e:
|
||||
breakdown["t_export_done"] = _time.monotonic()
|
||||
breakdown["export_error"] = str(e)
|
||||
finally:
|
||||
p_inst.ongoing_tokens -= input_length
|
||||
p_inst.pending_prefill_tokens -= p_prefill_release
|
||||
p_inst.num_requests -= 1
|
||||
p_inst.active_p_offloads = max(0, p_inst.active_p_offloads - 1)
|
||||
|
||||
if not export_ok:
|
||||
breakdown["route_class"] = "HEAVY_COLO_FALLBACK"
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
|
||||
async def generate_fallback():
|
||||
prefill_done = False
|
||||
try:
|
||||
async with d_inst.client.stream("POST", api, json=req_data, headers=headers) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
if not prefill_done:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
d_inst.ongoing_decode_tokens += input_length
|
||||
breakdown["t_first_token"] = _time.monotonic()
|
||||
prefill_done = True
|
||||
yield chunk
|
||||
d_inst.record_prefix(token_ids)
|
||||
finally:
|
||||
if not prefill_done:
|
||||
d_inst.pending_prefill_tokens -= estimated_new
|
||||
else:
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
return StreamingResponse(generate_fallback(), media_type="text/event-stream")
|
||||
|
||||
# Step 2: D pulls cached blocks + does local prefill for new tokens + decodes
|
||||
exported_tokens = breakdown.get("exported_tokens", 0)
|
||||
d_inst.pending_prefill_tokens += estimated_new
|
||||
breakdown["t_decode_sent"] = _time.monotonic()
|
||||
|
||||
parsed = urllib.parse.urlparse(str(p_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, p_inst.bootstrap_port)
|
||||
parsed = urllib.parse.urlparse(str(c_inst.client.base_url))
|
||||
bootstrap_addr = "http://%s:%s" % (parsed.hostname, c_inst.bootstrap_port)
|
||||
|
||||
# Send full prompt to D with direct_read flag
|
||||
decode_data = req_data.copy()
|
||||
decode_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"direct_read": True,
|
||||
"remote_bootstrap_addr": bootstrap_addr,
|
||||
"remote_engine_id": p_inst.engine_id.get(0, ""),
|
||||
"remote_engine_id": c_inst.engine_id.get(0, ""),
|
||||
"transfer_id": "xfer-" + request_id,
|
||||
"remote_num_tokens": exported_tokens,
|
||||
"remote_num_tokens": cached_tokens,
|
||||
}
|
||||
|
||||
async def generate():
|
||||
@@ -551,6 +465,7 @@ async def _handle_heavy_offload(api, req_data, headers, token_ids, input_length,
|
||||
d_inst.ongoing_decode_tokens -= input_length
|
||||
d_inst.ongoing_tokens -= input_length
|
||||
d_inst.num_requests -= 1
|
||||
c_inst.active_p_offloads = max(0, c_inst.active_p_offloads - 1)
|
||||
breakdown["t_done"] = _time.monotonic()
|
||||
_breakdown_log.append(breakdown)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user