Switch from RDMA READ to bootstrap-triggered PUSH

RDMA READ (batch_transfer_sync_read) fails on GPU memory because
batch_register_memory only sets IBV_ACCESS_REMOTE_WRITE.

New approach: D sends /push_blocks to C's bootstrap with token_ids
+ D's GPU addresses. C's bootstrap:
1. Looks up matching blocks in synced hash table (640/640 verified)
2. Uses C's TransferEngine.batch_transfer_sync_write to PUSH blocks
   directly into D's GPU memory
3. Returns match count + push status

C's scheduler is still NOT involved (0 GPU compute on C).
The push uses C's worker thread + existing RDMA WRITE path (proven reliable).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 01:47:49 +08:00
parent 6716a3401a
commit e3a1d70cf2
2 changed files with 84 additions and 64 deletions

View File

@@ -1053,6 +1053,7 @@ class MooncakeConnectorWorker:
self.bootstrap_server.set_worker_kv_info(
self.kv_caches_base_addr, self.block_len,
self.block_size, self.hostname, self.rpc_port,
transfer_engine=self.engine,
)
if _shared_block_pool is not None:
self.bootstrap_server.set_block_pool(_shared_block_pool)
@@ -1286,83 +1287,47 @@ class MooncakeConnectorWorker:
)
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
pin_token = f"dr-{req_id}-{self.tp_rank}"
"""Bootstrap-triggered PUSH: D asks C's bootstrap to push matched blocks.
C's bootstrap looks up cached blocks by token_ids, then uses C's
TransferEngine to RDMA WRITE (push) them directly into D's GPU memory.
C's scheduler is NOT involved.
"""
bootstrap_url = pm.remote_bootstrap_addr
num_remote_tokens = pm.remote_num_tokens or len(pm.prompt_token_ids)
try:
# 1. Query C's bootstrap for block mapping (token-based lookup)
query_payload = {"pin_token": pin_token}
if pm.prompt_token_ids:
query_payload["token_ids"] = pm.prompt_token_ids
query_payload["num_tokens"] = pm.remote_num_tokens or len(pm.prompt_token_ids)
else:
query_payload["block_hashes"] = [h.hex() for h in pm.block_hashes]
local_block_ids = pm.local_block_ids
d_session = f"{self.hostname}:{self.rpc_port}"
async with httpx.AsyncClient(timeout=30) as client:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{bootstrap_url}/query_blocks",
json=query_payload,
f"{bootstrap_url}/push_blocks",
json={
"token_ids": pm.prompt_token_ids,
"num_tokens": num_remote_tokens,
"dst_block_ids": local_block_ids,
"dst_base_addrs": self.kv_caches_base_addr,
"dst_block_len": self.block_len,
"dst_session": d_session,
},
)
resp.raise_for_status()
mapping = resp.json()
result = resp.json()
remote_block_ids = [b for b in mapping["block_ids"] if b is not None]
num_matched = len(remote_block_ids)
matched = result.get("matched", 0)
pushed = result.get("pushed", False)
if num_matched == 0:
logger.debug("direct_read %s: no cache hit on remote", req_id)
self.finished_recving_reqs.add(req_id)
return
if matched > 0 and pushed:
logger.info("direct_push %s: %d blocks pushed from C", req_id, matched)
else:
logger.debug("direct_push %s: %d matched, pushed=%s", req_id, matched, pushed)
# 2. Compute RDMA addresses
local_block_ids = pm.local_block_ids[:num_matched]
remote_base_addrs = mapping["kv_caches_base_addr"]
remote_block_len = mapping["block_len"]
remote_session = f"{mapping['hostname']}:{mapping['rpc_port']}"
src_ptrs: list[int] = []
dst_ptrs: list[int] = []
lengths: list[int] = []
for local_layer_addr, remote_layer_addr in zip(
self.kv_caches_base_addr, remote_base_addrs
):
grp_local, grp_remote = group_concurrent_contiguous(
local_block_ids, remote_block_ids
)
for gl, gr in zip(grp_local, grp_remote):
src_ptrs.append(remote_layer_addr + gr[0] * remote_block_len)
dst_ptrs.append(local_layer_addr + gl[0] * self.block_len)
lengths.append(remote_block_len * len(gr))
# 3. RDMA READ (D pulls from C's GPU memory)
logger.debug(
"direct_read %s: reading %d blocks from %s",
req_id, num_matched, remote_session,
)
ret = self.engine.batch_transfer_sync_read(
remote_session, src_ptrs, dst_ptrs, lengths,
)
if ret != 0:
logger.error("direct_read %s: RDMA read failed ret=%d", req_id, ret)
return
logger.debug("direct_read %s: success (%d blocks)", req_id, num_matched)
self.finished_recving_reqs.add(req_id)
except Exception as e:
logger.error("direct_read %s failed: %s", req_id, e)
finally:
# 4. Unpin blocks on C
try:
async with httpx.AsyncClient(timeout=10) as client:
await client.post(
f"{bootstrap_url}/unpin_blocks",
json={"pin_token": pin_token},
)
except Exception:
pass
logger.error("direct_push %s failed: %s", req_id, e)
self.finished_recving_reqs.add(req_id)
async def _start_load_kv(
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]

View File

@@ -47,6 +47,20 @@ class QueryBlocksResponse(BaseModel):
rpc_port: int
class PushBlocksRequest(BaseModel):
token_ids: list[int]
num_tokens: int
dst_block_ids: list[int] # D's allocated block IDs for receiving
dst_base_addrs: list[int] # D's kv_caches_base_addr
dst_block_len: int # D's block_len
dst_session: str # D's "hostname:rpc_port" for RDMA write
class PushBlocksResponse(BaseModel):
matched: int
pushed: bool
class UnpinBlocksRequest(BaseModel):
pin_token: str
@@ -83,6 +97,7 @@ class MooncakeBootstrapServer:
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
self.app.post("/query_blocks")(self.query_blocks)
self.app.post("/unpin_blocks")(self.unpin_blocks)
self.app.post("/push_blocks")(self.push_blocks)
def start(self):
if self.server_thread:
@@ -161,6 +176,7 @@ class MooncakeBootstrapServer:
block_size: int,
hostname: str,
rpc_port: int,
transfer_engine=None,
):
self._kv_info = {
"kv_caches_base_addr": kv_caches_base_addr,
@@ -169,6 +185,7 @@ class MooncakeBootstrapServer:
"hostname": hostname,
"rpc_port": rpc_port,
}
self._transfer_engine = transfer_engine
def update_hash_table(
self,
@@ -268,3 +285,41 @@ class MooncakeBootstrapServer:
async def unpin_blocks(self, req: UnpinBlocksRequest):
self._pinned.pop(req.pin_token, None)
return {"status": "ok"}
async def push_blocks(self, req: PushBlocksRequest):
"""Query matching blocks by token_ids, then PUSH them to D via RDMA write."""
if self._kv_info is None or self._transfer_engine is None:
raise HTTPException(503, "Worker not ready")
block_ids, _ = self._lookup_by_tokens(req.token_ids, req.num_tokens)
matched_src = [b for b in block_ids if b is not None]
num_matched = len(matched_src)
if num_matched == 0:
logger.info("push_blocks: 0 matched")
return PushBlocksResponse(matched=0, pushed=False)
matched_dst = req.dst_block_ids[:num_matched]
src_base = self._kv_info["kv_caches_base_addr"]
src_block_len = self._kv_info["block_len"]
src_ptrs: list[int] = []
dst_ptrs: list[int] = []
lengths: list[int] = []
for src_layer, dst_layer in zip(src_base, req.dst_base_addrs):
for s_bid, d_bid in zip(matched_src, matched_dst):
src_ptrs.append(src_layer + s_bid * src_block_len)
dst_ptrs.append(dst_layer + d_bid * req.dst_block_len)
lengths.append(src_block_len)
import asyncio
loop = asyncio.get_event_loop()
ret = await loop.run_in_executor(
None,
self._transfer_engine.batch_transfer_sync_write,
req.dst_session, src_ptrs, dst_ptrs, lengths,
)
logger.info("push_blocks: %d matched, push ret=%d", num_matched, ret)
return PushBlocksResponse(matched=num_matched, pushed=(ret == 0))