Direct RDMA read: D reads cached KV from C's GPU without C's scheduler
Complete implementation of direct RDMA read for KV cache migration: vLLM Mooncake connector (mooncake_connector.py): - PullReqMeta: add direct_read flag + block_hashes - MooncakeConnectorMetadata: add hash_table_updates/removals for scheduler->worker block hash sync - MooncakeConnectorScheduler: set_block_pool() to access BlockPool, build_connector_meta() computes hash table deltas each step, update_state_after_alloc() captures request block hashes for direct_read - MooncakeConnectorWorker: _start_direct_read() + _direct_read_single() implements D-side RDMA read via batch_transfer_sync_read, with HTTP query/unpin to C's bootstrap server Bootstrap server (mooncake_utils.py): - POST /query_blocks: look up block hashes, return block_ids + GPU layout - POST /unpin_blocks: release pin tracking - set_worker_kv_info(): register GPU addresses at init - update_hash_table(): receive scheduler deltas each step Scheduler (scheduler.py): - One-line hookup: pass block_pool to connector after KVCacheManager init Proxy (cache_aware_proxy.py): - _handle_direct_read_offload: sends request ONLY to D with direct_read=True + remote_bootstrap_addr. No request to C at all. - C's scheduler is completely uninvolved (0 GPU time on C) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -108,6 +108,9 @@ class PullReqMeta:
|
||||
expire_time: float = float("inf")
|
||||
# Designed for one D pairing to multiple P
|
||||
pull_tasks_count: int = 0
|
||||
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
||||
direct_read: bool = False
|
||||
block_hashes: list[bytes] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -124,11 +127,12 @@ class SendBlockMeta:
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
# Use (engine_id, dp_rank) to group reqs with same dp.
|
||||
# See comments in MooncakeBootstrapServer.
|
||||
self.reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
self.reqs_to_send: dict[ReqId, tuple[TransferId, list[int]]] = {}
|
||||
self.reqs_not_processed: set[TransferId] = set()
|
||||
# Hash table sync: scheduler → worker (for direct RDMA read)
|
||||
self.hash_table_updates: dict[str, int] = {} # hex hash → block_id
|
||||
self.hash_table_removals: set[str] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@@ -136,6 +140,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
load_remote_cache: bool = True,
|
||||
block_hashes: list[bytes] | None = None,
|
||||
):
|
||||
transfer_id = kv_transfer_params["transfer_id"]
|
||||
if load_remote_cache:
|
||||
@@ -146,6 +151,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_engine_id=remote_engine_id,
|
||||
remote_bootstrap_addr=kv_transfer_params["remote_bootstrap_addr"],
|
||||
transfer_id=transfer_id,
|
||||
direct_read=bool(kv_transfer_params.get("direct_read")),
|
||||
block_hashes=block_hashes or [],
|
||||
)
|
||||
else:
|
||||
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
||||
@@ -173,6 +180,10 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
def set_block_pool(self, block_pool):
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.set_block_pool(block_pool)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@@ -250,6 +261,8 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self._block_pool = None
|
||||
self._known_hash_keys: set = set()
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.is_kv_producer: bool = (
|
||||
@@ -260,14 +273,14 @@ class MooncakeConnectorScheduler:
|
||||
)
|
||||
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to remove from processed set because they're not to send after
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[TransferId] = set()
|
||||
# Per-request block hashes for direct_read passthrough
|
||||
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
||||
|
||||
def set_block_pool(self, block_pool):
|
||||
self._block_pool = block_pool
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
@@ -348,13 +361,21 @@ class MooncakeConnectorScheduler:
|
||||
else:
|
||||
local_block_ids = []
|
||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||
# Capture block hashes for direct_read
|
||||
if params.get("direct_read") and hasattr(request, "block_hashes"):
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
num_remote_blocks = (
|
||||
(num_external_tokens + block_size - 1) // block_size
|
||||
)
|
||||
self._req_block_hashes[request.request_id] = [
|
||||
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer",
|
||||
params,
|
||||
)
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
@@ -371,7 +392,6 @@ class MooncakeConnectorScheduler:
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to PullReqMeta.
|
||||
if not self.is_kv_producer:
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
@@ -379,9 +399,30 @@ class MooncakeConnectorScheduler:
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
block_hashes=self._req_block_hashes.pop(req_id, None),
|
||||
)
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
# Sync hash table to worker for direct RDMA read block lookups
|
||||
if self._block_pool is not None:
|
||||
cache = self._block_pool.cached_block_hash_to_block._cache
|
||||
current_keys = set(cache.keys())
|
||||
new_keys = current_keys - self._known_hash_keys
|
||||
removed_keys = self._known_hash_keys - current_keys
|
||||
if new_keys or removed_keys:
|
||||
from vllm.v1.core.kv_cache_utils import get_block_hash
|
||||
for k in new_keys:
|
||||
block = cache[k]
|
||||
if isinstance(block, dict):
|
||||
bid = next(iter(block.values())).block_id
|
||||
else:
|
||||
bid = block.block_id
|
||||
meta.hash_table_updates[get_block_hash(k).hex()] = bid
|
||||
meta.hash_table_removals = {
|
||||
get_block_hash(k).hex() for k in removed_keys
|
||||
}
|
||||
self._known_hash_keys = current_keys.copy()
|
||||
|
||||
if not self.is_kv_consumer:
|
||||
for req_id, (req, block_ids) in self._reqs_need_send.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
@@ -971,7 +1012,6 @@ class MooncakeConnectorWorker:
|
||||
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
|
||||
)
|
||||
|
||||
# No need to launch server for D node.
|
||||
if self.is_kv_consumer:
|
||||
return
|
||||
|
||||
@@ -979,7 +1019,13 @@ class MooncakeConnectorWorker:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._mooncake_sender_listener(ready_event), self.sender_loop
|
||||
)
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
ready_event.wait()
|
||||
|
||||
if self.bootstrap_server is not None:
|
||||
self.bootstrap_server.set_worker_kv_info(
|
||||
self.kv_caches_base_addr, self.block_len,
|
||||
self.hostname, self.rpc_port,
|
||||
)
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
finished_recving_reqs = self.finished_recving_reqs
|
||||
@@ -1197,6 +1243,93 @@ class MooncakeConnectorWorker:
|
||||
|
||||
self.receive_kv(remote_engine_id, pull_metas)
|
||||
|
||||
async def _start_direct_read(
|
||||
self, reqs_by_engine: dict[EngineId, dict[ReqId, PullReqMeta]]
|
||||
):
|
||||
"""Direct RDMA read: D reads cached KV blocks from C's GPU memory
|
||||
without involving C's scheduler.
|
||||
"""
|
||||
for _engine_id, pull_metas in reqs_by_engine.items():
|
||||
for req_id, pm in pull_metas.items():
|
||||
asyncio.create_task(
|
||||
self._direct_read_single(req_id, pm)
|
||||
)
|
||||
|
||||
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
|
||||
pin_token = f"dr-{req_id}-{self.tp_rank}"
|
||||
bootstrap_url = pm.remote_bootstrap_addr
|
||||
|
||||
try:
|
||||
# 1. Query C's bootstrap for block mapping
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{bootstrap_url}/query_blocks",
|
||||
json={
|
||||
"block_hashes": [h.hex() for h in pm.block_hashes],
|
||||
"pin_token": pin_token,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
mapping = resp.json()
|
||||
|
||||
remote_block_ids = [b for b in mapping["block_ids"] if b is not None]
|
||||
num_matched = len(remote_block_ids)
|
||||
|
||||
if num_matched == 0:
|
||||
logger.debug("direct_read %s: no cache hit on remote", req_id)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
return
|
||||
|
||||
# 2. Compute RDMA addresses
|
||||
local_block_ids = pm.local_block_ids[:num_matched]
|
||||
remote_base_addrs = mapping["kv_caches_base_addr"]
|
||||
remote_block_len = mapping["block_len"]
|
||||
remote_session = f"{mapping['hostname']}:{mapping['rpc_port']}"
|
||||
|
||||
src_ptrs: list[int] = []
|
||||
dst_ptrs: list[int] = []
|
||||
lengths: list[int] = []
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
self.kv_caches_base_addr, remote_base_addrs
|
||||
):
|
||||
grp_local, grp_remote = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
for gl, gr in zip(grp_local, grp_remote):
|
||||
src_ptrs.append(remote_layer_addr + gr[0] * remote_block_len)
|
||||
dst_ptrs.append(local_layer_addr + gl[0] * self.block_len)
|
||||
lengths.append(remote_block_len * len(gr))
|
||||
|
||||
# 3. RDMA READ (D pulls from C's GPU memory)
|
||||
logger.debug(
|
||||
"direct_read %s: reading %d blocks from %s",
|
||||
req_id, num_matched, remote_session,
|
||||
)
|
||||
ret = self.engine.batch_transfer_sync_read(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths,
|
||||
)
|
||||
|
||||
if ret != 0:
|
||||
logger.error("direct_read %s: RDMA read failed ret=%d", req_id, ret)
|
||||
return
|
||||
|
||||
logger.debug("direct_read %s: success (%d blocks)", req_id, num_matched)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("direct_read %s failed: %s", req_id, e)
|
||||
finally:
|
||||
# 4. Unpin blocks on C
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{bootstrap_url}/unpin_blocks",
|
||||
json={"pin_token": pin_token},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_load_kv(
|
||||
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
|
||||
):
|
||||
@@ -1237,11 +1370,34 @@ class MooncakeConnectorWorker:
|
||||
assert not send_meta.ready.is_set()
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if not self.is_kv_producer and metadata.reqs_to_recv:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_load_kv(metadata.reqs_to_recv), self.receiver_loop
|
||||
# Sync hash table to bootstrap server (for direct RDMA read queries)
|
||||
if self.bootstrap_server is not None and (
|
||||
metadata.hash_table_updates or metadata.hash_table_removals
|
||||
):
|
||||
self.bootstrap_server.update_hash_table(
|
||||
metadata.hash_table_updates, metadata.hash_table_removals
|
||||
)
|
||||
|
||||
if not self.is_kv_producer and metadata.reqs_to_recv:
|
||||
# Split direct_read vs normal pull requests
|
||||
direct_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
normal_reqs: dict[EngineId, dict[ReqId, PullReqMeta]] = defaultdict(dict)
|
||||
for engine_id, pull_metas in metadata.reqs_to_recv.items():
|
||||
for req_id, pm in pull_metas.items():
|
||||
if pm.direct_read:
|
||||
direct_reqs[engine_id][req_id] = pm
|
||||
else:
|
||||
normal_reqs[engine_id][req_id] = pm
|
||||
|
||||
if normal_reqs:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_load_kv(normal_reqs), self.receiver_loop
|
||||
)
|
||||
if direct_reqs:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._start_direct_read(direct_reqs), self.receiver_loop
|
||||
)
|
||||
|
||||
if not self.is_kv_consumer and (
|
||||
metadata.reqs_to_send or metadata.reqs_not_processed
|
||||
):
|
||||
|
||||
@@ -32,10 +32,28 @@ class EngineEntry:
|
||||
worker_addr: dict[int, dict[int, WorkerAddr]]
|
||||
|
||||
|
||||
class QueryBlocksRequest(BaseModel):
|
||||
block_hashes: list[str] # hex-encoded BlockHash values
|
||||
pin_token: str
|
||||
|
||||
|
||||
class QueryBlocksResponse(BaseModel):
|
||||
block_ids: list[int | None] # None = cache miss (prefix match stops)
|
||||
kv_caches_base_addr: list[int]
|
||||
block_len: int
|
||||
hostname: str
|
||||
rpc_port: int
|
||||
|
||||
|
||||
class UnpinBlocksRequest(BaseModel):
|
||||
pin_token: str
|
||||
|
||||
|
||||
class MooncakeBootstrapServer:
|
||||
"""
|
||||
A centralized server running on the global rank 0 prefiller worker.
|
||||
Prefiller workers register their connection info (IP, port, ranks) here.
|
||||
Also serves block mapping queries for direct RDMA read.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, host: str, port: int):
|
||||
@@ -48,13 +66,19 @@ class MooncakeBootstrapServer:
|
||||
self.server_thread: threading.Thread | None = None
|
||||
self.server: uvicorn.Server | None = None
|
||||
|
||||
# Direct RDMA read support: block hash → block_id mapping
|
||||
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id
|
||||
self._kv_info: dict | None = None # set by worker at register_kv_caches
|
||||
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def _register_routes(self):
|
||||
# All methods are async. No need to use lock to protect data.
|
||||
self.app.post("/register")(self.register_worker)
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
self.app.post("/query_blocks")(self.query_blocks)
|
||||
self.app.post("/unpin_blocks")(self.unpin_blocks)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
@@ -125,3 +149,54 @@ class MooncakeBootstrapServer:
|
||||
|
||||
async def query(self) -> dict[int, EngineEntry]:
|
||||
return self.workers
|
||||
|
||||
def set_worker_kv_info(
|
||||
self,
|
||||
kv_caches_base_addr: list[int],
|
||||
block_len: int,
|
||||
hostname: str,
|
||||
rpc_port: int,
|
||||
):
|
||||
self._kv_info = {
|
||||
"kv_caches_base_addr": kv_caches_base_addr,
|
||||
"block_len": block_len,
|
||||
"hostname": hostname,
|
||||
"rpc_port": rpc_port,
|
||||
}
|
||||
|
||||
def update_hash_table(
|
||||
self,
|
||||
updates: dict[str, int],
|
||||
removals: set[str],
|
||||
):
|
||||
for k in removals:
|
||||
self._hash_table.pop(k, None)
|
||||
self._hash_table.update(updates)
|
||||
|
||||
async def query_blocks(self, req: QueryBlocksRequest):
|
||||
if self._kv_info is None:
|
||||
raise HTTPException(503, "Worker KV info not registered yet")
|
||||
|
||||
block_ids: list[int | None] = []
|
||||
pinned_ids: list[int] = []
|
||||
for h in req.block_hashes:
|
||||
bid = self._hash_table.get(h)
|
||||
if bid is not None:
|
||||
block_ids.append(bid)
|
||||
pinned_ids.append(bid)
|
||||
else:
|
||||
block_ids.append(None)
|
||||
break # prefix match: stop at first miss
|
||||
|
||||
self._pinned[req.pin_token] = pinned_ids
|
||||
return QueryBlocksResponse(
|
||||
block_ids=block_ids,
|
||||
kv_caches_base_addr=self._kv_info["kv_caches_base_addr"],
|
||||
block_len=self._kv_info["block_len"],
|
||||
hostname=self._kv_info["hostname"],
|
||||
rpc_port=self._kv_info["rpc_port"],
|
||||
)
|
||||
|
||||
async def unpin_blocks(self, req: UnpinBlocksRequest):
|
||||
self._pinned.pop(req.pin_token, None)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -234,6 +234,8 @@ class Scheduler(SchedulerInterface):
|
||||
hash_block_size=self.block_size,
|
||||
metrics_collector=self.kv_metrics_collector,
|
||||
)
|
||||
if self.connector is not None and hasattr(self.connector, "set_block_pool"):
|
||||
self.connector.set_block_pool(self.kv_cache_manager.block_pool)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
|
||||
Reference in New Issue
Block a user