Root cause: each vLLM instance has a random NONE_HASH (os.urandom(32)) when PYTHONHASHSEED is not set. All block hashes are chained from NONE_HASH, so D's hashes never match C's hashes. Fix: C's bootstrap server now accepts token_ids and does the prefix cache lookup locally using C's own hash function and block pool. No cross-instance hash matching needed. New flow: D sends prompt token_ids → C computes hashes on C's side → C looks up in C's own BlockPool → returns block_ids. Also: module-level _shared_block_pool for scheduler→bootstrap bridge, prompt_token_ids passed through PullReqMeta, test script added. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
177 lines
6.2 KiB
Python
177 lines
6.2 KiB
Python
"""Minimal test: verify direct RDMA read hash matching.
|
|
|
|
1. Send a multi-turn session to inst_0 (builds cache)
|
|
2. Query inst_0's bootstrap /query_blocks with computed block hashes
|
|
3. Check if hashes match (the core question)
|
|
|
|
Usage:
|
|
# Start 2 elastic instances first, then:
|
|
python scripts/test_direct_read.py --port0 8000 --bp0 8998 --port1 8001 --bp1 8999
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
import time
|
|
|
|
import httpx
|
|
|
|
BLOCK_SIZE = 512
|
|
VOCAB_SIZE = 151936
|
|
TOKEN_RANGE_START = 100
|
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
|
|
|
|
|
def make_prompt(seed: int, n_blocks: int) -> list[int]:
|
|
"""Deterministic prompt from seed, like the replayer does."""
|
|
rng = random.Random(seed)
|
|
return [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END)
|
|
for _ in range(BLOCK_SIZE * n_blocks)]
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--port0", type=int, default=8000)
|
|
p.add_argument("--bp0", type=int, default=8998)
|
|
p.add_argument("--port1", type=int, default=8001)
|
|
p.add_argument("--bp1", type=int, default=8999)
|
|
p.add_argument("--model", type=str,
|
|
default="/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct")
|
|
args = p.parse_args()
|
|
|
|
client = httpx.Client(timeout=120)
|
|
base0 = f"http://127.0.0.1:{args.port0}"
|
|
base1 = f"http://127.0.0.1:{args.port1}"
|
|
bp0 = f"http://127.0.0.1:{args.bp0}"
|
|
bp1 = f"http://127.0.0.1:{args.bp1}"
|
|
|
|
# Step 1: Send request to inst_0 to build cache
|
|
prompt = make_prompt(seed=42, n_blocks=20) # 10240 tokens
|
|
print(f"[1] Sending {len(prompt)} tokens to inst_0...")
|
|
|
|
resp = client.post(f"{base0}/v1/completions", json={
|
|
"model": args.model,
|
|
"prompt": prompt,
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
})
|
|
resp.raise_for_status()
|
|
print(f" OK: {resp.json()['choices'][0]['text'][:20]}...")
|
|
|
|
# Wait for hash table sync (happens in scheduler step)
|
|
time.sleep(3)
|
|
|
|
# Step 2: Query inst_0's bootstrap for its hash table size
|
|
print(f"\n[2] Querying inst_0 bootstrap /query endpoint...")
|
|
resp = client.get(f"{bp0}/query")
|
|
resp.raise_for_status()
|
|
query_data = resp.json()
|
|
print(f" Bootstrap has {len(query_data)} dp_rank entries")
|
|
|
|
# Step 3: Compute block hashes the way D would
|
|
# D's scheduler uses request.block_hashes which is computed by
|
|
# vLLM's block hasher. We can't easily replicate that here.
|
|
# Instead, let's send the SAME prompt to inst_1 with direct_read=True
|
|
# and see what happens.
|
|
|
|
# First, let's directly test the /query_blocks endpoint
|
|
# with some known hashes. We need to know what hashes inst_0 has.
|
|
|
|
# Try querying with dummy hashes to see the response format
|
|
print(f"\n[3] Testing /query_blocks with dummy hashes...")
|
|
resp = client.post(f"{bp0}/query_blocks", json={
|
|
"block_hashes": ["0000000000000000"],
|
|
"pin_token": "test-1",
|
|
})
|
|
resp.raise_for_status()
|
|
result = resp.json()
|
|
print(f" Response: {json.dumps(result, indent=2)}")
|
|
|
|
# Unpin
|
|
client.post(f"{bp0}/unpin_blocks", json={"pin_token": "test-1"})
|
|
|
|
# Step 4: Send same prompt to inst_1 with do_remote_prefill + direct_read
|
|
# This triggers D's scheduler to compute block_hashes and the worker
|
|
# to query C's bootstrap
|
|
print(f"\n[4] Sending same prompt to inst_1 with direct_read...")
|
|
|
|
# Get inst_0's engine_id from bootstrap
|
|
engine_id = query_data.get("0", {}).get("engine_id", "")
|
|
print(f" inst_0 engine_id: {engine_id}")
|
|
|
|
resp = client.post(f"{base1}/v1/completions", json={
|
|
"model": args.model,
|
|
"prompt": prompt,
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
"kv_transfer_params": {
|
|
"do_remote_decode": False,
|
|
"do_remote_prefill": True,
|
|
"direct_read": True,
|
|
"remote_bootstrap_addr": bp0,
|
|
"remote_engine_id": engine_id,
|
|
"transfer_id": "test-xfer-001",
|
|
"remote_num_tokens": len(prompt),
|
|
},
|
|
})
|
|
print(f" Status: {resp.status_code}")
|
|
if resp.status_code == 200:
|
|
print(f" Output: {resp.json()['choices'][0]['text'][:50]}...")
|
|
else:
|
|
print(f" Error: {resp.text[:200]}")
|
|
|
|
# Step 5: Check logs for hash matching
|
|
print(f"\n[5] Check vLLM logs for direct_read activity:")
|
|
print(f" grep 'direct_read\\|query_blocks\\|hash_table_sync\\|no cache hit' inst_*.log")
|
|
|
|
# Step 6: Send turn 2 (extended prompt) to verify prefix caching
|
|
prompt2 = prompt + make_prompt(seed=43, n_blocks=5) # extend by 2560 tokens
|
|
print(f"\n[6] Sending turn 2 ({len(prompt2)} tokens) to inst_0...")
|
|
t0 = time.time()
|
|
resp = client.post(f"{base0}/v1/completions", json={
|
|
"model": args.model,
|
|
"prompt": prompt2,
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
})
|
|
resp.raise_for_status()
|
|
ttft = time.time() - t0
|
|
print(f" TTFT: {ttft:.3f}s (should be fast if prefix cached)")
|
|
|
|
# Now send turn 2 to inst_1 with direct_read for turn 1's cache
|
|
print(f"\n[7] Sending turn 2 to inst_1 with direct_read (remote_num_tokens={len(prompt)})...")
|
|
t0 = time.time()
|
|
resp = client.post(f"{base1}/v1/completions", json={
|
|
"model": args.model,
|
|
"prompt": prompt2,
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
"kv_transfer_params": {
|
|
"do_remote_decode": False,
|
|
"do_remote_prefill": True,
|
|
"direct_read": True,
|
|
"remote_bootstrap_addr": bp0,
|
|
"remote_engine_id": engine_id,
|
|
"transfer_id": "test-xfer-002",
|
|
"remote_num_tokens": len(prompt), # only first 10240 from remote
|
|
},
|
|
})
|
|
ttft1 = time.time() - t0
|
|
print(f" Status: {resp.status_code}")
|
|
if resp.status_code == 200:
|
|
print(f" TTFT: {ttft1:.3f}s")
|
|
print(f" Output: {resp.json()['choices'][0]['text'][:50]}...")
|
|
else:
|
|
print(f" Error: {resp.text[:200]}")
|
|
|
|
print(f"\n=== Summary ===")
|
|
print(f"Turn 1 on inst_0: OK")
|
|
print(f"Turn 2 on inst_0 (cached): TTFT={ttft:.3f}s")
|
|
print(f"Turn 2 on inst_1 (direct_read): TTFT={ttft1:.3f}s")
|
|
print(f"If direct_read works: inst_1 TTFT ≈ inst_0 TTFT (both have cache)")
|
|
print(f"If direct_read broken: inst_1 TTFT >> inst_0 TTFT (cold prefill)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|