Files
agentic-kvc/microbench/connector_tax/cache_sweep/smoke_partial_transfer.py
Gahow Wang 41a0c1c48f Migration correctness smoke tests: direct-read, partial-transfer, NIXL
Standalone smoke tests validating KV-migration correctness paths before
trace replay: full migrate-cache, partial-prefill transfer, and a
NIXL-connector variant, each with a runner.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:13 +08:00

167 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""Smoke test for partial KV transfer (Mechanism B).
Test if vLLM's Mooncake connector actually transfers only the
NON-OVERLAPPING portion when dst already has prefix cache.
Sequence:
step 0: warm dst with prompt P (cold prefill) — dst now has cache for [0, P]
step 1: cold prefill on src with prompt P+ext (src now has [0, P+ext])
step 2: migrate src→dst with prompt P+ext
- dst local cache should hit [0, P]
- only [P, P+ext] needs to come from src (~ext tokens)
- dst decode should be fast
step 3: control: another migrate src→dst_cold with prompt P+ext
- dst_cold has no cache, must pull all P+ext tokens
- compare with step 2
If step 2 is dramatically faster than step 3, partial transfer works
and Mechanism B is viable. If step 2 ~= step 3, partial transfer isn't
being exploited and we need to dig deeper.
"""
import asyncio, argparse, json, sys, uuid, time
import httpx
import random as _r
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
payload = {
"model": MODEL,
"prompt": prompt,
"max_tokens": max_tokens,
"min_tokens": max_tokens if max_tokens == 1 else 1,
"temperature": 0.0,
"stream": stream,
}
if kv_xfer is not None:
payload["kv_transfer_params"] = kv_xfer
if stream:
payload["stream_options"] = {"include_usage": True}
url = f"http://127.0.0.1:{port}/v1/completions"
if not stream:
r = await client.post(url, json=payload, timeout=300.0)
r.raise_for_status()
return r.json()
last_w = None; last_any = None
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
resp.raise_for_status()
buf = ""
async for chunk in resp.aiter_bytes():
buf += chunk.decode("utf-8", errors="replace")
while "\n\n" in buf:
line, buf = buf.split("\n\n", 1)
if line.startswith("data: "):
s = line[6:].strip()
if s == "[DONE]": continue
try:
d = json.loads(s); last_any = d
if d.get("usage"): last_w = d
except Exception:
pass
return last_w or last_any or {}
async def get_engine_id(client, bp):
r = await client.get(f"http://127.0.0.1:{bp}/query")
return r.json()["0"]["engine_id"]
def cached_of(d):
usage = d.get("usage") or {}
det = usage.get("prompt_tokens_details") or {}
return det.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
async def do_migration(client, src_port, dst_port, src_bp, prompt, max_tokens, label):
"""Perform Mooncake-style PD-sep migration, returns (src_ms, dst_ms, response)."""
src_id = await get_engine_id(client, src_bp)
transfer_id = f"smoke-xfer-{uuid.uuid4().hex[:8]}"
t0 = time.monotonic()
src_resp = await send(client, src_port, prompt, max_tokens=1,
kv_xfer={"do_remote_decode": True, "do_remote_prefill": False,
"transfer_id": transfer_id}, stream=False)
t1 = time.monotonic()
bootstrap_addr = f"http://127.0.0.1:{src_bp}"
dst_resp = await send(client, dst_port, prompt, max_tokens=max_tokens,
kv_xfer={"do_remote_decode": False, "do_remote_prefill": True,
"remote_bootstrap_addr": bootstrap_addr,
"remote_engine_id": src_id,
"transfer_id": transfer_id}, stream=True)
t2 = time.monotonic()
src_ms = int((t1-t0)*1000); dst_ms = int((t2-t1)*1000)
cached = cached_of(dst_resp)
print(f" [{label}] src_prefill={src_ms}ms dst_total={dst_ms}ms cached={cached}/{len(prompt)}")
return src_ms, dst_ms, dst_resp
async def main():
p = argparse.ArgumentParser()
p.add_argument("--src-port", type=int, default=8000)
p.add_argument("--src-bp", type=int, default=8998)
p.add_argument("--dst-warm-port", type=int, default=8001) # will be pre-warmed
p.add_argument("--dst-warm-bp", type=int, default=8999)
p.add_argument("--dst-cold-port", type=int, default=8002) # cold control
p.add_argument("--dst-cold-bp", type=int, default=9000)
p.add_argument("--prefix-tokens", type=int, default=32768)
p.add_argument("--ext-tokens", type=int, default=512)
args = p.parse_args()
rng = _r.Random("partial-1")
prompt_base = [rng.randint(1024, 99_999) for _ in range(args.prefix_tokens)]
ext_tokens = [_r.Random("ext-partial").randint(1024, 99_999) for _ in range(args.ext_tokens)]
prompt_ext = prompt_base + ext_tokens
async with httpx.AsyncClient() as client:
print(f"\n=== Setup ===")
print(f"Prompt prefix: {args.prefix_tokens} tokens, extension: {args.ext_tokens} tokens")
# Step 0: warm dst_warm with prompt_base (normal request, no kv_transfer)
print(f"\n=== Step 0: warm dst_warm (port {args.dst_warm_port}) with prompt_base ===")
t0 = time.monotonic()
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
kv_xfer=None, stream=False)
t1 = time.monotonic()
print(f" cold prefill on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
# Sanity: 2nd request to dst_warm hits local cache
print(f"\n=== Sanity: 2nd request to dst_warm with same prompt — should hit local cache ===")
t0 = time.monotonic()
r = await send(client, args.dst_warm_port, prompt_base, max_tokens=1,
kv_xfer=None, stream=False)
t1 = time.monotonic()
print(f" warm request on dst_warm: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{args.prefix_tokens}")
# Step 1: cold prefill on src with prompt_ext (also caches at src)
print(f"\n=== Step 1: cold prefill on src (port {args.src_port}) with prompt_ext ===")
t0 = time.monotonic()
r = await send(client, args.src_port, prompt_ext, max_tokens=1,
kv_xfer=None, stream=False)
t1 = time.monotonic()
print(f" cold prefill on src: {int((t1-t0)*1000)}ms, cached={cached_of(r)}/{len(prompt_ext)}")
# Step 2: MIGRATE src -> dst_warm (which has cache for prompt_base)
print(f"\n=== Step 2: MIGRATE src -> dst_warm (cache-rich) with prompt_ext ===")
s_ms_warm, d_ms_warm, _ = await do_migration(
client, args.src_port, args.dst_warm_port, args.src_bp,
prompt_ext, max_tokens=4, label="cache-rich dst")
# Step 3: MIGRATE src -> dst_cold (no cache)
print(f"\n=== Step 3: MIGRATE src -> dst_cold (port {args.dst_cold_port}, cold) with prompt_ext ===")
s_ms_cold, d_ms_cold, _ = await do_migration(
client, args.src_port, args.dst_cold_port, args.src_bp,
prompt_ext, max_tokens=4, label="cold dst (control)")
# Verdict
print(f"\n=== VERDICT ===")
print(f"Cache-rich dst (Mechanism B): dst_total={d_ms_warm}ms")
print(f"Cold dst (full transfer): dst_total={d_ms_cold}ms")
speedup = (d_ms_cold - d_ms_warm) / d_ms_cold * 100 if d_ms_cold > 0 else 0
print(f"Δ = {d_ms_cold - d_ms_warm}ms ({speedup:+.1f}% faster with cache-rich dst)")
print(f"Partial transfer {'WORKING' if speedup > 30 else 'NOT working / not exploited'}.")
if __name__ == "__main__":
asyncio.run(main())