"""Minimal test: verify direct RDMA read hash matching. 1. Send a multi-turn session to inst_0 (builds cache) 2. Query inst_0's bootstrap /query_blocks with computed block hashes 3. Check if hashes match (the core question) Usage: # Start 2 elastic instances first, then: python scripts/test_direct_read.py --port0 8000 --bp0 8998 --port1 8001 --bp1 8999 """ import argparse import json import random import time import httpx BLOCK_SIZE = 512 VOCAB_SIZE = 151936 TOKEN_RANGE_START = 100 TOKEN_RANGE_END = VOCAB_SIZE - 100 def make_prompt(seed: int, n_blocks: int) -> list[int]: """Deterministic prompt from seed, like the replayer does.""" rng = random.Random(seed) return [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE * n_blocks)] def main(): p = argparse.ArgumentParser() p.add_argument("--port0", type=int, default=8000) p.add_argument("--bp0", type=int, default=8998) p.add_argument("--port1", type=int, default=8001) p.add_argument("--bp1", type=int, default=8999) p.add_argument("--model", type=str, default="/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct") args = p.parse_args() client = httpx.Client(timeout=120) base0 = f"http://127.0.0.1:{args.port0}" base1 = f"http://127.0.0.1:{args.port1}" bp0 = f"http://127.0.0.1:{args.bp0}" bp1 = f"http://127.0.0.1:{args.bp1}" # Step 1: Send request to inst_0 to build cache prompt = make_prompt(seed=42, n_blocks=20) # 10240 tokens print(f"[1] Sending {len(prompt)} tokens to inst_0...") resp = client.post(f"{base0}/v1/completions", json={ "model": args.model, "prompt": prompt, "max_tokens": 1, "temperature": 0, }) resp.raise_for_status() print(f" OK: {resp.json()['choices'][0]['text'][:20]}...") # Wait for hash table sync (happens in scheduler step) time.sleep(3) # Step 2: Query inst_0's bootstrap for its hash table size print(f"\n[2] Querying inst_0 bootstrap /query endpoint...") resp = client.get(f"{bp0}/query") resp.raise_for_status() query_data = resp.json() print(f" Bootstrap has {len(query_data)} dp_rank entries") # Step 3: Compute block hashes the way D would # D's scheduler uses request.block_hashes which is computed by # vLLM's block hasher. We can't easily replicate that here. # Instead, let's send the SAME prompt to inst_1 with direct_read=True # and see what happens. # First, let's directly test the /query_blocks endpoint # with some known hashes. We need to know what hashes inst_0 has. # Try querying with dummy hashes to see the response format print(f"\n[3] Testing /query_blocks with dummy hashes...") resp = client.post(f"{bp0}/query_blocks", json={ "block_hashes": ["0000000000000000"], "pin_token": "test-1", }) resp.raise_for_status() result = resp.json() print(f" Response: {json.dumps(result, indent=2)}") # Unpin client.post(f"{bp0}/unpin_blocks", json={"pin_token": "test-1"}) # Step 4: Send same prompt to inst_1 with do_remote_prefill + direct_read # This triggers D's scheduler to compute block_hashes and the worker # to query C's bootstrap print(f"\n[4] Sending same prompt to inst_1 with direct_read...") # Get inst_0's engine_id from bootstrap engine_id = query_data.get("0", {}).get("engine_id", "") print(f" inst_0 engine_id: {engine_id}") resp = client.post(f"{base1}/v1/completions", json={ "model": args.model, "prompt": prompt, "max_tokens": 1, "temperature": 0, "kv_transfer_params": { "do_remote_decode": False, "do_remote_prefill": True, "direct_read": True, "remote_bootstrap_addr": bp0, "remote_engine_id": engine_id, "transfer_id": "test-xfer-001", "remote_num_tokens": len(prompt), }, }) print(f" Status: {resp.status_code}") if resp.status_code == 200: print(f" Output: {resp.json()['choices'][0]['text'][:50]}...") else: print(f" Error: {resp.text[:200]}") # Step 5: Check logs for hash matching print(f"\n[5] Check vLLM logs for direct_read activity:") print(f" grep 'direct_read\\|query_blocks\\|hash_table_sync\\|no cache hit' inst_*.log") # Step 6: Send turn 2 (extended prompt) to verify prefix caching prompt2 = prompt + make_prompt(seed=43, n_blocks=5) # extend by 2560 tokens print(f"\n[6] Sending turn 2 ({len(prompt2)} tokens) to inst_0...") t0 = time.time() resp = client.post(f"{base0}/v1/completions", json={ "model": args.model, "prompt": prompt2, "max_tokens": 1, "temperature": 0, }) resp.raise_for_status() ttft = time.time() - t0 print(f" TTFT: {ttft:.3f}s (should be fast if prefix cached)") # Now send turn 2 to inst_1 with direct_read for turn 1's cache print(f"\n[7] Sending turn 2 to inst_1 with direct_read (remote_num_tokens={len(prompt)})...") t0 = time.time() resp = client.post(f"{base1}/v1/completions", json={ "model": args.model, "prompt": prompt2, "max_tokens": 1, "temperature": 0, "kv_transfer_params": { "do_remote_decode": False, "do_remote_prefill": True, "direct_read": True, "remote_bootstrap_addr": bp0, "remote_engine_id": engine_id, "transfer_id": "test-xfer-002", "remote_num_tokens": len(prompt), # only first 10240 from remote }, }) ttft1 = time.time() - t0 print(f" Status: {resp.status_code}") if resp.status_code == 200: print(f" TTFT: {ttft1:.3f}s") print(f" Output: {resp.json()['choices'][0]['text'][:50]}...") else: print(f" Error: {resp.text[:200]}") print(f"\n=== Summary ===") print(f"Turn 1 on inst_0: OK") print(f"Turn 2 on inst_0 (cached): TTFT={ttft:.3f}s") print(f"Turn 2 on inst_1 (direct_read): TTFT={ttft1:.3f}s") print(f"If direct_read works: inst_1 TTFT ≈ inst_0 TTFT (both have cache)") print(f"If direct_read broken: inst_1 TTFT >> inst_0 TTFT (cold prefill)") if __name__ == "__main__": main()