diff --git a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index 8f4638b..eda34a4 100644 --- a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -297,34 +297,26 @@ class MooncakeBootstrapServer: return {"status": "ok"} async def estimate_hit(self, req: EstimateHitRequest): - """Read-only probe: how many prefix-contiguous tokens are cached?""" + """Read-only probe: how many prefix-contiguous tokens are cached? + + Reuses _lookup_by_tokens (proven to work with push_blocks) instead + of reimplementing hash computation. + """ if self._kv_info is None: raise HTTPException(503, "Worker KV info not registered yet") - block_size = req.block_size or self._kv_info.get("block_size", 512) - n_tokens = len(req.token_ids) - num_blocks = n_tokens // block_size - if num_blocks == 0 or not self._hash_table: + if not self._hash_table: return EstimateHitResponse(hit_tokens=0) - import vllm.v1.core.kv_cache_utils as kv_utils - from vllm.utils.hashing import sha256 + block_ids, _ = self._lookup_by_tokens(req.token_ids, None) + hit_blocks = sum(1 for b in block_ids if b is not None) + block_size = self._kv_info.get("block_size", 512) + hit_tokens = hit_blocks * block_size - prev_hash = kv_utils.NONE_HASH - hit_blocks = 0 - for i in range(num_blocks): - block_tokens = tuple( - req.token_ids[i * block_size:(i + 1) * block_size]) - block_hash = kv_utils.hash_block_tokens( - sha256, prev_hash, block_tokens, None) - prev_hash = block_hash - - if self._hash_table.get(block_hash.hex()) is not None: - hit_blocks += 1 - else: - break - - return EstimateHitResponse(hit_tokens=hit_blocks * block_size) + logger.info("estimate_hit: %d/%d blocks hit (%d tokens, tbl=%d)", + hit_blocks, len(req.token_ids) // block_size, + hit_tokens, len(self._hash_table)) + return EstimateHitResponse(hit_tokens=hit_tokens) async def push_blocks(self, req: PushBlocksRequest): """Query matching blocks by token_ids, then PUSH them to D via RDMA write."""