Standalone smoke tests validating KV-migration correctness paths before trace replay: full migrate-cache, partial-prefill transfer, and a NIXL-connector variant, each with a runner. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
167 lines
7.5 KiB
Python
167 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""Smoke test for partial KV transfer (Mechanism B).
|
|
|
|
Test if vLLM's Mooncake connector actually transfers only the
|
|
NON-OVERLAPPING portion when dst already has prefix cache.
|
|
|
|
Sequence:
|
|
step 0: warm dst with prompt P (cold prefill) — dst now has cache for [0, P]
|
|
step 1: cold prefill on src with prompt P+ext (src now has [0, P+ext])
|
|
step 2: migrate src→dst with prompt P+ext
|
|
- dst local cache should hit [0, P]
|
|
- only [P, P+ext] needs to come from src (~ext tokens)
|
|
- dst decode should be fast
|
|
step 3: control: another migrate src→dst_cold with prompt P+ext
|
|
- dst_cold has no cache, must pull all P+ext tokens
|
|
- compare with step 2
|
|
|
|
If step 2 is dramatically faster than step 3, partial transfer works
|
|
and Mechanism B is viable. If step 2 ~= step 3, partial transfer isn't
|
|
being exploited and we need to dig deeper.
|
|
"""
|
|
import asyncio, argparse, json, sys, uuid, time
|
|
import httpx
|
|
import random as _r
|
|
|
|
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
|
|
|
|
|
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
|
|
payload = {
|
|
"model": MODEL,
|
|
"prompt": prompt,
|
|
"max_tokens": max_tokens,
|
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
|
"temperature": 0.0,
|
|
"stream": stream,
|
|
}
|
|
if kv_xfer is not None:
|
|
payload["kv_transfer_params"] = kv_xfer
|
|
if stream:
|
|
payload["stream_options"] = {"include_usage": True}
|
|
url = f"http://127.0.0.1:{port}/v1/completions"
|
|
if not stream:
|
|
r = await client.post(url, json=payload, timeout=300.0)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
last_w = None; last_any = None
|
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
|
resp.raise_for_status()
|
|
buf = ""
|
|
async for chunk in resp.aiter_bytes():
|
|
buf += chunk.decode("utf-8", errors="replace")
|
|
while "\n\n" in buf:
|
|
line, buf = buf.split("\n\n", 1)
|
|
if line.startswith("data: "):
|
|
s = line[6:].strip()
|
|
if s == "[DONE]": continue
|
|
try:
|
|
d = json.loads(s); last_any = d
|
|
if d.get("usage"): last_w = d
|
|
except Exception:
|
|
pass
|
|
return last_w or last_any or {}
|
|
|
|
|
|
async def get_engine_id(client, bp):
|
|
r = await client.get(f"http://127.0.0.1:{bp}/query")
|
|
return r.json()["0"]["engine_id"]
|
|
|
|
|
|
def cached_of(d):
|
|
usage = d.get("usage") or {}
|
|
det = usage.get("prompt_tokens_details") or {}
|
|
return det.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
|
|
|
|
|
async def do_migration(client, src_port, dst_port, src_bp, prompt, max_tokens, label):
|
|
"""Perform Mooncake-style PD-sep migration, returns (src_ms, dst_ms, response)."""
|
|
src_id = await get_engine_id(client, src_bp)
|
|
transfer_id = f"smoke-xfer-{uuid.uuid4().hex[:8]}"
|
|
t0 = time.monotonic()
|
|
src_resp = await send(client, src_port, prompt, max_tokens=1,
|
|
kv_xfer={"do_remote_decode": True, "do_remote_prefill": False,
|
|
"transfer_id": transfer_id}, stream=False)
|
|
t1 = time.monotonic()
|
|
bootstrap_addr = f"http://127.0.0.1:{src_bp}"
|
|
dst_resp = await send(client, dst_port, prompt, max_tokens=max_tokens,
|
|
kv_xfer={"do_remote_decode": False, "do_remote_prefill": True,
|
|
"remote_bootstrap_addr": bootstrap_addr,
|
|
"remote_engine_id": src_id,
|
|
"transfer_id": transfer_id}, stream=True)
|
|
t2 = time.monotonic()
|
|
src_ms = int((t1-t0)*1000); dst_ms = int((t2-t1)*1000)
|
|
cached = cached_of(dst_resp)
|
|
print(f" [{label}] src_prefill={src_ms}ms dst_total={dst_ms}ms cached={cached}/{len(prompt)}")
|
|
return src_ms, dst_ms, dst_resp
|
|
|
|
|
|
async def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--src-port", type=int, default=8000)
|
|
p.add_argument("--src-bp", type=int, default=8998)
|
|
p.add_argument("--dst-warm-port", type=int, default=8001) # will be pre-warmed
|
|
p.add_argument("--dst-warm-bp", type=int, default=8999)
|
|
p.add_argument("--dst-cold-port", type=int, default=8002) # cold control
|
|
p.add_argument("--dst-cold-bp", type=int, default=9000)
|
|
p.add_argument("--prefix-tokens", type=int, default=32768)
|
|
p.add_argument("--ext-tokens", type=int, default=512)
|
|
args = p.parse_args()
|
|
|
|
rng = _r.Random("partial-1")
|
|
prompt_base = [rng.randint(1024, 99_999) for _ in range(args.prefix_tokens)]
|
|
ext_tokens = [_r.Random("ext-partial").randint(1024, 99_999) for _ in range(args.ext_tokens)]
|
|
prompt_ext = prompt_base + ext_tokens
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
print(f"\n=== Setup ===")
|
|
print(f"Prompt prefix: {args.prefix_tokens} tokens, extension: {args.ext_tokens} tokens")
|
|
|
|
# Step 0: warm dst_warm with prompt_base (normal request, no kv_transfer)
|
|
print(f"\n=== Step 0: warm dst_warm (port {args.dst_warm_port}) with prompt_base ===")
|
|
t0 = time.monotonic()
|
|
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
|
|
kv_xfer=None, stream=False)
|
|
t1 = time.monotonic()
|
|
print(f" cold prefill on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
|
|
|
|
# Sanity: 2nd request to dst_warm hits local cache
|
|
print(f"\n=== Sanity: 2nd request to dst_warm with same prompt — should hit local cache ===")
|
|
t0 = time.monotonic()
|
|
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
|
|
kv_xfer=None, stream=False)
|
|
t1 = time.monotonic()
|
|
print(f" warm request on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
|
|
|
|
# Step 1: cold prefill on src with prompt_ext (also caches at src)
|
|
print(f"\n=== Step 1: cold prefill on src (port {args.src_port}) with prompt_ext ===")
|
|
t0 = time.monotonic()
|
|
r = await send(client, args.src_port, prompt_ext, max_tokens=1,
|
|
kv_xfer=None, stream=False)
|
|
t1 = time.monotonic()
|
|
print(f" cold prefill on src: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{len(prompt_ext)}")
|
|
|
|
# Step 2: MIGRATE src -> dst_warm (which has cache for prompt_base)
|
|
print(f"\n=== Step 2: MIGRATE src -> dst_warm (cache-rich) with prompt_ext ===")
|
|
s_ms_warm, d_ms_warm, _ = await do_migration(
|
|
client, args.src_port, args.dst_warm_port, args.src_bp,
|
|
prompt_ext, max_tokens=4, label="cache-rich dst")
|
|
|
|
# Step 3: MIGRATE src -> dst_cold (no cache)
|
|
print(f"\n=== Step 3: MIGRATE src -> dst_cold (port {args.dst_cold_port}, cold) with prompt_ext ===")
|
|
s_ms_cold, d_ms_cold, _ = await do_migration(
|
|
client, args.src_port, args.dst_cold_port, args.src_bp,
|
|
prompt_ext, max_tokens=4, label="cold dst (control)")
|
|
|
|
# Verdict
|
|
print(f"\n=== VERDICT ===")
|
|
print(f"Cache-rich dst (Mechanism B): dst_total={d_ms_warm}ms")
|
|
print(f"Cold dst (full transfer): dst_total={d_ms_cold}ms")
|
|
speedup = (d_ms_cold - d_ms_warm) / d_ms_cold * 100 if d_ms_cold > 0 else 0
|
|
print(f"Δ = {d_ms_cold - d_ms_warm}ms ({speedup:+.1f}% faster with cache-rich dst)")
|
|
print(f"Partial transfer {'WORKING' if speedup > 30 else 'NOT working / not exploited'}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|