Fix: use synced hash table + sha256_cbor for token-based lookup (same process NONE_HASH)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-24 01:18:47 +08:00
parent 0500350849
commit 0c88609caa

View File

@@ -225,27 +225,19 @@ class MooncakeBootstrapServer:
def _lookup_by_tokens(
self, token_ids: list[int], num_tokens: int | None
) -> tuple[list[int | None], list[int]]:
"""Look up cached blocks by computing hashes on C's side."""
"""Look up cached blocks by computing hashes using the synced hash table.
Computes block hashes from token_ids using the same hash function
as the scheduler, then looks up in the synced _hash_table.
"""
from vllm.v1.core.kv_cache_utils import (
BlockHash,
hash_block_tokens,
make_block_hash_with_group_id,
NONE_HASH,
)
from vllm.utils.hashing import sha256_cbor
bp = self._block_pool
block_size = bp.block_size if hasattr(bp, 'block_size') else 512
# Use C's own hash function
hash_fn = bp._hash_fn if hasattr(bp, '_hash_fn') else None
if hash_fn is None:
# Fallback: try to get from coordinator
coord = bp.coordinator if hasattr(bp, 'coordinator') else None
if coord and hasattr(coord, '_caching_hash_fn'):
hash_fn = coord._caching_hash_fn
if hash_fn is None:
logger.warning("Cannot find hash function on block pool")
return [], []
block_size = self._kv_info.get("block_size", 512) if self._kv_info else 512
n = num_tokens or len(token_ids)
n = min(n, len(token_ids))
num_blocks = n // block_size
@@ -256,18 +248,16 @@ class MooncakeBootstrapServer:
for i in range(num_blocks):
block_tokens = tuple(token_ids[i * block_size:(i + 1) * block_size])
block_hash = hash_block_tokens(hash_fn, prev_hash, block_tokens, None)
block_hash = hash_block_tokens(sha256_cbor, prev_hash, block_tokens, None)
prev_hash = block_hash
# Look up in C's block pool
key = make_block_hash_with_group_id(block_hash, 0)
block = bp.cached_block_hash_to_block.get_one_block(key)
if block is not None:
block_ids.append(block.block_id)
pinned_ids.append(block.block_id)
bid = self._hash_table.get(block_hash.hex())
if bid is not None:
block_ids.append(bid)
pinned_ids.append(bid)
else:
block_ids.append(None)
break # prefix match stops
break
return block_ids, pinned_ids