Files
agentic-kvc/microbench/connector_tax/cache_sweep/smoke_nixl_migrate.py
Gahow Wang 41a0c1c48f Migration correctness smoke tests: direct-read, partial-transfer, NIXL
Standalone smoke tests validating KV-migration correctness paths before
trace replay: full migrate-cache, partial-prefill transfer, and a
NIXL-connector variant, each with a runner.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:13 +08:00

133 lines
5.6 KiB
Python

#!/usr/bin/env python3
"""Smoke test for Nixl PD-sep migration.
Nixl handshake (vs Mooncake's pre-baked engine_id):
1. POST to src with kv_transfer_params={"do_remote_decode": True},
max_tokens=1, stream=False.
2. src returns kv_transfer_params in the response body containing
remote_block_ids, remote_engine_id, remote_host, remote_port,
remote_request_id, tp_size.
3. POST to dst with the SAME kv_transfer_params dict.
4. dst pulls KV via UCX (NVLink intra-node) and decodes.
Verifies migration correctness + measures KV transfer latency on Nixl
so we can ablate vs Mooncake/RDMA on the same workload.
"""
import asyncio, argparse, json, sys, uuid, time
import httpx
import random as _r
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
payload = {
"model": MODEL,
"prompt": prompt,
"max_tokens": max_tokens,
"min_tokens": max_tokens if max_tokens == 1 else 1,
"temperature": 0.0,
"stream": stream,
}
if kv_xfer is not None:
payload["kv_transfer_params"] = kv_xfer
if stream:
payload["stream_options"] = {"include_usage": True}
url = f"http://127.0.0.1:{port}/v1/completions"
if not stream:
r = await client.post(url, json=payload, timeout=300.0)
r.raise_for_status()
return r.json()
last_with_usage = None; last_any = None
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
resp.raise_for_status()
buf = ""
async for chunk in resp.aiter_bytes():
buf += chunk.decode("utf-8", errors="replace")
while "\n\n" in buf:
line, buf = buf.split("\n\n", 1)
if line.startswith("data: "):
s = line[6:].strip()
if s == "[DONE]": continue
try:
d = json.loads(s); last_any = d
if d.get("usage"): last_with_usage = d
except Exception:
pass
return last_with_usage or last_any or {}
def short(d):
if not d: return "no_resp"
usage = d.get("usage") or {}
details = usage.get("prompt_tokens_details") or {}
cached = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
return (f"cached={cached}/{usage.get('prompt_tokens',0)} "
f"completion={usage.get('completion_tokens',0)}")
async def main():
p = argparse.ArgumentParser()
p.add_argument("--src-port", type=int, default=8000)
p.add_argument("--dst-port", type=int, default=8001)
p.add_argument("--n-prefix-tokens", type=int, default=8192)
p.add_argument("--n-extension", type=int, default=32)
p.add_argument("--decode-tokens", type=int, default=16)
p.add_argument("--prefix-base", type=int, default=100)
args = p.parse_args()
rng = _r.Random(f"prefix-{args.prefix_base}")
prompt = [rng.randint(1024, 99_999) for _ in range(args.n_prefix_tokens)]
async with httpx.AsyncClient() as client:
# ----- Step 1: src prefills with do_remote_decode=True -----
t_src_start = time.monotonic()
src_resp = await send(
client, args.src_port, prompt, max_tokens=1,
kv_xfer={"do_remote_decode": True}, stream=False,
)
t_src_done = time.monotonic()
src_kv = src_resp.get("kv_transfer_params")
print(f"[1] src prefill ({(t_src_done-t_src_start)*1000:.0f}ms): {short(src_resp)}")
if not src_kv:
print(f" FAIL: src returned no kv_transfer_params")
print(f" response keys: {list(src_resp.keys())}")
sys.exit(1)
print(f" src kv_transfer_params keys: {list(src_kv.keys())}")
print(f" remote_block_ids: {len(src_kv.get('remote_block_ids', [[]])[0]) if src_kv.get('remote_block_ids') else 0} blocks")
# ----- Step 2: dst pulls KV using forwarded kv_transfer_params -----
t_dst_start = time.monotonic()
dst_resp = await send(
client, args.dst_port, prompt, max_tokens=args.decode_tokens,
kv_xfer=src_kv, stream=True,
)
t_dst_done = time.monotonic()
dst_total_ms = (t_dst_done - t_dst_start) * 1000
n_completion = (dst_resp.get("usage") or {}).get("completion_tokens", 0)
print(f"[2] dst decode ({dst_total_ms:.0f}ms, {n_completion} completion tokens): {short(dst_resp)}")
print(f" [TIMING] proto=nixl src_prefill={int((t_src_done-t_src_start)*1000)}ms "
f"dst_total={int(dst_total_ms)}ms (KV xfer + {n_completion}-token decode)")
print()
# ----- Step 3: follow-up on dst (no kv_transfer_params) -----
ext = [_r.Random(f"ext-{args.prefix_base}").randint(1024, 99_999)
for _ in range(args.n_extension)]
follow_prompt = prompt + ext
fu = await send(client, args.dst_port, follow_prompt, max_tokens=4,
kv_xfer=None, stream=False)
print(f"[3] follow-up dst (cache hit test): {short(fu)}")
usage_fu = fu.get("usage") or {}
details_fu = usage_fu.get("prompt_tokens_details") or {}
cached_fu = details_fu.get("cached_tokens", 0) or usage_fu.get("cached_tokens", 0)
expected_min = int(args.n_prefix_tokens * 0.95)
verdict = "PASS" if cached_fu >= expected_min else "FAIL"
print(f"\n=== verdict: {verdict} (follow-up cached={cached_fu}, "
f"expected >= {expected_min} of {args.n_prefix_tokens}) ===")
sys.exit(0 if verdict == "PASS" else 1)
if __name__ == "__main__":
asyncio.run(main())