diff --git a/scripts/test_direct_read.py b/scripts/test_direct_read.py new file mode 100644 index 0000000..a7b39d1 --- /dev/null +++ b/scripts/test_direct_read.py @@ -0,0 +1,176 @@ +"""Minimal test: verify direct RDMA read hash matching. + +1. Send a multi-turn session to inst_0 (builds cache) +2. Query inst_0's bootstrap /query_blocks with computed block hashes +3. Check if hashes match (the core question) + +Usage: + # Start 2 elastic instances first, then: + python scripts/test_direct_read.py --port0 8000 --bp0 8998 --port1 8001 --bp1 8999 +""" + +import argparse +import json +import random +import time + +import httpx + +BLOCK_SIZE = 512 +VOCAB_SIZE = 151936 +TOKEN_RANGE_START = 100 +TOKEN_RANGE_END = VOCAB_SIZE - 100 + + +def make_prompt(seed: int, n_blocks: int) -> list[int]: + """Deterministic prompt from seed, like the replayer does.""" + rng = random.Random(seed) + return [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) + for _ in range(BLOCK_SIZE * n_blocks)] + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--port0", type=int, default=8000) + p.add_argument("--bp0", type=int, default=8998) + p.add_argument("--port1", type=int, default=8001) + p.add_argument("--bp1", type=int, default=8999) + p.add_argument("--model", type=str, + default="/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct") + args = p.parse_args() + + client = httpx.Client(timeout=120) + base0 = f"http://127.0.0.1:{args.port0}" + base1 = f"http://127.0.0.1:{args.port1}" + bp0 = f"http://127.0.0.1:{args.bp0}" + bp1 = f"http://127.0.0.1:{args.bp1}" + + # Step 1: Send request to inst_0 to build cache + prompt = make_prompt(seed=42, n_blocks=20) # 10240 tokens + print(f"[1] Sending {len(prompt)} tokens to inst_0...") + + resp = client.post(f"{base0}/v1/completions", json={ + "model": args.model, + "prompt": prompt, + "max_tokens": 1, + "temperature": 0, + }) + resp.raise_for_status() + print(f" OK: {resp.json()['choices'][0]['text'][:20]}...") + + # Wait for hash table sync (happens in scheduler step) + time.sleep(3) + + # Step 2: Query inst_0's bootstrap for its hash table size + print(f"\n[2] Querying inst_0 bootstrap /query endpoint...") + resp = client.get(f"{bp0}/query") + resp.raise_for_status() + query_data = resp.json() + print(f" Bootstrap has {len(query_data)} dp_rank entries") + + # Step 3: Compute block hashes the way D would + # D's scheduler uses request.block_hashes which is computed by + # vLLM's block hasher. We can't easily replicate that here. + # Instead, let's send the SAME prompt to inst_1 with direct_read=True + # and see what happens. + + # First, let's directly test the /query_blocks endpoint + # with some known hashes. We need to know what hashes inst_0 has. + + # Try querying with dummy hashes to see the response format + print(f"\n[3] Testing /query_blocks with dummy hashes...") + resp = client.post(f"{bp0}/query_blocks", json={ + "block_hashes": ["0000000000000000"], + "pin_token": "test-1", + }) + resp.raise_for_status() + result = resp.json() + print(f" Response: {json.dumps(result, indent=2)}") + + # Unpin + client.post(f"{bp0}/unpin_blocks", json={"pin_token": "test-1"}) + + # Step 4: Send same prompt to inst_1 with do_remote_prefill + direct_read + # This triggers D's scheduler to compute block_hashes and the worker + # to query C's bootstrap + print(f"\n[4] Sending same prompt to inst_1 with direct_read...") + + # Get inst_0's engine_id from bootstrap + engine_id = query_data.get("0", {}).get("engine_id", "") + print(f" inst_0 engine_id: {engine_id}") + + resp = client.post(f"{base1}/v1/completions", json={ + "model": args.model, + "prompt": prompt, + "max_tokens": 1, + "temperature": 0, + "kv_transfer_params": { + "do_remote_decode": False, + "do_remote_prefill": True, + "direct_read": True, + "remote_bootstrap_addr": bp0, + "remote_engine_id": engine_id, + "transfer_id": "test-xfer-001", + "remote_num_tokens": len(prompt), + }, + }) + print(f" Status: {resp.status_code}") + if resp.status_code == 200: + print(f" Output: {resp.json()['choices'][0]['text'][:50]}...") + else: + print(f" Error: {resp.text[:200]}") + + # Step 5: Check logs for hash matching + print(f"\n[5] Check vLLM logs for direct_read activity:") + print(f" grep 'direct_read\\|query_blocks\\|hash_table_sync\\|no cache hit' inst_*.log") + + # Step 6: Send turn 2 (extended prompt) to verify prefix caching + prompt2 = prompt + make_prompt(seed=43, n_blocks=5) # extend by 2560 tokens + print(f"\n[6] Sending turn 2 ({len(prompt2)} tokens) to inst_0...") + t0 = time.time() + resp = client.post(f"{base0}/v1/completions", json={ + "model": args.model, + "prompt": prompt2, + "max_tokens": 1, + "temperature": 0, + }) + resp.raise_for_status() + ttft = time.time() - t0 + print(f" TTFT: {ttft:.3f}s (should be fast if prefix cached)") + + # Now send turn 2 to inst_1 with direct_read for turn 1's cache + print(f"\n[7] Sending turn 2 to inst_1 with direct_read (remote_num_tokens={len(prompt)})...") + t0 = time.time() + resp = client.post(f"{base1}/v1/completions", json={ + "model": args.model, + "prompt": prompt2, + "max_tokens": 1, + "temperature": 0, + "kv_transfer_params": { + "do_remote_decode": False, + "do_remote_prefill": True, + "direct_read": True, + "remote_bootstrap_addr": bp0, + "remote_engine_id": engine_id, + "transfer_id": "test-xfer-002", + "remote_num_tokens": len(prompt), # only first 10240 from remote + }, + }) + ttft1 = time.time() - t0 + print(f" Status: {resp.status_code}") + if resp.status_code == 200: + print(f" TTFT: {ttft1:.3f}s") + print(f" Output: {resp.json()['choices'][0]['text'][:50]}...") + else: + print(f" Error: {resp.text[:200]}") + + print(f"\n=== Summary ===") + print(f"Turn 1 on inst_0: OK") + print(f"Turn 2 on inst_0 (cached): TTFT={ttft:.3f}s") + print(f"Turn 2 on inst_1 (direct_read): TTFT={ttft1:.3f}s") + print(f"If direct_read works: inst_1 TTFT ≈ inst_0 TTFT (both have cache)") + print(f"If direct_read broken: inst_1 TTFT >> inst_0 TTFT (cold prefill)") + + +if __name__ == "__main__": + main() diff --git a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index a7abde2..d4b6d30 100644 --- a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -65,6 +65,13 @@ TransferId = str # KV transfer coordination ID (shared by P/D) logger = init_logger(__name__) +# Module-level block pool for bootstrap server access (kv_both same-process only) +_shared_block_pool = None + +def _set_shared_block_pool(bp): + global _shared_block_pool + _shared_block_pool = bp + class MooncakeXferMetadata( msgspec.Struct, @@ -111,6 +118,8 @@ class PullReqMeta: # Direct RDMA read: D reads from C's GPU memory without C's scheduler direct_read: bool = False block_hashes: list[bytes] = field(default_factory=list) + prompt_token_ids: list[int] = field(default_factory=list) + remote_num_tokens: int = 0 @dataclass @@ -141,10 +150,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): kv_transfer_params: dict[str, Any], load_remote_cache: bool = True, block_hashes: list[bytes] | None = None, + prompt_token_ids: list[int] | None = None, ): transfer_id = kv_transfer_params["transfer_id"] if load_remote_cache: remote_engine_id = kv_transfer_params["remote_engine_id"] + remote_num = kv_transfer_params.get("remote_num_tokens", 0) self.reqs_to_recv[remote_engine_id][request_id] = PullReqMeta( d_req_id=request_id, local_block_ids=local_block_ids, @@ -153,6 +164,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): transfer_id=transfer_id, direct_read=bool(kv_transfer_params.get("direct_read")), block_hashes=block_hashes or [], + prompt_token_ids=prompt_token_ids or [], + remote_num_tokens=remote_num, ) else: self.reqs_to_send[request_id] = (transfer_id, local_block_ids) @@ -183,6 +196,8 @@ class MooncakeConnector(KVConnectorBase_V1): def set_block_pool(self, block_pool): if self.connector_scheduler is not None: self.connector_scheduler.set_block_pool(block_pool) + # Also store module-level for bootstrap server access (same process for kv_both TP=1) + _set_shared_block_pool(block_pool) ############################################################ # Scheduler Side Methods @@ -276,8 +291,8 @@ class MooncakeConnectorScheduler: self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_not_processed: set[TransferId] = set() - # Per-request block hashes for direct_read passthrough self._req_block_hashes: dict[ReqId, list[bytes]] = {} + self._req_token_ids: dict[ReqId, list[int]] = {} def set_block_pool(self, block_pool): self._block_pool = block_pool @@ -361,15 +376,20 @@ class MooncakeConnectorScheduler: else: local_block_ids = [] self._reqs_need_recv[request.request_id] = (request, local_block_ids) - # Capture block hashes for direct_read - if params.get("direct_read") and hasattr(request, "block_hashes"): + if params.get("direct_read"): block_size = self.vllm_config.cache_config.block_size num_remote_blocks = ( (num_external_tokens + block_size - 1) // block_size ) - self._req_block_hashes[request.request_id] = [ - bytes(h) for h in request.block_hashes[:num_remote_blocks] - ] + if hasattr(request, "block_hashes"): + self._req_block_hashes[request.request_id] = [ + bytes(h) for h in request.block_hashes[:num_remote_blocks] + ] + # Store prompt token_ids for token-based lookup on C + if hasattr(request, "prompt_token_ids") and request.prompt_token_ids: + self._req_token_ids[request.request_id] = list( + request.prompt_token_ids[:num_remote_blocks * block_size] + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " @@ -400,6 +420,7 @@ class MooncakeConnectorScheduler: local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, block_hashes=self._req_block_hashes.pop(req_id, None), + prompt_token_ids=self._req_token_ids.pop(req_id, None), ) self._reqs_need_recv.clear() @@ -1030,8 +1051,10 @@ class MooncakeConnectorWorker: if self.bootstrap_server is not None: self.bootstrap_server.set_worker_kv_info( self.kv_caches_base_addr, self.block_len, - self.hostname, self.rpc_port, + self.block_size, self.hostname, self.rpc_port, ) + if _shared_block_pool is not None: + self.bootstrap_server.set_block_pool(_shared_block_pool) async def fetch_finished_recving_reqs(self) -> set[ReqId]: finished_recving_reqs = self.finished_recving_reqs @@ -1266,14 +1289,18 @@ class MooncakeConnectorWorker: bootstrap_url = pm.remote_bootstrap_addr try: - # 1. Query C's bootstrap for block mapping + # 1. Query C's bootstrap for block mapping (token-based lookup) + query_payload = {"pin_token": pin_token} + if pm.prompt_token_ids: + query_payload["token_ids"] = pm.prompt_token_ids + query_payload["num_tokens"] = pm.remote_num_tokens or len(pm.prompt_token_ids) + else: + query_payload["block_hashes"] = [h.hex() for h in pm.block_hashes] + async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{bootstrap_url}/query_blocks", - json={ - "block_hashes": [h.hex() for h in pm.block_hashes], - "pin_token": pin_token, - }, + json=query_payload, ) resp.raise_for_status() mapping = resp.json() diff --git a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py index 231cd3f..595ec75 100644 --- a/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py +++ b/third_party/vllm/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py @@ -33,7 +33,9 @@ class EngineEntry: class QueryBlocksRequest(BaseModel): - block_hashes: list[str] # hex-encoded BlockHash values + block_hashes: list[str] | None = None # hex-encoded BlockHash values (legacy) + token_ids: list[int] | None = None # raw token IDs for C-side hash lookup + num_tokens: int | None = None # number of tokens to match prefix for pin_token: str @@ -66,10 +68,11 @@ class MooncakeBootstrapServer: self.server_thread: threading.Thread | None = None self.server: uvicorn.Server | None = None - # Direct RDMA read support: block hash → block_id mapping - self._hash_table: dict[str, int] = {} # hex BlockHash → block_id + # Direct RDMA read support + self._hash_table: dict[str, int] = {} # hex BlockHash → block_id (legacy) self._kv_info: dict | None = None # set by worker at register_kv_caches self._pinned: dict[str, list[int]] = {} # pin_token → block_ids + self._block_pool = None # set by scheduler for token-based lookup def __del__(self): self.shutdown() @@ -154,12 +157,14 @@ class MooncakeBootstrapServer: self, kv_caches_base_addr: list[int], block_len: int, + block_size: int, hostname: str, rpc_port: int, ): self._kv_info = { "kv_caches_base_addr": kv_caches_base_addr, "block_len": block_len, + "block_size": block_size, "hostname": hostname, "rpc_port": rpc_port, } @@ -173,20 +178,40 @@ class MooncakeBootstrapServer: self._hash_table.pop(k, None) self._hash_table.update(updates) + def set_block_pool(self, block_pool): + """Store reference to scheduler's block pool for token-based lookup.""" + self._block_pool = block_pool + async def query_blocks(self, req: QueryBlocksRequest): if self._kv_info is None: raise HTTPException(503, "Worker KV info not registered yet") block_ids: list[int | None] = [] pinned_ids: list[int] = [] - for h in req.block_hashes: - bid = self._hash_table.get(h) - if bid is not None: - block_ids.append(bid) - pinned_ids.append(bid) - else: - block_ids.append(None) - break # prefix match: stop at first miss + + if req.token_ids is not None and self._block_pool is not None: + # Token-based lookup: compute hashes on C's side using C's hash function + block_ids, pinned_ids = self._lookup_by_tokens( + req.token_ids, req.num_tokens) + elif req.block_hashes is not None: + # Hash-based lookup (legacy) + for h in req.block_hashes: + bid = self._hash_table.get(h) + if bid is not None: + block_ids.append(bid) + pinned_ids.append(bid) + else: + block_ids.append(None) + break + + logger.info( + "query_blocks: %d/%d matched (token_mode=%s, pool=%s)", + len(pinned_ids), + req.num_tokens // self._kv_info.get("block_size", 512) + if req.num_tokens else len(req.block_hashes or []), + req.token_ids is not None, + self._block_pool is not None, + ) self._pinned[req.pin_token] = pinned_ids return QueryBlocksResponse( @@ -197,6 +222,55 @@ class MooncakeBootstrapServer: rpc_port=self._kv_info["rpc_port"], ) + def _lookup_by_tokens( + self, token_ids: list[int], num_tokens: int | None + ) -> tuple[list[int | None], list[int]]: + """Look up cached blocks by computing hashes on C's side.""" + from vllm.v1.core.kv_cache_utils import ( + hash_block_tokens, + make_block_hash_with_group_id, + NONE_HASH, + ) + + bp = self._block_pool + block_size = bp.block_size if hasattr(bp, 'block_size') else 512 + # Use C's own hash function + hash_fn = bp._hash_fn if hasattr(bp, '_hash_fn') else None + if hash_fn is None: + # Fallback: try to get from coordinator + coord = bp.coordinator if hasattr(bp, 'coordinator') else None + if coord and hasattr(coord, '_caching_hash_fn'): + hash_fn = coord._caching_hash_fn + + if hash_fn is None: + logger.warning("Cannot find hash function on block pool") + return [], [] + + n = num_tokens or len(token_ids) + n = min(n, len(token_ids)) + num_blocks = n // block_size + + block_ids: list[int | None] = [] + pinned_ids: list[int] = [] + prev_hash = NONE_HASH + + for i in range(num_blocks): + block_tokens = tuple(token_ids[i * block_size:(i + 1) * block_size]) + block_hash = hash_block_tokens(hash_fn, prev_hash, block_tokens, None) + prev_hash = block_hash + + # Look up in C's block pool + key = make_block_hash_with_group_id(block_hash, 0) + block = bp.cached_block_hash_to_block.get_one_block(key) + if block is not None: + block_ids.append(block.block_id) + pinned_ids.append(block.block_id) + else: + block_ids.append(None) + break # prefix match stops + + return block_ids, pinned_ids + async def unpin_blocks(self, req: UnpinBlocksRequest): self._pinned.pop(req.pin_token, None) return {"status": "ok"}