Direct RDMA read: D reads cached KV from C's GPU without C's scheduler

Complete implementation of direct RDMA read for KV cache migration:

vLLM Mooncake connector (mooncake_connector.py):
- PullReqMeta: add direct_read flag + block_hashes
- MooncakeConnectorMetadata: add hash_table_updates/removals for
  scheduler->worker block hash sync
- MooncakeConnectorScheduler: set_block_pool() to access BlockPool,
  build_connector_meta() computes hash table deltas each step,
  update_state_after_alloc() captures request block hashes for direct_read
- MooncakeConnectorWorker: _start_direct_read() + _direct_read_single()
  implements D-side RDMA read via batch_transfer_sync_read, with
  HTTP query/unpin to C's bootstrap server

Bootstrap server (mooncake_utils.py):
- POST /query_blocks: look up block hashes, return block_ids + GPU layout
- POST /unpin_blocks: release pin tracking
- set_worker_kv_info(): register GPU addresses at init
- update_hash_table(): receive scheduler deltas each step

Scheduler (scheduler.py):
- One-line hookup: pass block_pool to connector after KVCacheManager init

Proxy (cache_aware_proxy.py):
- _handle_direct_read_offload: sends request ONLY to D with
  direct_read=True + remote_bootstrap_addr. No request to C at all.
- C's scheduler is completely uninvolved (0 GPU time on C)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-23 21:02:13 +08:00
parent 020be9f444
commit a7df84bd3b
4 changed files with 271 additions and 123 deletions

View File

@@ -108,6 +108,9 @@ class PullReqMeta:
expire_time: float = float("inf")
# Designed for one D pairing to multiple P
pull_tasks_count: int = 0
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
direct_read: bool = False
block_hashes: list[bytes] = field(default_factory=list)
@dataclass
@@ -124,11 +127,12 @@ class SendBlockMeta:
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
# Use (engine_id, dp_rank) to group reqs with same dp.
# See comments in MooncakeBootstrapServer.
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
self.reqs_not_processed: set[TransferId] = set()
# Hash table sync: scheduler → worker (for direct RDMA read)
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
self.hash_table_removals: set[str] = set()
def add_new_req(
self,
@@ -136,6 +140,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
block_hashes: list[bytes] | None = None,
):
transfer_id = kv_transfer_params["transfer_id"]
if load_remote_cache:
@@ -146,6 +151,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_engine_id=remote_engine_id,
remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"],
transfer_id=transfer_id,
direct_read=bool(kv_transfer_params.get("direct_read")),
block_hashes=block_hashes or [],
)
else:
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
@@ -173,6 +180,10 @@ class MooncakeConnector(KVConnectorBase_V1):
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
def set_block_pool(self, block_pool):
if self.connector_scheduler is not None:
self.connector_scheduler.set_block_pool(block_pool)
############################################################
# Scheduler Side Methods
############################################################
@@ -250,6 +261,8 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self._block_pool = None
self._known_hash_keys: set = set()
assert vllm_config.kv_transfer_config
self.is_kv_producer: bool = (
@@ -260,14 +273,14 @@ class MooncakeConnectorScheduler:
)
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[TransferId] = set()
# Per-request block hashes for direct_read passthrough
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
def set_block_pool(self, block_pool):
self._block_pool = block_pool
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
@@ -348,13 +361,21 @@ class MooncakeConnectorScheduler:
else:
local_block_ids = []
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
# Capture block hashes for direct_read
if params.get("direct_read") and hasattr(request, "block_hashes"):
block_size = self.vllm_config.cache_config.block_size
num_remote_blocks = (
(num_external_tokens + block_size - 1) // block_size
)
self._req_block_hashes[request.request_id] = [
bytes(h) for h in request.block_hashes[:num_remote_blocks]
]
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
if params.get("do_remote_decode"):
@@ -371,7 +392,6 @@ class MooncakeConnectorScheduler:
) -> KVConnectorMetadata:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to PullReqMeta.
if not self.is_kv_producer:
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
@@ -379,9 +399,30 @@ class MooncakeConnectorScheduler:
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
block_hashes=self._req_block_hashes.pop(req_id, None),
)
self._reqs_need_recv.clear()
# Sync hash table to worker for direct RDMA read block lookups
if self._block_pool is not None:
cache = self._block_pool.cached_block_hash_to_block._cache
current_keys = set(cache.keys())
new_keys = current_keys - self._known_hash_keys
removed_keys = self._known_hash_keys - current_keys
if new_keys or removed_keys:
from vllm.v1.core.kv_cache_utils import get_block_hash
for k in new_keys:
block = cache[k]
if isinstance(block, dict):
bid = next(iter(block.values())).block_id
else:
bid = block.block_id
meta.hash_table_updates[get_block_hash(k).hex()] = bid
meta.hash_table_removals = {
get_block_hash(k).hex() for k in removed_keys
}
self._known_hash_keys = current_keys.copy()
if not self.is_kv_consumer:
for req_id, (req, block_ids) in self._reqs_need_send.items():
assert req.kv_transfer_params is not None
@@ -971,7 +1012,6 @@ class MooncakeConnectorWorker:
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
)
# No need to launch server for D node.
if self.is_kv_consumer:
return
@@ -979,7 +1019,13 @@ class MooncakeConnectorWorker:
asyncio.run_coroutine_threadsafe(
self._mooncake_sender_listener(ready_event), self.sender_loop
)
ready_event.wait() # Wait for listener ZMQ socket to be ready.
ready_event.wait()
if self.bootstrap_server is not None:
self.bootstrap_server.set_worker_kv_info(
self.kv_caches_base_addr, self.block_len,
self.hostname, self.rpc_port,
)
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
finished_recving_reqs = self.finished_recving_reqs
@@ -1197,6 +1243,93 @@ class MooncakeConnectorWorker:
self.receive_kv(remote_engine_id, pull_metas)
async def _start_direct_read(
self, reqs_by_engine: dict[EngineId, dict[ReqId, PullReqMeta]]
):
"""Direct RDMA read: D reads cached KV blocks from C's GPU memory
without involving C's scheduler.
"""
for _engine_id, pull_metas in reqs_by_engine.items():
for req_id, pm in pull_metas.items():
asyncio.create_task(
self._direct_read_single(req_id, pm)
)
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
pin_token = f"dr-{req_id}-{self.tp_rank}"
bootstrap_url = pm.remote_bootstrap_addr
try:
# 1. Query C's bootstrap for block mapping
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{bootstrap_url}/query_blocks",
json={
"block_hashes": [h.hex() for h in pm.block_hashes],
"pin_token": pin_token,
},
)
resp.raise_for_status()
mapping = resp.json()
remote_block_ids = [b for b in mapping["block_ids"] if b is not None]
num_matched = len(remote_block_ids)
if num_matched == 0:
logger.debug("direct_read %s: no cache hit on remote", req_id)
self.finished_recving_reqs.add(req_id)
return
# 2. Compute RDMA addresses
local_block_ids = pm.local_block_ids[:num_matched]
remote_base_addrs = mapping["kv_caches_base_addr"]
remote_block_len = mapping["block_len"]
remote_session = f"{mapping['hostname']}:{mapping['rpc_port']}"
src_ptrs: list[int] = []
dst_ptrs: list[int] = []
lengths: list[int] = []
for local_layer_addr, remote_layer_addr in zip(
self.kv_caches_base_addr, remote_base_addrs
):
grp_local, grp_remote = group_concurrent_contiguous(
local_block_ids, remote_block_ids
)
for gl, gr in zip(grp_local, grp_remote):
src_ptrs.append(remote_layer_addr + gr[0] * remote_block_len)
dst_ptrs.append(local_layer_addr + gl[0] * self.block_len)
lengths.append(remote_block_len * len(gr))
# 3. RDMA READ (D pulls from C's GPU memory)
logger.debug(
"direct_read %s: reading %d blocks from %s",
req_id, num_matched, remote_session,
)
ret = self.engine.batch_transfer_sync_read(
remote_session, src_ptrs, dst_ptrs, lengths,
)
if ret != 0:
logger.error("direct_read %s: RDMA read failed ret=%d", req_id, ret)
return
logger.debug("direct_read %s: success (%d blocks)", req_id, num_matched)
self.finished_recving_reqs.add(req_id)
except Exception as e:
logger.error("direct_read %s failed: %s", req_id, e)
finally:
# 4. Unpin blocks on C
try:
async with httpx.AsyncClient(timeout=10) as client:
await client.post(
f"{bootstrap_url}/unpin_blocks",
json={"pin_token": pin_token},
)
except Exception:
pass
async def _start_load_kv(
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
):
@@ -1237,11 +1370,34 @@ class MooncakeConnectorWorker:
assert not send_meta.ready.is_set()
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if not self.is_kv_producer and metadata.reqs_to_recv:
asyncio.run_coroutine_threadsafe(
self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop
# Sync hash table to bootstrap server (for direct RDMA read queries)
if self.bootstrap_server is not None and (
metadata.hash_table_updates or metadata.hash_table_removals
):
self.bootstrap_server.update_hash_table(
metadata.hash_table_updates, metadata.hash_table_removals
)
if not self.is_kv_producer and metadata.reqs_to_recv:
# Split direct_read vs normal pull requests
direct_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
normal_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
for engine_id, pull_metas in metadata.reqs_to_recv.items():
for req_id, pm in pull_metas.items():
if pm.direct_read:
direct_reqs[engine_id][req_id] = pm
else:
normal_reqs[engine_id][req_id] = pm
if normal_reqs:
asyncio.run_coroutine_threadsafe(
self._start_load_kv(normal_reqs), self.receiver_loop
)
if direct_reqs:
asyncio.run_coroutine_threadsafe(
self._start_direct_read(direct_reqs), self.receiver_loop
)
if not self.is_kv_consumer and (
metadata.reqs_to_send or metadata.reqs_not_processed
):

View File

@@ -32,10 +32,28 @@ class EngineEntry:
worker_addr: dict[int, dict[int, WorkerAddr]]
class QueryBlocksRequest(BaseModel):
block_hashes: list[str] # hex-encoded BlockHash values
pin_token: str
class QueryBlocksResponse(BaseModel):
block_ids: list[int | None] # None = cache miss (prefix match stops)
kv_caches_base_addr: list[int]
block_len: int
hostname: str
rpc_port: int
class UnpinBlocksRequest(BaseModel):
pin_token: str
class MooncakeBootstrapServer:
"""
A centralized server running on the global rank 0 prefiller worker.
Prefiller workers register their connection info (IP, port, ranks) here.
Also serves block mapping queries for direct RDMA read.
"""
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
@@ -48,13 +66,19 @@ class MooncakeBootstrapServer:
self.server_thread: threading.Thread | None = None
self.server: uvicorn.Server | None = None
# Direct RDMA read support: block hash → block_id mapping
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id
self._kv_info: dict | None = None # set by worker at register_kv_caches
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
def __del__(self):
self.shutdown()
def _register_routes(self):
# All methods are async. No need to use lock to protect data.
self.app.post("/register")(self.register_worker)
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
self.app.post("/query_blocks")(self.query_blocks)
self.app.post("/unpin_blocks")(self.unpin_blocks)
def start(self):
if self.server_thread:
@@ -125,3 +149,54 @@ class MooncakeBootstrapServer:
async def query(self) -> dict[int, EngineEntry]:
return self.workers
def set_worker_kv_info(
self,
kv_caches_base_addr: list[int],
block_len: int,
hostname: str,
rpc_port: int,
):
self._kv_info = {
"kv_caches_base_addr": kv_caches_base_addr,
"block_len": block_len,
"hostname": hostname,
"rpc_port": rpc_port,
}
def update_hash_table(
self,
updates: dict[str, int],
removals: set[str],
):
for k in removals:
self._hash_table.pop(k, None)
self._hash_table.update(updates)
async def query_blocks(self, req: QueryBlocksRequest):
if self._kv_info is None:
raise HTTPException(503, "Worker KV info not registered yet")
block_ids: list[int | None] = []
pinned_ids: list[int] = []
for h in req.block_hashes:
bid = self._hash_table.get(h)
if bid is not None:
block_ids.append(bid)
pinned_ids.append(bid)
else:
block_ids.append(None)
break # prefix match: stop at first miss
self._pinned[req.pin_token] = pinned_ids
return QueryBlocksResponse(
block_ids=block_ids,
kv_caches_base_addr=self._kv_info["kv_caches_base_addr"],
block_len=self._kv_info["block_len"],
hostname=self._kv_info["hostname"],
rpc_port=self._kv_info["rpc_port"],
)
async def unpin_blocks(self, req: UnpinBlocksRequest):
self._pinned.pop(req.pin_token, None)
return {"status": "ok"}

View File

@@ -234,6 +234,8 @@ class Scheduler(SchedulerInterface):
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
if self.connector is not None and hasattr(self.connector, "set_block_pool"):
self.connector.set_block_pool(self.kv_cache_manager.block_pool)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER