Switch from RDMA READ to bootstrap-triggered PUSH
RDMA READ (batch_transfer_sync_read) fails on GPU memory because batch_register_memory only sets IBV_ACCESS_REMOTE_WRITE. New approach: D sends /push_blocks to C's bootstrap with token_ids + D's GPU addresses. C's bootstrap: 1. Looks up matching blocks in synced hash table (640/640 verified) 2. Uses C's TransferEngine.batch_transfer_sync_write to PUSH blocks directly into D's GPU memory 3. Returns match count + push status C's scheduler is still NOT involved (0 GPU compute on C). The push uses C's worker thread + existing RDMA WRITE path (proven reliable). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1053,6 +1053,7 @@ class MooncakeConnectorWorker:
|
||||
self.bootstrap_server.set_worker_kv_info(
|
||||
self.kv_caches_base_addr, self.block_len,
|
||||
self.block_size, self.hostname, self.rpc_port,
|
||||
transfer_engine=self.engine,
|
||||
)
|
||||
if _shared_block_pool is not None:
|
||||
self.bootstrap_server.set_block_pool(_shared_block_pool)
|
||||
@@ -1286,83 +1287,47 @@ class MooncakeConnectorWorker:
|
||||
)
|
||||
|
||||
async def _direct_read_single(self, req_id: ReqId, pm: PullReqMeta):
|
||||
pin_token = f"dr-{req_id}-{self.tp_rank}"
|
||||
"""Bootstrap-triggered PUSH: D asks C's bootstrap to push matched blocks.
|
||||
|
||||
C's bootstrap looks up cached blocks by token_ids, then uses C's
|
||||
TransferEngine to RDMA WRITE (push) them directly into D's GPU memory.
|
||||
C's scheduler is NOT involved.
|
||||
"""
|
||||
bootstrap_url = pm.remote_bootstrap_addr
|
||||
num_remote_tokens = pm.remote_num_tokens or len(pm.prompt_token_ids)
|
||||
|
||||
try:
|
||||
# 1. Query C's bootstrap for block mapping (token-based lookup)
|
||||
query_payload = {"pin_token": pin_token}
|
||||
if pm.prompt_token_ids:
|
||||
query_payload["token_ids"] = pm.prompt_token_ids
|
||||
query_payload["num_tokens"] = pm.remote_num_tokens or len(pm.prompt_token_ids)
|
||||
else:
|
||||
query_payload["block_hashes"] = [h.hex() for h in pm.block_hashes]
|
||||
local_block_ids = pm.local_block_ids
|
||||
d_session = f"{self.hostname}:{self.rpc_port}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
f"{bootstrap_url}/query_blocks",
|
||||
json=query_payload,
|
||||
f"{bootstrap_url}/push_blocks",
|
||||
json={
|
||||
"token_ids": pm.prompt_token_ids,
|
||||
"num_tokens": num_remote_tokens,
|
||||
"dst_block_ids": local_block_ids,
|
||||
"dst_base_addrs": self.kv_caches_base_addr,
|
||||
"dst_block_len": self.block_len,
|
||||
"dst_session": d_session,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
mapping = resp.json()
|
||||
result = resp.json()
|
||||
|
||||
remote_block_ids = [b for b in mapping["block_ids"] if b is not None]
|
||||
num_matched = len(remote_block_ids)
|
||||
matched = result.get("matched", 0)
|
||||
pushed = result.get("pushed", False)
|
||||
|
||||
if num_matched == 0:
|
||||
logger.debug("direct_read %s: no cache hit on remote", req_id)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
return
|
||||
if matched > 0 and pushed:
|
||||
logger.info("direct_push %s: %d blocks pushed from C", req_id, matched)
|
||||
else:
|
||||
logger.debug("direct_push %s: %d matched, pushed=%s", req_id, matched, pushed)
|
||||
|
||||
# 2. Compute RDMA addresses
|
||||
local_block_ids = pm.local_block_ids[:num_matched]
|
||||
remote_base_addrs = mapping["kv_caches_base_addr"]
|
||||
remote_block_len = mapping["block_len"]
|
||||
remote_session = f"{mapping['hostname']}:{mapping['rpc_port']}"
|
||||
|
||||
src_ptrs: list[int] = []
|
||||
dst_ptrs: list[int] = []
|
||||
lengths: list[int] = []
|
||||
|
||||
for local_layer_addr, remote_layer_addr in zip(
|
||||
self.kv_caches_base_addr, remote_base_addrs
|
||||
):
|
||||
grp_local, grp_remote = group_concurrent_contiguous(
|
||||
local_block_ids, remote_block_ids
|
||||
)
|
||||
for gl, gr in zip(grp_local, grp_remote):
|
||||
src_ptrs.append(remote_layer_addr + gr[0] * remote_block_len)
|
||||
dst_ptrs.append(local_layer_addr + gl[0] * self.block_len)
|
||||
lengths.append(remote_block_len * len(gr))
|
||||
|
||||
# 3. RDMA READ (D pulls from C's GPU memory)
|
||||
logger.debug(
|
||||
"direct_read %s: reading %d blocks from %s",
|
||||
req_id, num_matched, remote_session,
|
||||
)
|
||||
ret = self.engine.batch_transfer_sync_read(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths,
|
||||
)
|
||||
|
||||
if ret != 0:
|
||||
logger.error("direct_read %s: RDMA read failed ret=%d", req_id, ret)
|
||||
return
|
||||
|
||||
logger.debug("direct_read %s: success (%d blocks)", req_id, num_matched)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("direct_read %s failed: %s", req_id, e)
|
||||
finally:
|
||||
# 4. Unpin blocks on C
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{bootstrap_url}/unpin_blocks",
|
||||
json={"pin_token": pin_token},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error("direct_push %s failed: %s", req_id, e)
|
||||
self.finished_recving_reqs.add(req_id)
|
||||
|
||||
async def _start_load_kv(
|
||||
self, reqs_to_recv: dict[EngineId, dict[ReqId, PullReqMeta]]
|
||||
|
||||
@@ -47,6 +47,20 @@ class QueryBlocksResponse(BaseModel):
|
||||
rpc_port: int
|
||||
|
||||
|
||||
class PushBlocksRequest(BaseModel):
|
||||
token_ids: list[int]
|
||||
num_tokens: int
|
||||
dst_block_ids: list[int] # D's allocated block IDs for receiving
|
||||
dst_base_addrs: list[int] # D's kv_caches_base_addr
|
||||
dst_block_len: int # D's block_len
|
||||
dst_session: str # D's "hostname:rpc_port" for RDMA write
|
||||
|
||||
|
||||
class PushBlocksResponse(BaseModel):
|
||||
matched: int
|
||||
pushed: bool
|
||||
|
||||
|
||||
class UnpinBlocksRequest(BaseModel):
|
||||
pin_token: str
|
||||
|
||||
@@ -83,6 +97,7 @@ class MooncakeBootstrapServer:
|
||||
self.app.get("/query", response_model=dict[int, EngineEntry])(self.query)
|
||||
self.app.post("/query_blocks")(self.query_blocks)
|
||||
self.app.post("/unpin_blocks")(self.unpin_blocks)
|
||||
self.app.post("/push_blocks")(self.push_blocks)
|
||||
|
||||
def start(self):
|
||||
if self.server_thread:
|
||||
@@ -161,6 +176,7 @@ class MooncakeBootstrapServer:
|
||||
block_size: int,
|
||||
hostname: str,
|
||||
rpc_port: int,
|
||||
transfer_engine=None,
|
||||
):
|
||||
self._kv_info = {
|
||||
"kv_caches_base_addr": kv_caches_base_addr,
|
||||
@@ -169,6 +185,7 @@ class MooncakeBootstrapServer:
|
||||
"hostname": hostname,
|
||||
"rpc_port": rpc_port,
|
||||
}
|
||||
self._transfer_engine = transfer_engine
|
||||
|
||||
def update_hash_table(
|
||||
self,
|
||||
@@ -268,3 +285,41 @@ class MooncakeBootstrapServer:
|
||||
async def unpin_blocks(self, req: UnpinBlocksRequest):
|
||||
self._pinned.pop(req.pin_token, None)
|
||||
return {"status": "ok"}
|
||||
|
||||
async def push_blocks(self, req: PushBlocksRequest):
|
||||
"""Query matching blocks by token_ids, then PUSH them to D via RDMA write."""
|
||||
if self._kv_info is None or self._transfer_engine is None:
|
||||
raise HTTPException(503, "Worker not ready")
|
||||
|
||||
block_ids, _ = self._lookup_by_tokens(req.token_ids, req.num_tokens)
|
||||
matched_src = [b for b in block_ids if b is not None]
|
||||
num_matched = len(matched_src)
|
||||
|
||||
if num_matched == 0:
|
||||
logger.info("push_blocks: 0 matched")
|
||||
return PushBlocksResponse(matched=0, pushed=False)
|
||||
|
||||
matched_dst = req.dst_block_ids[:num_matched]
|
||||
src_base = self._kv_info["kv_caches_base_addr"]
|
||||
src_block_len = self._kv_info["block_len"]
|
||||
|
||||
src_ptrs: list[int] = []
|
||||
dst_ptrs: list[int] = []
|
||||
lengths: list[int] = []
|
||||
|
||||
for src_layer, dst_layer in zip(src_base, req.dst_base_addrs):
|
||||
for s_bid, d_bid in zip(matched_src, matched_dst):
|
||||
src_ptrs.append(src_layer + s_bid * src_block_len)
|
||||
dst_ptrs.append(dst_layer + d_bid * req.dst_block_len)
|
||||
lengths.append(src_block_len)
|
||||
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
ret = await loop.run_in_executor(
|
||||
None,
|
||||
self._transfer_engine.batch_transfer_sync_write,
|
||||
req.dst_session, src_ptrs, dst_ptrs, lengths,
|
||||
)
|
||||
|
||||
logger.info("push_blocks: %d matched, push ret=%d", num_matched, ret)
|
||||
return PushBlocksResponse(matched=num_matched, pushed=(ret == 0))
|
||||
|
||||
Reference in New Issue
Block a user