Adds the pieces needed to run the producer on dash1 and the consumer on dash2 with the same shared cpfs venv: start_vllm_single.sh INSTANCE / GPU / PORT / BP / MASTER / ROLE env vars; brings up ONE vLLM instance + applies the mooncake instrumentation patch (idempotent since the venv is cpfs-shared, so the first invocation applies and the second is a no-op). Per-instance MB2_LOG_DIR keeps producer/consumer events separate even though both directories live on the same cpfs path visible to both hosts. mb2_kv_transfer.py New --src-host / --dst-host args. Defaults stay 127.0.0.1 for backward-compat with the intra-node sweep. /v1/completions URLs and /query URLs now use the supplied hosts. remote_bootstrap_addr is built as http://<src_host>:<src_bp> so the consumer's do_remote_prefill request carries a routable address. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
262 lines
9.6 KiB
Python
Executable File
262 lines
9.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""MB2: measure KV transfer time between two vLLM instances over Mooncake.
|
|
|
|
Pattern (adapted from microbench/connector_tax/cache_sweep/smoke_test_migrate_cache.py):
|
|
|
|
1. Prefill on A: do_remote_decode with max_tokens=1 (A computes & caches KV)
|
|
2. Pull to B: do_remote_prefill on B with kv_transfer_params from step 1
|
|
(this is the operation that performs the KV transfer)
|
|
3. Verify: send a follow-up to B; cached_tokens should equal the
|
|
prompt length (confirms the KV landed on B)
|
|
|
|
We time step 2 — that gives us E2E "transfer + B's prefill check" latency.
|
|
By sweeping input_length we trace T_transfer(KV_size).
|
|
|
|
The follow-up step gives us a sanity check (correctness) but isn't timed.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import statistics
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
MODEL_PATH = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
|
|
|
|
|
async def get_engine_id(client: httpx.AsyncClient, host: str, bootstrap_port: int) -> str:
|
|
"""The /query endpoint lives on the mooncake bootstrap port, not the
|
|
vLLM HTTP serving port."""
|
|
r = await client.get(f"http://{host}:{bootstrap_port}/query")
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
return data["0"]["engine_id"]
|
|
|
|
|
|
async def completion(
|
|
client: httpx.AsyncClient,
|
|
host: str,
|
|
port: int,
|
|
prompt_token_ids: list[int],
|
|
max_tokens: int,
|
|
kv_transfer_params: dict | None = None,
|
|
) -> tuple[float, dict]:
|
|
payload = {
|
|
"model": MODEL_PATH,
|
|
"prompt": prompt_token_ids,
|
|
"max_tokens": max_tokens,
|
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
|
"temperature": 0.0,
|
|
"stream": False,
|
|
}
|
|
if kv_transfer_params:
|
|
payload["kv_transfer_params"] = kv_transfer_params
|
|
t0 = time.perf_counter()
|
|
r = await client.post(
|
|
f"http://{host}:{port}/v1/completions",
|
|
json=payload, timeout=600.0,
|
|
)
|
|
elapsed_s = time.perf_counter() - t0
|
|
r.raise_for_status()
|
|
return elapsed_s, r.json()
|
|
|
|
|
|
def synth_prompt(rng_seed: int, n_tokens: int) -> list[int]:
|
|
"""Deterministic token-id sequence, far enough from special tokens."""
|
|
import random
|
|
rng = random.Random(rng_seed)
|
|
return [rng.randint(100, 150000) for _ in range(n_tokens)]
|
|
|
|
|
|
async def measure_one(
|
|
client: httpx.AsyncClient,
|
|
src_host: str, src_port: int,
|
|
dst_host: str, dst_port: int,
|
|
src_eid: str, dst_eid: str,
|
|
src_bootstrap_addr: str,
|
|
input_tokens: int,
|
|
rng_seed: int,
|
|
skip_verify: bool = False,
|
|
) -> dict:
|
|
"""Three-step measurement (step 3 is sanity, optional for strict PD-disagg
|
|
where the dst is a kv_consumer-only instance that cannot serve a plain
|
|
request).
|
|
|
|
vLLM-shipped MooncakeConnector kv_transfer_params schema
|
|
(vllm/distributed/.../v1/mooncake/mooncake_connector.py:385):
|
|
do_remote_decode : {transfer_id}
|
|
do_remote_prefill : {transfer_id, remote_engine_id, remote_bootstrap_addr}
|
|
"""
|
|
prompt = synth_prompt(rng_seed, input_tokens)
|
|
session = uuid.uuid4().hex
|
|
transfer_id = uuid.uuid4().hex
|
|
t_step1_client = time.time()
|
|
t_prefill_s, prefill_resp = await completion(
|
|
client, src_host, src_port, prompt, max_tokens=1,
|
|
kv_transfer_params={
|
|
"do_remote_decode": True,
|
|
"transfer_id": transfer_id,
|
|
},
|
|
)
|
|
t_step2_client = time.time()
|
|
t_transfer_s, pull_resp = await completion(
|
|
client, dst_host, dst_port, prompt, max_tokens=1,
|
|
kv_transfer_params={
|
|
"do_remote_prefill": True,
|
|
"transfer_id": transfer_id,
|
|
"remote_engine_id": src_eid,
|
|
"remote_bootstrap_addr": src_bootstrap_addr,
|
|
},
|
|
)
|
|
t_step2_end_client = time.time()
|
|
|
|
cached_followup = None
|
|
t_followup_s = None
|
|
if not skip_verify:
|
|
t_followup_s, follow_resp = await completion(
|
|
client, dst_host, dst_port, prompt, max_tokens=1,
|
|
)
|
|
usage = (follow_resp.get("usage") or {})
|
|
details = usage.get("prompt_tokens_details") or {}
|
|
cached_followup = details.get("cached_tokens", 0) or usage.get("cached_tokens", 0)
|
|
|
|
pull_usage = (pull_resp.get("usage") or {})
|
|
pull_completion_tokens = pull_usage.get("completion_tokens", 0)
|
|
ok = pull_completion_tokens >= 1
|
|
if not skip_verify and cached_followup is not None:
|
|
ok = ok and (cached_followup >= input_tokens * 0.9)
|
|
|
|
return {
|
|
"input_tokens": input_tokens,
|
|
"session": session,
|
|
"t_step1_client_unix": t_step1_client,
|
|
"t_step2_client_unix": t_step2_client,
|
|
"t_step2_end_unix": t_step2_end_client,
|
|
"t_prefill_s": t_prefill_s,
|
|
"t_transfer_s": t_transfer_s,
|
|
"t_followup_s": t_followup_s,
|
|
"cached_followup": cached_followup,
|
|
"pull_completion_tokens": pull_completion_tokens,
|
|
"ok": ok,
|
|
}
|
|
|
|
|
|
async def main_async(args: argparse.Namespace) -> None:
|
|
sizes_str = args.sizes
|
|
sizes = [int(s) for s in sizes_str.split(",")]
|
|
repeats = args.repeats
|
|
src_host, dst_host = args.src_host, args.dst_host
|
|
src_port, dst_port = args.src_port, args.dst_port
|
|
src_bp, dst_bp = args.src_bp, args.dst_bp
|
|
|
|
limits = httpx.Limits(max_connections=10, max_keepalive_connections=10)
|
|
async with httpx.AsyncClient(limits=limits, trust_env=False) as client:
|
|
src_eid = await get_engine_id(client, src_host, src_bp)
|
|
try:
|
|
dst_eid = await get_engine_id(client, dst_host, dst_bp)
|
|
except Exception as exc:
|
|
print(f"[mb2] dst bootstrap ({dst_host}:{dst_bp}) unreachable ({exc.__class__.__name__}); "
|
|
f"running in strict-PD mode (dst is kv_consumer, no bootstrap).")
|
|
dst_eid = None
|
|
print(f"[mb2] src={src_host}:{src_port} bp={src_bp} eid={src_eid[:16]}...")
|
|
print(f"[mb2] dst={dst_host}:{dst_port} bp={dst_bp} eid="
|
|
f"{(dst_eid[:16] + '...') if dst_eid else 'N/A (kv_consumer)'}")
|
|
|
|
src_bootstrap_addr = f"http://{src_host}:{src_bp}"
|
|
results = []
|
|
for sz in sizes:
|
|
for r in range(repeats):
|
|
row = await measure_one(
|
|
client,
|
|
src_host, src_port, dst_host, dst_port,
|
|
src_eid, dst_eid,
|
|
src_bootstrap_addr=src_bootstrap_addr,
|
|
input_tokens=sz, rng_seed=sz * 1000 + r,
|
|
skip_verify=args.skip_verify,
|
|
)
|
|
cached = row.get("cached_followup")
|
|
cached_str = f"{cached}/{sz}" if cached is not None else "skip"
|
|
print(f" size={sz:>6} rep={r} "
|
|
f"transfer={row['t_transfer_s']*1000:7.1f}ms "
|
|
f"followup_cached={cached_str} "
|
|
f"ok={row['ok']}")
|
|
results.append(row)
|
|
|
|
# Summarise per-size
|
|
summary = []
|
|
for sz in sizes:
|
|
ts = [r["t_transfer_s"] for r in results if r["input_tokens"] == sz and r["ok"]]
|
|
if not ts:
|
|
continue
|
|
summary.append({
|
|
"input_tokens": sz,
|
|
"n_ok": len(ts),
|
|
"transfer_s_mean": statistics.mean(ts),
|
|
"transfer_s_p50": statistics.median(ts),
|
|
"transfer_s_p90": statistics.quantiles(ts, n=10)[-1] if len(ts) >= 10 else max(ts),
|
|
"transfer_s_min": min(ts),
|
|
"transfer_s_max": max(ts),
|
|
})
|
|
|
|
out = {
|
|
"model": MODEL_PATH,
|
|
"kv_bytes_per_token": 98304,
|
|
"src_host": src_host,
|
|
"src_port": src_port,
|
|
"dst_host": dst_host,
|
|
"dst_port": dst_port,
|
|
"config_label": args.label,
|
|
"raw": results,
|
|
"summary": summary,
|
|
}
|
|
Path(args.out).write_text(json.dumps(out, indent=2))
|
|
print(f"[mb2] wrote {args.out}")
|
|
for s in summary:
|
|
sz = s["input_tokens"]
|
|
kv_mib = sz * 98304 / 1024 / 1024
|
|
print(f" {sz:>6} tok ({kv_mib:>7.1f} MiB KV): "
|
|
f"mean {s['transfer_s_mean']*1000:7.1f} ms · "
|
|
f"p50 {s['transfer_s_p50']*1000:7.1f} · "
|
|
f"p90 {s['transfer_s_p90']*1000:7.1f} "
|
|
f"(n_ok={s['n_ok']})")
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--src-host", default="127.0.0.1",
|
|
help="hostname/IP of the producer (A); use a routable "
|
|
"address for inter-node tests")
|
|
p.add_argument("--dst-host", default="127.0.0.1",
|
|
help="hostname/IP of the consumer (B)")
|
|
p.add_argument("--src-port", type=int, default=8000,
|
|
help="vLLM HTTP port on the producer side (A)")
|
|
p.add_argument("--dst-port", type=int, default=8001,
|
|
help="vLLM HTTP port on the consumer side (B)")
|
|
p.add_argument("--src-bp", type=int, default=8998,
|
|
help="Mooncake bootstrap port on A (serves /query)")
|
|
p.add_argument("--dst-bp", type=int, default=8999,
|
|
help="Mooncake bootstrap port on B (serves /query)")
|
|
p.add_argument(
|
|
"--sizes",
|
|
default="512,1024,2048,4096,8192,16384,32768,65536",
|
|
help="Comma-separated input_token sizes to sweep",
|
|
)
|
|
p.add_argument("--repeats", type=int, default=5)
|
|
p.add_argument("--label", default="intra-node",
|
|
help="Label written into the output (e.g. intra-node / inter-node)")
|
|
p.add_argument("--out", default="mb2_result.json")
|
|
p.add_argument("--skip-verify", action="store_true",
|
|
help="Skip the step-3 verify completion (required for "
|
|
"strict PD-disagg where dst is kv_consumer-only).")
|
|
args = p.parse_args()
|
|
asyncio.run(main_async(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|