Fix hash mismatch: token-based lookup instead of cross-instance hash matching
Root cause: each vLLM instance has a random NONE_HASH (os.urandom(32)) when PYTHONHASHSEED is not set. All block hashes are chained from NONE_HASH, so D's hashes never match C's hashes. Fix: C's bootstrap server now accepts token_ids and does the prefix cache lookup locally using C's own hash function and block pool. No cross-instance hash matching needed. New flow: D sends prompt token_ids → C computes hashes on C's side → C looks up in C's own BlockPool → returns block_ids. Also: module-level _shared_block_pool for scheduler→bootstrap bridge, prompt_token_ids passed through PullReqMeta, test script added. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
176
scripts/test_direct_read.py
Normal file
176
scripts/test_direct_read.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""Minimal test: verify direct RDMA read hash matching.
|
||||||
|
|
||||||
|
1. Send a multi-turn session to inst_0 (builds cache)
|
||||||
|
2. Query inst_0's bootstrap /query_blocks with computed block hashes
|
||||||
|
3. Check if hashes match (the core question)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Start 2 elastic instances first, then:
|
||||||
|
python scripts/test_direct_read.py --port0 8000 --bp0 8998 --port1 8001 --bp1 8999
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
VOCAB_SIZE = 151936
|
||||||
|
TOKEN_RANGE_START = 100
|
||||||
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
||||||
|
|
||||||
|
|
||||||
|
def make_prompt(seed: int, n_blocks: int) -> list[int]:
|
||||||
|
"""Deterministic prompt from seed, like the replayer does."""
|
||||||
|
rng = random.Random(seed)
|
||||||
|
return [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END)
|
||||||
|
for _ in range(BLOCK_SIZE * n_blocks)]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--port0", type=int, default=8000)
|
||||||
|
p.add_argument("--bp0", type=int, default=8998)
|
||||||
|
p.add_argument("--port1", type=int, default=8001)
|
||||||
|
p.add_argument("--bp1", type=int, default=8999)
|
||||||
|
p.add_argument("--model", type=str,
|
||||||
|
default="/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
client = httpx.Client(timeout=120)
|
||||||
|
base0 = f"http://127.0.0.1:{args.port0}"
|
||||||
|
base1 = f"http://127.0.0.1:{args.port1}"
|
||||||
|
bp0 = f"http://127.0.0.1:{args.bp0}"
|
||||||
|
bp1 = f"http://127.0.0.1:{args.bp1}"
|
||||||
|
|
||||||
|
# Step 1: Send request to inst_0 to build cache
|
||||||
|
prompt = make_prompt(seed=42, n_blocks=20) # 10240 tokens
|
||||||
|
print(f"[1] Sending {len(prompt)} tokens to inst_0...")
|
||||||
|
|
||||||
|
resp = client.post(f"{base0}/v1/completions", json={
|
||||||
|
"model": args.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 0,
|
||||||
|
})
|
||||||
|
resp.raise_for_status()
|
||||||
|
print(f" OK: {resp.json()['choices'][0]['text'][:20]}...")
|
||||||
|
|
||||||
|
# Wait for hash table sync (happens in scheduler step)
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
# Step 2: Query inst_0's bootstrap for its hash table size
|
||||||
|
print(f"\n[2] Querying inst_0 bootstrap /query endpoint...")
|
||||||
|
resp = client.get(f"{bp0}/query")
|
||||||
|
resp.raise_for_status()
|
||||||
|
query_data = resp.json()
|
||||||
|
print(f" Bootstrap has {len(query_data)} dp_rank entries")
|
||||||
|
|
||||||
|
# Step 3: Compute block hashes the way D would
|
||||||
|
# D's scheduler uses request.block_hashes which is computed by
|
||||||
|
# vLLM's block hasher. We can't easily replicate that here.
|
||||||
|
# Instead, let's send the SAME prompt to inst_1 with direct_read=True
|
||||||
|
# and see what happens.
|
||||||
|
|
||||||
|
# First, let's directly test the /query_blocks endpoint
|
||||||
|
# with some known hashes. We need to know what hashes inst_0 has.
|
||||||
|
|
||||||
|
# Try querying with dummy hashes to see the response format
|
||||||
|
print(f"\n[3] Testing /query_blocks with dummy hashes...")
|
||||||
|
resp = client.post(f"{bp0}/query_blocks", json={
|
||||||
|
"block_hashes": ["0000000000000000"],
|
||||||
|
"pin_token": "test-1",
|
||||||
|
})
|
||||||
|
resp.raise_for_status()
|
||||||
|
result = resp.json()
|
||||||
|
print(f" Response: {json.dumps(result, indent=2)}")
|
||||||
|
|
||||||
|
# Unpin
|
||||||
|
client.post(f"{bp0}/unpin_blocks", json={"pin_token": "test-1"})
|
||||||
|
|
||||||
|
# Step 4: Send same prompt to inst_1 with do_remote_prefill + direct_read
|
||||||
|
# This triggers D's scheduler to compute block_hashes and the worker
|
||||||
|
# to query C's bootstrap
|
||||||
|
print(f"\n[4] Sending same prompt to inst_1 with direct_read...")
|
||||||
|
|
||||||
|
# Get inst_0's engine_id from bootstrap
|
||||||
|
engine_id = query_data.get("0", {}).get("engine_id", "")
|
||||||
|
print(f" inst_0 engine_id: {engine_id}")
|
||||||
|
|
||||||
|
resp = client.post(f"{base1}/v1/completions", json={
|
||||||
|
"model": args.model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 0,
|
||||||
|
"kv_transfer_params": {
|
||||||
|
"do_remote_decode": False,
|
||||||
|
"do_remote_prefill": True,
|
||||||
|
"direct_read": True,
|
||||||
|
"remote_bootstrap_addr": bp0,
|
||||||
|
"remote_engine_id": engine_id,
|
||||||
|
"transfer_id": "test-xfer-001",
|
||||||
|
"remote_num_tokens": len(prompt),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
print(f" Status: {resp.status_code}")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
print(f" Output: {resp.json()['choices'][0]['text'][:50]}...")
|
||||||
|
else:
|
||||||
|
print(f" Error: {resp.text[:200]}")
|
||||||
|
|
||||||
|
# Step 5: Check logs for hash matching
|
||||||
|
print(f"\n[5] Check vLLM logs for direct_read activity:")
|
||||||
|
print(f" grep 'direct_read\\|query_blocks\\|hash_table_sync\\|no cache hit' inst_*.log")
|
||||||
|
|
||||||
|
# Step 6: Send turn 2 (extended prompt) to verify prefix caching
|
||||||
|
prompt2 = prompt + make_prompt(seed=43, n_blocks=5) # extend by 2560 tokens
|
||||||
|
print(f"\n[6] Sending turn 2 ({len(prompt2)} tokens) to inst_0...")
|
||||||
|
t0 = time.time()
|
||||||
|
resp = client.post(f"{base0}/v1/completions", json={
|
||||||
|
"model": args.model,
|
||||||
|
"prompt": prompt2,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 0,
|
||||||
|
})
|
||||||
|
resp.raise_for_status()
|
||||||
|
ttft = time.time() - t0
|
||||||
|
print(f" TTFT: {ttft:.3f}s (should be fast if prefix cached)")
|
||||||
|
|
||||||
|
# Now send turn 2 to inst_1 with direct_read for turn 1's cache
|
||||||
|
print(f"\n[7] Sending turn 2 to inst_1 with direct_read (remote_num_tokens={len(prompt)})...")
|
||||||
|
t0 = time.time()
|
||||||
|
resp = client.post(f"{base1}/v1/completions", json={
|
||||||
|
"model": args.model,
|
||||||
|
"prompt": prompt2,
|
||||||
|
"max_tokens": 1,
|
||||||
|
"temperature": 0,
|
||||||
|
"kv_transfer_params": {
|
||||||
|
"do_remote_decode": False,
|
||||||
|
"do_remote_prefill": True,
|
||||||
|
"direct_read": True,
|
||||||
|
"remote_bootstrap_addr": bp0,
|
||||||
|
"remote_engine_id": engine_id,
|
||||||
|
"transfer_id": "test-xfer-002",
|
||||||
|
"remote_num_tokens": len(prompt), # only first 10240 from remote
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ttft1 = time.time() - t0
|
||||||
|
print(f" Status: {resp.status_code}")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
print(f" TTFT: {ttft1:.3f}s")
|
||||||
|
print(f" Output: {resp.json()['choices'][0]['text'][:50]}...")
|
||||||
|
else:
|
||||||
|
print(f" Error: {resp.text[:200]}")
|
||||||
|
|
||||||
|
print(f"\n=== Summary ===")
|
||||||
|
print(f"Turn 1 on inst_0: OK")
|
||||||
|
print(f"Turn 2 on inst_0 (cached): TTFT={ttft:.3f}s")
|
||||||
|
print(f"Turn 2 on inst_1 (direct_read): TTFT={ttft1:.3f}s")
|
||||||
|
print(f"If direct_read works: inst_1 TTFT ≈ inst_0 TTFT (both have cache)")
|
||||||
|
print(f"If direct_read broken: inst_1 TTFT >> inst_0 TTFT (cold prefill)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -65,6 +65,13 @@ TransferId = str # KV transfer coordination ID (shared by P/D)
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Module-level block pool for bootstrap server access (kv_both same-process only)
|
||||||
|
_shared_block_pool = None
|
||||||
|
|
||||||
|
def _set_shared_block_pool(bp):
|
||||||
|
global _shared_block_pool
|
||||||
|
_shared_block_pool = bp
|
||||||
|
|
||||||
|
|
||||||
class MooncakeXferMetadata(
|
class MooncakeXferMetadata(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
@@ -111,6 +118,8 @@ class PullReqMeta:
|
|||||||
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
# Direct RDMA read: D reads from C's GPU memory without C's scheduler
|
||||||
direct_read: bool = False
|
direct_read: bool = False
|
||||||
block_hashes: list[bytes] = field(default_factory=list)
|
block_hashes: list[bytes] = field(default_factory=list)
|
||||||
|
prompt_token_ids: list[int] = field(default_factory=list)
|
||||||
|
remote_num_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -141,10 +150,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|||||||
kv_transfer_params: dict[str, Any],
|
kv_transfer_params: dict[str, Any],
|
||||||
load_remote_cache: bool = True,
|
load_remote_cache: bool = True,
|
||||||
block_hashes: list[bytes] | None = None,
|
block_hashes: list[bytes] | None = None,
|
||||||
|
prompt_token_ids: list[int] | None = None,
|
||||||
):
|
):
|
||||||
transfer_id = kv_transfer_params["transfer_id"]
|
transfer_id = kv_transfer_params["transfer_id"]
|
||||||
if load_remote_cache:
|
if load_remote_cache:
|
||||||
remote_engine_id = kv_transfer_params["remote_engine_id"]
|
remote_engine_id = kv_transfer_params["remote_engine_id"]
|
||||||
|
remote_num = kv_transfer_params.get("remote_num_tokens", 0)
|
||||||
self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta(
|
self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta(
|
||||||
d_req_id=request_id,
|
d_req_id=request_id,
|
||||||
local_block_ids=local_block_ids,
|
local_block_ids=local_block_ids,
|
||||||
@@ -153,6 +164,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|||||||
transfer_id=transfer_id,
|
transfer_id=transfer_id,
|
||||||
direct_read=bool(kv_transfer_params.get("direct_read")),
|
direct_read=bool(kv_transfer_params.get("direct_read")),
|
||||||
block_hashes=block_hashes or [],
|
block_hashes=block_hashes or [],
|
||||||
|
prompt_token_ids=prompt_token_ids or [],
|
||||||
|
remote_num_tokens=remote_num,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
self.reqs_to_send[request_id] = (transfer_id, local_block_ids)
|
||||||
@@ -183,6 +196,8 @@ class MooncakeConnector(KVConnectorBase_V1):
|
|||||||
def set_block_pool(self, block_pool):
|
def set_block_pool(self, block_pool):
|
||||||
if self.connector_scheduler is not None:
|
if self.connector_scheduler is not None:
|
||||||
self.connector_scheduler.set_block_pool(block_pool)
|
self.connector_scheduler.set_block_pool(block_pool)
|
||||||
|
# Also store module-level for bootstrap server access (same process for kv_both TP=1)
|
||||||
|
_set_shared_block_pool(block_pool)
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Scheduler Side Methods
|
# Scheduler Side Methods
|
||||||
@@ -276,8 +291,8 @@ class MooncakeConnectorScheduler:
|
|||||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||||
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||||
self._reqs_not_processed: set[TransferId] = set()
|
self._reqs_not_processed: set[TransferId] = set()
|
||||||
# Per-request block hashes for direct_read passthrough
|
|
||||||
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
self._req_block_hashes: dict[ReqId, list[bytes]] = {}
|
||||||
|
self._req_token_ids: dict[ReqId, list[int]] = {}
|
||||||
|
|
||||||
def set_block_pool(self, block_pool):
|
def set_block_pool(self, block_pool):
|
||||||
self._block_pool = block_pool
|
self._block_pool = block_pool
|
||||||
@@ -361,15 +376,20 @@ class MooncakeConnectorScheduler:
|
|||||||
else:
|
else:
|
||||||
local_block_ids = []
|
local_block_ids = []
|
||||||
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
|
||||||
# Capture block hashes for direct_read
|
if params.get("direct_read"):
|
||||||
if params.get("direct_read") and hasattr(request, "block_hashes"):
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
num_remote_blocks = (
|
num_remote_blocks = (
|
||||||
(num_external_tokens + block_size - 1) // block_size
|
(num_external_tokens + block_size - 1) // block_size
|
||||||
)
|
)
|
||||||
self._req_block_hashes[request.request_id] = [
|
if hasattr(request, "block_hashes"):
|
||||||
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
self._req_block_hashes[request.request_id] = [
|
||||||
]
|
bytes(h) for h in request.block_hashes[:num_remote_blocks]
|
||||||
|
]
|
||||||
|
# Store prompt token_ids for token-based lookup on C
|
||||||
|
if hasattr(request, "prompt_token_ids") and request.prompt_token_ids:
|
||||||
|
self._req_token_ids[request.request_id] = list(
|
||||||
|
request.prompt_token_ids[:num_remote_blocks * block_size]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Got invalid KVTransferParams: %s. This "
|
"Got invalid KVTransferParams: %s. This "
|
||||||
@@ -400,6 +420,7 @@ class MooncakeConnectorScheduler:
|
|||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
block_hashes=self._req_block_hashes.pop(req_id, None),
|
block_hashes=self._req_block_hashes.pop(req_id, None),
|
||||||
|
prompt_token_ids=self._req_token_ids.pop(req_id, None),
|
||||||
)
|
)
|
||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
|
|
||||||
@@ -1030,8 +1051,10 @@ class MooncakeConnectorWorker:
|
|||||||
if self.bootstrap_server is not None:
|
if self.bootstrap_server is not None:
|
||||||
self.bootstrap_server.set_worker_kv_info(
|
self.bootstrap_server.set_worker_kv_info(
|
||||||
self.kv_caches_base_addr, self.block_len,
|
self.kv_caches_base_addr, self.block_len,
|
||||||
self.hostname, self.rpc_port,
|
self.block_size, self.hostname, self.rpc_port,
|
||||||
)
|
)
|
||||||
|
if _shared_block_pool is not None:
|
||||||
|
self.bootstrap_server.set_block_pool(_shared_block_pool)
|
||||||
|
|
||||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||||
finished_recving_reqs = self.finished_recving_reqs
|
finished_recving_reqs = self.finished_recving_reqs
|
||||||
@@ -1266,14 +1289,18 @@ class MooncakeConnectorWorker:
|
|||||||
bootstrap_url = pm.remote_bootstrap_addr
|
bootstrap_url = pm.remote_bootstrap_addr
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Query C's bootstrap for block mapping
|
# 1. Query C's bootstrap for block mapping (token-based lookup)
|
||||||
|
query_payload = {"pin_token": pin_token}
|
||||||
|
if pm.prompt_token_ids:
|
||||||
|
query_payload["token_ids"] = pm.prompt_token_ids
|
||||||
|
query_payload["num_tokens"] = pm.remote_num_tokens or len(pm.prompt_token_ids)
|
||||||
|
else:
|
||||||
|
query_payload["block_hashes"] = [h.hex() for h in pm.block_hashes]
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{bootstrap_url}/query_blocks",
|
f"{bootstrap_url}/query_blocks",
|
||||||
json={
|
json=query_payload,
|
||||||
"block_hashes": [h.hex() for h in pm.block_hashes],
|
|
||||||
"pin_token": pin_token,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
mapping = resp.json()
|
mapping = resp.json()
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ class EngineEntry:
|
|||||||
|
|
||||||
|
|
||||||
class QueryBlocksRequest(BaseModel):
|
class QueryBlocksRequest(BaseModel):
|
||||||
block_hashes: list[str] # hex-encoded BlockHash values
|
block_hashes: list[str] | None = None # hex-encoded BlockHash values (legacy)
|
||||||
|
token_ids: list[int] | None = None # raw token IDs for C-side hash lookup
|
||||||
|
num_tokens: int | None = None # number of tokens to match prefix for
|
||||||
pin_token: str
|
pin_token: str
|
||||||
|
|
||||||
|
|
||||||
@@ -66,10 +68,11 @@ class MooncakeBootstrapServer:
|
|||||||
self.server_thread: threading.Thread | None = None
|
self.server_thread: threading.Thread | None = None
|
||||||
self.server: uvicorn.Server | None = None
|
self.server: uvicorn.Server | None = None
|
||||||
|
|
||||||
# Direct RDMA read support: block hash → block_id mapping
|
# Direct RDMA read support
|
||||||
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id
|
self._hash_table: dict[str, int] = {} # hex BlockHash → block_id (legacy)
|
||||||
self._kv_info: dict | None = None # set by worker at register_kv_caches
|
self._kv_info: dict | None = None # set by worker at register_kv_caches
|
||||||
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
|
self._pinned: dict[str, list[int]] = {} # pin_token → block_ids
|
||||||
|
self._block_pool = None # set by scheduler for token-based lookup
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
@@ -154,12 +157,14 @@ class MooncakeBootstrapServer:
|
|||||||
self,
|
self,
|
||||||
kv_caches_base_addr: list[int],
|
kv_caches_base_addr: list[int],
|
||||||
block_len: int,
|
block_len: int,
|
||||||
|
block_size: int,
|
||||||
hostname: str,
|
hostname: str,
|
||||||
rpc_port: int,
|
rpc_port: int,
|
||||||
):
|
):
|
||||||
self._kv_info = {
|
self._kv_info = {
|
||||||
"kv_caches_base_addr": kv_caches_base_addr,
|
"kv_caches_base_addr": kv_caches_base_addr,
|
||||||
"block_len": block_len,
|
"block_len": block_len,
|
||||||
|
"block_size": block_size,
|
||||||
"hostname": hostname,
|
"hostname": hostname,
|
||||||
"rpc_port": rpc_port,
|
"rpc_port": rpc_port,
|
||||||
}
|
}
|
||||||
@@ -173,20 +178,40 @@ class MooncakeBootstrapServer:
|
|||||||
self._hash_table.pop(k, None)
|
self._hash_table.pop(k, None)
|
||||||
self._hash_table.update(updates)
|
self._hash_table.update(updates)
|
||||||
|
|
||||||
|
def set_block_pool(self, block_pool):
|
||||||
|
"""Store reference to scheduler's block pool for token-based lookup."""
|
||||||
|
self._block_pool = block_pool
|
||||||
|
|
||||||
async def query_blocks(self, req: QueryBlocksRequest):
|
async def query_blocks(self, req: QueryBlocksRequest):
|
||||||
if self._kv_info is None:
|
if self._kv_info is None:
|
||||||
raise HTTPException(503, "Worker KV info not registered yet")
|
raise HTTPException(503, "Worker KV info not registered yet")
|
||||||
|
|
||||||
block_ids: list[int | None] = []
|
block_ids: list[int | None] = []
|
||||||
pinned_ids: list[int] = []
|
pinned_ids: list[int] = []
|
||||||
for h in req.block_hashes:
|
|
||||||
bid = self._hash_table.get(h)
|
if req.token_ids is not None and self._block_pool is not None:
|
||||||
if bid is not None:
|
# Token-based lookup: compute hashes on C's side using C's hash function
|
||||||
block_ids.append(bid)
|
block_ids, pinned_ids = self._lookup_by_tokens(
|
||||||
pinned_ids.append(bid)
|
req.token_ids, req.num_tokens)
|
||||||
else:
|
elif req.block_hashes is not None:
|
||||||
block_ids.append(None)
|
# Hash-based lookup (legacy)
|
||||||
break # prefix match: stop at first miss
|
for h in req.block_hashes:
|
||||||
|
bid = self._hash_table.get(h)
|
||||||
|
if bid is not None:
|
||||||
|
block_ids.append(bid)
|
||||||
|
pinned_ids.append(bid)
|
||||||
|
else:
|
||||||
|
block_ids.append(None)
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"query_blocks: %d/%d matched (token_mode=%s, pool=%s)",
|
||||||
|
len(pinned_ids),
|
||||||
|
req.num_tokens // self._kv_info.get("block_size", 512)
|
||||||
|
if req.num_tokens else len(req.block_hashes or []),
|
||||||
|
req.token_ids is not None,
|
||||||
|
self._block_pool is not None,
|
||||||
|
)
|
||||||
|
|
||||||
self._pinned[req.pin_token] = pinned_ids
|
self._pinned[req.pin_token] = pinned_ids
|
||||||
return QueryBlocksResponse(
|
return QueryBlocksResponse(
|
||||||
@@ -197,6 +222,55 @@ class MooncakeBootstrapServer:
|
|||||||
rpc_port=self._kv_info["rpc_port"],
|
rpc_port=self._kv_info["rpc_port"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _lookup_by_tokens(
|
||||||
|
self, token_ids: list[int], num_tokens: int | None
|
||||||
|
) -> tuple[list[int | None], list[int]]:
|
||||||
|
"""Look up cached blocks by computing hashes on C's side."""
|
||||||
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
|
hash_block_tokens,
|
||||||
|
make_block_hash_with_group_id,
|
||||||
|
NONE_HASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
bp = self._block_pool
|
||||||
|
block_size = bp.block_size if hasattr(bp, 'block_size') else 512
|
||||||
|
# Use C's own hash function
|
||||||
|
hash_fn = bp._hash_fn if hasattr(bp, '_hash_fn') else None
|
||||||
|
if hash_fn is None:
|
||||||
|
# Fallback: try to get from coordinator
|
||||||
|
coord = bp.coordinator if hasattr(bp, 'coordinator') else None
|
||||||
|
if coord and hasattr(coord, '_caching_hash_fn'):
|
||||||
|
hash_fn = coord._caching_hash_fn
|
||||||
|
|
||||||
|
if hash_fn is None:
|
||||||
|
logger.warning("Cannot find hash function on block pool")
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
n = num_tokens or len(token_ids)
|
||||||
|
n = min(n, len(token_ids))
|
||||||
|
num_blocks = n // block_size
|
||||||
|
|
||||||
|
block_ids: list[int | None] = []
|
||||||
|
pinned_ids: list[int] = []
|
||||||
|
prev_hash = NONE_HASH
|
||||||
|
|
||||||
|
for i in range(num_blocks):
|
||||||
|
block_tokens = tuple(token_ids[i * block_size:(i + 1) * block_size])
|
||||||
|
block_hash = hash_block_tokens(hash_fn, prev_hash, block_tokens, None)
|
||||||
|
prev_hash = block_hash
|
||||||
|
|
||||||
|
# Look up in C's block pool
|
||||||
|
key = make_block_hash_with_group_id(block_hash, 0)
|
||||||
|
block = bp.cached_block_hash_to_block.get_one_block(key)
|
||||||
|
if block is not None:
|
||||||
|
block_ids.append(block.block_id)
|
||||||
|
pinned_ids.append(block.block_id)
|
||||||
|
else:
|
||||||
|
block_ids.append(None)
|
||||||
|
break # prefix match stops
|
||||||
|
|
||||||
|
return block_ids, pinned_ids
|
||||||
|
|
||||||
async def unpin_blocks(self, req: UnpinBlocksRequest):
|
async def unpin_blocks(self, req: UnpinBlocksRequest):
|
||||||
self._pinned.pop(req.pin_token, None)
|
self._pinned.pop(req.pin_token, None)
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|||||||
Reference in New Issue
Block a user