Fix hash mismatch: token-based lookup instead of cross-instance hash matching

Root cause: each vLLM instance has a random NONE_HASH (os.urandom(32))
when PYTHONHASHSEED is not set. All block hashes are chained from
NONE_HASH, so D's hashes never match C's hashes.

Fix: C's bootstrap server now accepts token_ids and does the prefix
cache lookup locally using C's own hash function and block pool.
No cross-instance hash matching needed.

New flow: D sends prompt token_ids → C computes hashes on C's side →
C looks up in C's own BlockPool → returns block_ids.

Also: module-level _shared_block_pool for scheduler→bootstrap bridge,
prompt_token_ids passed through PullReqMeta, test script added.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 01:14:33 +08:00
parent a1f30e5fce
commit 0500350849
3 changed files with 300 additions and 23 deletions

View File

@@ -65,6 +65,13 @@ TransferId = str # KV transfer coordination ID (shared by P/D)
logger = init_logger(__name__)
# Module-level block pool for bootstrap server access (kv_both same-process only)
_shared_block_pool = None
def _set_shared_block_pool(bp):
global _shared_block_pool
_shared_block_pool = bp
class MooncakeXferMetadata(
msgspec.Struct,
@@ -111,6 +118,8 @@ class PullReqMeta:
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
direct_read: bool = False
block_hashes: list[bytes] = field(default_factory=list)
prompt_token_ids: list[int] = field(default_factory=list)
remote_num_tokens: int = 0
@dataclass
@@ -141,10 +150,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
block_hashes: list[bytes] | None = None,
prompt_token_ids: list[int] | None = None,
):
transfer_id = kv_transfer_params["transfer_id"]
if load_remote_cache:
remote_engine_id = kv_transfer_params["remote_engine_id"]
remote_num = kv_transfer_params.get("remote_num_tokens", 0)
self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta(
d_req_id=request_id,
local_block_ids=local_block_ids,
@@ -153,6 +164,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
transfer_id=transfer_id,
direct_read=bool(kv_transfer_params.get("direct_read")),
block_hashes=block_hashes or [],
prompt_token_ids=prompt_token_ids or [],
remote_num_tokens=remote_num,
)
else:
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
@@ -183,6 +196,8 @@ class MooncakeConnector(KVConnectorBase_V1):
def set_block_pool(self, block_pool):
if self.connector_scheduler is not None:
self.connector_scheduler.set_block_pool(block_pool)
# Also store module-level for bootstrap server access (same process for kv_both TP=1)
_set_shared_block_pool(block_pool)
############################################################
# Scheduler Side Methods
@@ -276,8 +291,8 @@ class MooncakeConnectorScheduler:
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_not_processed: set[TransferId] = set()
# Per-request block hashes for direct_read passthrough
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
self._req_token_ids: dict[ReqId, list[int]] = {}
def set_block_pool(self, block_pool):
self._block_pool = block_pool
@@ -361,15 +376,20 @@ class MooncakeConnectorScheduler:
else:
local_block_ids = []
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
# Capture block hashes for direct_read
if params.get("direct_read") and hasattr(request, "block_hashes"):
if params.get("direct_read"):
block_size = self.vllm_config.cache_config.block_size
num_remote_blocks = (
(num_external_tokens + block_size - 1) // block_size
)
self._req_block_hashes[request.request_id] = [
bytes(h) for h in request.block_hashes[:num_remote_blocks]
]
if hasattr(request, "block_hashes"):
self._req_block_hashes[request.request_id] = [
bytes(h) for h in request.block_hashes[:num_remote_blocks]
]
# Store prompt token_ids for token-based lookup on C
if hasattr(request, "prompt_token_ids") and request.prompt_token_ids:
self._req_token_ids[request.request_id] = list(
request.prompt_token_ids[:num_remote_blocks * block_size]
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
@@ -400,6 +420,7 @@ class MooncakeConnectorScheduler:
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
block_hashes=self._req_block_hashes.pop(req_id, None),
prompt_token_ids=self._req_token_ids.pop(req_id, None),
)
self._reqs_need_recv.clear()
@@ -1030,8 +1051,10 @@ class MooncakeConnectorWorker:
if self.bootstrap_server is not None:
self.bootstrap_server.set_worker_kv_info(
self.kv_caches_base_addr, self.block_len,
self.hostname, self.rpc_port,
self.block_size, self.hostname, self.rpc_port,
)
if _shared_block_pool is not None:
self.bootstrap_server.set_block_pool(_shared_block_pool)
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
finished_recving_reqs = self.finished_recving_reqs
@@ -1266,14 +1289,18 @@ class MooncakeConnectorWorker:
bootstrap_url = pm.remote_bootstrap_addr
try:
# 1. Query C's bootstrap for block mapping
# 1. Query C's bootstrap for block mapping (token-based lookup)
query_payload = {"pin_token": pin_token}
if pm.prompt_token_ids:
query_payload["token_ids"] = pm.prompt_token_ids
query_payload["num_tokens"] = pm.remote_num_tokens or len(pm.prompt_token_ids)
else:
query_payload["block_hashes"] = [h.hex() for h in pm.block_hashes]
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{bootstrap_url}/query_blocks",
json={
"block_hashes": [h.hex() for h in pm.block_hashes],
"pin_token": pin_token,
},
json=query_payload,
)
resp.raise_for_status()
mapping = resp.json()

View File

@@ -33,7 +33,9 @@ class EngineEntry:
class QueryBlocksRequest(BaseModel):
block_hashes: list[str] # hex-encoded BlockHash values
block_hashes: list[str] | None = None # hex-encoded BlockHash values (legacy)
token_ids: list[int] | None = None # raw token IDs for C-side hash lookup
num_tokens: int | None = None # number of tokens to match prefix for
pin_token: str
@@ -66,10 +68,11 @@ class MooncakeBootstrapServer:
self.server_thread: threading.Thread | None = None
self.server: uvicorn.Server | None = None
# Direct RDMA read support: block hash → block_id mapping
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id
# Direct RDMA read support
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id (legacy)
self._kv_info: dict | None = None # set by worker at register_kv_caches
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
self._block_pool = None # set by scheduler for token-based lookup
def __del__(self):
self.shutdown()
@@ -154,12 +157,14 @@ class MooncakeBootstrapServer:
self,
kv_caches_base_addr: list[int],
block_len: int,
block_size: int,
hostname: str,
rpc_port: int,
):
self._kv_info = {
"kv_caches_base_addr": kv_caches_base_addr,
"block_len": block_len,
"block_size": block_size,
"hostname": hostname,
"rpc_port": rpc_port,
}
@@ -173,20 +178,40 @@ class MooncakeBootstrapServer:
self._hash_table.pop(k, None)
self._hash_table.update(updates)
def set_block_pool(self, block_pool):
"""Store reference to scheduler's block pool for token-based lookup."""
self._block_pool = block_pool
async def query_blocks(self, req: QueryBlocksRequest):
if self._kv_info is None:
raise HTTPException(503, "Worker KV info not registered yet")
block_ids: list[int | None] = []
pinned_ids: list[int] = []
for h in req.block_hashes:
bid = self._hash_table.get(h)
if bid is not None:
block_ids.append(bid)
pinned_ids.append(bid)
else:
block_ids.append(None)
break # prefix match: stop at first miss
if req.token_ids is not None and self._block_pool is not None:
# Token-based lookup: compute hashes on C's side using C's hash function
block_ids, pinned_ids = self._lookup_by_tokens(
req.token_ids, req.num_tokens)
elif req.block_hashes is not None:
# Hash-based lookup (legacy)
for h in req.block_hashes:
bid = self._hash_table.get(h)
if bid is not None:
block_ids.append(bid)
pinned_ids.append(bid)
else:
block_ids.append(None)
break
logger.info(
"query_blocks: %d/%d matched (token_mode=%s, pool=%s)",
len(pinned_ids),
req.num_tokens // self._kv_info.get("block_size", 512)
if req.num_tokens else len(req.block_hashes or []),
req.token_ids is not None,
self._block_pool is not None,
)
self._pinned[req.pin_token] = pinned_ids
return QueryBlocksResponse(
@@ -197,6 +222,55 @@ class MooncakeBootstrapServer:
rpc_port=self._kv_info["rpc_port"],
)
def _lookup_by_tokens(
self, token_ids: list[int], num_tokens: int | None
) -> tuple[list[int | None], list[int]]:
"""Look up cached blocks by computing hashes on C's side."""
from vllm.v1.core.kv_cache_utils import (
hash_block_tokens,
make_block_hash_with_group_id,
NONE_HASH,
)
bp = self._block_pool
block_size = bp.block_size if hasattr(bp, 'block_size') else 512
# Use C's own hash function
hash_fn = bp._hash_fn if hasattr(bp, '_hash_fn') else None
if hash_fn is None:
# Fallback: try to get from coordinator
coord = bp.coordinator if hasattr(bp, 'coordinator') else None
if coord and hasattr(coord, '_caching_hash_fn'):
hash_fn = coord._caching_hash_fn
if hash_fn is None:
logger.warning("Cannot find hash function on block pool")
return [], []
n = num_tokens or len(token_ids)
n = min(n, len(token_ids))
num_blocks = n // block_size
block_ids: list[int | None] = []
pinned_ids: list[int] = []
prev_hash = NONE_HASH
for i in range(num_blocks):
block_tokens = tuple(token_ids[i * block_size:(i + 1) * block_size])
block_hash = hash_block_tokens(hash_fn, prev_hash, block_tokens, None)
prev_hash = block_hash
# Look up in C's block pool
key = make_block_hash_with_group_id(block_hash, 0)
block = bp.cached_block_hash_to_block.get_one_block(key)
if block is not None:
block_ids.append(block.block_id)
pinned_ids.append(block.block_id)
else:
block_ids.append(None)
break # prefix match stops
return block_ids, pinned_ids
async def unpin_blocks(self, req: UnpinBlocksRequest):
self._pinned.pop(req.pin_token, None)
return {"status": "ok"}