Files
agentic-kvc/microbench/fresh_setup/mb2_kv_transfer.py
Gahow Wang 622e0bc04c MB2: parameterize vLLM roles (kv_producer + kv_consumer default)
start_vllm_pair.sh
  ROLE_A / ROLE_B env vars (default kv_producer / kv_consumer for strict
  PD-disagg). Override to kv_both for the kv_both control. The role is
  injected into --kv-transfer-config so vLLM imposes the role restriction.

mb2_kv_transfer.py
  --skip-verify flag drops step 3 (the plain completion sanity-check on
  the destination), required when the dst is kv_consumer-only since a
  kv_consumer instance refuses to serve a request without
  do_remote_prefill. The transfer-time itself is still measured from
  step 2 (do_remote_prefill on the consumer).

Also: per-step client-side wall-clock timestamps (t_step1_client_unix,
t_step2_client_unix, t_step2_end_unix) are now captured so the
post-hoc breakdown analyzer can join with the per-instance JSONL logs
on absolute time.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-27 18:17:42 +08:00

229 lines
7.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""MB2: measure KV transfer time between two vLLM instances over Mooncake.
Pattern (adapted from microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py):
1. Prefill on A: do_remote_decode with max_tokens=1 (A computes & caches KV)
2. Pull to B: do_remote_prefill on B with kv_transfer_params from step 1
(this is the operation that performs the KV transfer)
3. Verify: send a follow-up to B; cached_tokens should equal the
prompt length (confirms the KV landed on B)
We time step 2 — that gives us E2E "transfer + B's prefill check" latency.
By sweeping input_length we trace T_transfer(KV_size).
The follow-up step gives us a sanity check (correctness) but isn't timed.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import statistics
import time
import uuid
from pathlib import Path
import httpx
MODEL_PATH = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
async def get_engine_id(client: httpx.AsyncClient, port: int) -> str:
r = await client.get(f"http://127.0.0.1:{port}/query")
r.raise_for_status()
data = r.json()
return data["0"]["engine_id"]
async def completion(
client: httpx.AsyncClient,
port: int,
prompt_token_ids: list[int],
max_tokens: int,
kv_transfer_params: dict | None = None,
) -> tuple[float, dict]:
payload = {
"model": MODEL_PATH,
"prompt": prompt_token_ids,
"max_tokens": max_tokens,
"min_tokens": max_tokens if max_tokens == 1 else 1,
"temperature": 0.0,
"stream": False,
}
if kv_transfer_params:
payload["kv_transfer_params"] = kv_transfer_params
t0 = time.perf_counter()
r = await client.post(
f"http://127.0.0.1:{port}/v1/completions",
json=payload, timeout=600.0,
)
elapsed_s = time.perf_counter() - t0
r.raise_for_status()
return elapsed_s, r.json()
def synth_prompt(rng_seed: int, n_tokens: int) -> list[int]:
"""Deterministic token-id sequence, far enough from special tokens."""
import random
rng = random.Random(rng_seed)
return [rng.randint(100, 150000) for _ in range(n_tokens)]
async def measure_one(
client: httpx.AsyncClient,
src_port: int, dst_port: int,
src_eid: str, dst_eid: str,
input_tokens: int,
rng_seed: int,
skip_verify: bool = False,
) -> dict:
"""Three-step measurement (step 3 is sanity, optional for strict PD-disagg
where the dst is a kv_consumer-only instance that cannot serve a plain
request)."""
prompt = synth_prompt(rng_seed, input_tokens)
session = uuid.uuid4().hex
t_step1_client = time.time()
t_prefill_s, prefill_resp = await completion(
client, src_port, prompt, max_tokens=1,
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": None,
"remote_engine_id": src_eid,
"remote_host": "127.0.0.1",
"remote_port": src_port,
},
)
src_kvp = prefill_resp.get("kv_transfer_params") or {}
t_step2_client = time.time()
t_transfer_s, pull_resp = await completion(
client, dst_port, prompt, max_tokens=1,
kv_transfer_params={
"do_remote_prefill": True,
"remote_block_ids": src_kvp.get("remote_block_ids"),
"remote_engine_id": src_eid,
"remote_host": "127.0.0.1",
"remote_port": src_kvp.get("remote_port", src_port),
},
)
t_step2_end_client = time.time()
cached_followup = None
t_followup_s = None
if not skip_verify:
t_followup_s, follow_resp = await completion(
client, dst_port, prompt, max_tokens=1,
)
usage = (follow_resp.get("usage") or {})
details = usage.get("prompt_tokens_details") or {}
cached_followup = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
pull_usage = (pull_resp.get("usage") or {})
pull_completion_tokens = pull_usage.get("completion_tokens", 0)
ok = pull_completion_tokens >= 1
if not skip_verify and cached_followup is not None:
ok = ok and (cached_followup >= input_tokens * 0.9)
return {
"input_tokens": input_tokens,
"session": session,
"t_step1_client_unix": t_step1_client,
"t_step2_client_unix": t_step2_client,
"t_step2_end_unix": t_step2_end_client,
"t_prefill_s": t_prefill_s,
"t_transfer_s": t_transfer_s,
"t_followup_s": t_followup_s,
"cached_followup": cached_followup,
"pull_completion_tokens": pull_completion_tokens,
"ok": ok,
}
async def main_async(args: argparse.Namespace) -> None:
sizes_str = args.sizes
sizes = [int(s) for s in sizes_str.split(",")]
repeats = args.repeats
src_port, dst_port = args.src_port, args.dst_port
limits = httpx.Limits(max_connections=10, max_keepalive_connections=10)
async with httpx.AsyncClient(limits=limits, trust_env=False) as client:
src_eid = await get_engine_id(client, src_port)
dst_eid = await get_engine_id(client, dst_port)
print(f"[mb2] src_eid={src_eid[:16]}... dst_eid={dst_eid[:16]}...")
results = []
for sz in sizes:
for r in range(repeats):
row = await measure_one(
client, src_port, dst_port, src_eid, dst_eid,
input_tokens=sz, rng_seed=sz * 1000 + r,
skip_verify=args.skip_verify,
)
cached = row.get("cached_followup")
cached_str = f"{cached}/{sz}" if cached is not None else "skip"
print(f" size={sz:>6} rep={r} "
f"transfer={row['t_transfer_s']*1000:7.1f}ms "
f"followup_cached={cached_str} "
f"ok={row['ok']}")
results.append(row)
# Summarise per-size
summary = []
for sz in sizes:
ts = [r["t_transfer_s"] for r in results if r["input_tokens"] == sz and r["ok"]]
if not ts:
continue
summary.append({
"input_tokens": sz,
"n_ok": len(ts),
"transfer_s_mean": statistics.mean(ts),
"transfer_s_p50": statistics.median(ts),
"transfer_s_p90": statistics.quantiles(ts, n=10)[-1] if len(ts) >= 10 else max(ts),
"transfer_s_min": min(ts),
"transfer_s_max": max(ts),
})
out = {
"model": MODEL_PATH,
"kv_bytes_per_token": 98304,
"src_port": src_port,
"dst_port": dst_port,
"config_label": args.label,
"raw": results,
"summary": summary,
}
Path(args.out).write_text(json.dumps(out, indent=2))
print(f"[mb2] wrote {args.out}")
for s in summary:
sz = s["input_tokens"]
kv_mib = sz * 98304 / 1024 / 1024
print(f" {sz:>6} tok ({kv_mib:>7.1f} MiB KV): "
f"mean {s['transfer_s_mean']*1000:7.1f} ms · "
f"p50 {s['transfer_s_p50']*1000:7.1f} · "
f"p90 {s['transfer_s_p90']*1000:7.1f} "
f"(n_ok={s['n_ok']})")
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--src-port", type=int, default=8000)
p.add_argument("--dst-port", type=int, default=8001)
p.add_argument(
"--sizes",
default="512,1024,2048,4096,8192,16384,32768,65536",
help="Comma-separated input_token sizes to sweep",
)
p.add_argument("--repeats", type=int, default=5)
p.add_argument("--label", default="intra-node",
help="Label written into the output (e.g. intra-node / inter-node)")
p.add_argument("--out", default="mb2_result.json")
p.add_argument("--skip-verify", action="store_true",
help="Skip the step-3 verify completion (required for "
"strict PD-disagg where dst is kv_consumer-only).")
args = p.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()