Standalone smoke tests validating KV-migration correctness paths before trace replay: full migrate-cache, partial-prefill transfer, and a NIXL-connector variant, each with a runner. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
133 lines
5.6 KiB
Python
133 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Smoke test for Nixl PD-sep migration.
|
|
|
|
Nixl handshake (vs Mooncake's pre-baked engine_id):
|
|
1. POST to src with kv_transfer_params={"do_remote_decode": True},
|
|
max_tokens=1, stream=False.
|
|
2. src returns kv_transfer_params in the response body containing
|
|
remote_block_ids, remote_engine_id, remote_host, remote_port,
|
|
remote_request_id, tp_size.
|
|
3. POST to dst with the SAME kv_transfer_params dict.
|
|
4. dst pulls KV via UCX (NVLink intra-node) and decodes.
|
|
|
|
Verifies migration correctness + measures KV transfer latency on Nixl
|
|
so we can ablate vs Mooncake/RDMA on the same workload.
|
|
"""
|
|
import asyncio, argparse, json, sys, uuid, time
|
|
import httpx
|
|
import random as _r
|
|
|
|
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
|
|
|
|
|
async def send(client, port, prompt, max_tokens, kv_xfer, stream):
|
|
payload = {
|
|
"model": MODEL,
|
|
"prompt": prompt,
|
|
"max_tokens": max_tokens,
|
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
|
"temperature": 0.0,
|
|
"stream": stream,
|
|
}
|
|
if kv_xfer is not None:
|
|
payload["kv_transfer_params"] = kv_xfer
|
|
if stream:
|
|
payload["stream_options"] = {"include_usage": True}
|
|
url = f"http://127.0.0.1:{port}/v1/completions"
|
|
if not stream:
|
|
r = await client.post(url, json=payload, timeout=300.0)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
last_with_usage = None; last_any = None
|
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
|
resp.raise_for_status()
|
|
buf = ""
|
|
async for chunk in resp.aiter_bytes():
|
|
buf += chunk.decode("utf-8", errors="replace")
|
|
while "\n\n" in buf:
|
|
line, buf = buf.split("\n\n", 1)
|
|
if line.startswith("data: "):
|
|
s = line[6:].strip()
|
|
if s == "[DONE]": continue
|
|
try:
|
|
d = json.loads(s); last_any = d
|
|
if d.get("usage"): last_with_usage = d
|
|
except Exception:
|
|
pass
|
|
return last_with_usage or last_any or {}
|
|
|
|
|
|
def short(d):
|
|
if not d: return "no_resp"
|
|
usage = d.get("usage") or {}
|
|
details = usage.get("prompt_tokens_details") or {}
|
|
cached = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
|
return (f"cached={cached}/{usage.get('prompt_tokens',0)} "
|
|
f"completion={usage.get('completion_tokens',0)}")
|
|
|
|
|
|
async def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--src-port", type=int, default=8000)
|
|
p.add_argument("--dst-port", type=int, default=8001)
|
|
p.add_argument("--n-prefix-tokens", type=int, default=8192)
|
|
p.add_argument("--n-extension", type=int, default=32)
|
|
p.add_argument("--decode-tokens", type=int, default=16)
|
|
p.add_argument("--prefix-base", type=int, default=100)
|
|
args = p.parse_args()
|
|
|
|
rng = _r.Random(f"prefix-{args.prefix_base}")
|
|
prompt = [rng.randint(1024, 99_999) for _ in range(args.n_prefix_tokens)]
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
# ----- Step 1: src prefills with do_remote_decode=True -----
|
|
t_src_start = time.monotonic()
|
|
src_resp = await send(
|
|
client, args.src_port, prompt, max_tokens=1,
|
|
kv_xfer={"do_remote_decode": True}, stream=False,
|
|
)
|
|
t_src_done = time.monotonic()
|
|
src_kv = src_resp.get("kv_transfer_params")
|
|
print(f"[1] src prefill ({(t_src_done-t_src_start)*1000:.0f}ms): {short(src_resp)}")
|
|
if not src_kv:
|
|
print(f" FAIL: src returned no kv_transfer_params")
|
|
print(f" response keys: {list(src_resp.keys())}")
|
|
sys.exit(1)
|
|
print(f" src kv_transfer_params keys: {list(src_kv.keys())}")
|
|
print(f" remote_block_ids: {len(src_kv.get('remote_block_ids', [[]])[0]) if src_kv.get('remote_block_ids') else 0} blocks")
|
|
|
|
# ----- Step 2: dst pulls KV using forwarded kv_transfer_params -----
|
|
t_dst_start = time.monotonic()
|
|
dst_resp = await send(
|
|
client, args.dst_port, prompt, max_tokens=args.decode_tokens,
|
|
kv_xfer=src_kv, stream=True,
|
|
)
|
|
t_dst_done = time.monotonic()
|
|
dst_total_ms = (t_dst_done - t_dst_start) * 1000
|
|
n_completion = (dst_resp.get("usage") or {}).get("completion_tokens", 0)
|
|
print(f"[2] dst decode ({dst_total_ms:.0f}ms, {n_completion} completion tokens): {short(dst_resp)}")
|
|
print(f" [TIMING] proto=nixl src_prefill={int((t_src_done-t_src_start)*1000)}ms "
|
|
f"dst_total={int(dst_total_ms)}ms (KV xfer + {n_completion}-token decode)")
|
|
print()
|
|
|
|
# ----- Step 3: follow-up on dst (no kv_transfer_params) -----
|
|
ext = [_r.Random(f"ext-{args.prefix_base}").randint(1024, 99_999)
|
|
for _ in range(args.n_extension)]
|
|
follow_prompt = prompt + ext
|
|
fu = await send(client, args.dst_port, follow_prompt, max_tokens=4,
|
|
kv_xfer=None, stream=False)
|
|
print(f"[3] follow-up dst (cache hit test): {short(fu)}")
|
|
|
|
usage_fu = fu.get("usage") or {}
|
|
details_fu = usage_fu.get("prompt_tokens_details") or {}
|
|
cached_fu = details_fu.get("cached_tokens", 0) or usage_fu.get("cached_tokens", 0)
|
|
expected_min = int(args.n_prefix_tokens * 0.95)
|
|
verdict = "PASS" if cached_fu >= expected_min else "FAIL"
|
|
print(f"\n=== verdict: {verdict} (follow-up cached={cached_fu}, "
|
|
f"expected >= {expected_min} of {args.n_prefix_tokens}) ===")
|
|
sys.exit(0 if verdict == "PASS" else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|