Files
agentic-kvc/microbench/connector_tax/layerwise/mb7_layerwise.py
Gahow Wang e77bdcac5a Layerwise under load: overlap benefit survives (bg=16)
mb7 with background decode load (8/instance). Critical-path transfer overhead
stays ~constant ~90ms for layerwise vs 158/239/749ms baseline (up to 7.9x at
32k), prefill not slowed, KV correct. Confirms the overlap holds on busy
instances. DESIGN.md updated with idle-vs-load table + the two blockers
(chunk-safety, concurrent-transfer safety) that the full 1200-req trace needs.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 16:30:14 +08:00

269 lines
11 KiB
Python

#!/usr/bin/env python3
"""MB7: correctness + perf of layer-wise KV push vs post-hoc transfer.
Two 2-instance modes against A (src/producer) and B (dst/consumer):
baseline : prefill A (await) -> THEN B pulls (post-hoc full transfer).
T_total = T_prefill + T_xfer (sequential)
layerwise : dispatch B's remote-prefill (handshake) and A's prefill
CONCURRENTLY, so A pushes each layer as it computes it.
If overlap works, T_total ~= max(T_prefill, T_xfer) ~= T_prefill.
Reference: T_prefill_only = a plain prefill on A with no transfer.
Correctness: after the transfer, a plain follow-up to B on the same prompt
must report cached_tokens >= ~prompt_len (the KV actually landed on B).
The connector mode is selected by the launcher (run_mb7.sh): baseline uses the
stock connector; layerwise deploys mooncake_connector.LAYERWISE.py +
MOONCAKE_LAYERWISE=1. This script just drives the requests and measures.
Usage:
python mb7_layerwise.py --mode layerwise --sizes 8192,32768,65536 --repeats 3 \
--src-port 8000 --dst-port 8001 --src-bp 8998 --dst-bp 8999 --out mb7.json
"""
from __future__ import annotations
import argparse
import asyncio
import json
import statistics
import time
import uuid
from pathlib import Path
import httpx
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
KV_PER_TOK = 98304
def synth_prompt(seed: int, n: int) -> list[int]:
import random
rng = random.Random(seed)
return [rng.randint(100, 150000) for _ in range(n)]
async def get_engine_id(client, host, bp):
r = await client.get(f"http://{host}:{bp}/query")
r.raise_for_status()
return r.json()["0"]["engine_id"]
async def completion(client, host, port, prompt, max_tokens, ktp=None):
payload = {
"model": MODEL, "prompt": prompt, "max_tokens": max_tokens,
"min_tokens": max_tokens if max_tokens == 1 else 1,
"temperature": 0.0, "stream": False,
}
if ktp:
payload["kv_transfer_params"] = ktp
t0 = time.perf_counter()
r = await client.post(f"http://{host}:{port}/v1/completions",
json=payload, timeout=600.0)
dt = time.perf_counter() - t0
r.raise_for_status()
return dt, r.json()
def cached_of(resp) -> int:
usage = resp.get("usage") or {}
det = usage.get("prompt_tokens_details") or {}
return det.get("cached_tokens", 0) or usage.get("cached_tokens", 0) or 0
async def _stream_completion(client, host, port, prompt, max_tokens):
payload = {"model": MODEL, "prompt": prompt, "max_tokens": max_tokens,
"min_tokens": 1, "temperature": 0.0, "stream": True}
async with client.stream("POST", f"http://{host}:{port}/v1/completions",
json=payload, timeout=600.0) as r:
r.raise_for_status()
async for _ in r.aiter_bytes():
pass
class BackgroundLoad:
"""Hold N concurrent long-decode streams across endpoints to keep busy."""
def __init__(self, client, endpoints, n, prompt_tokens=2000, out_tokens=6000):
self.client, self.endpoints, self.n = client, endpoints, n
self.pt, self.ot = prompt_tokens, out_tokens
self._stop = asyncio.Event()
self._tasks = []
async def _w(self, idx):
host, port = self.endpoints[idx % len(self.endpoints)]
seed = 800000 + idx
while not self._stop.is_set():
try:
await _stream_completion(self.client, host, port,
synth_prompt(seed, self.pt), self.ot)
except Exception:
await asyncio.sleep(0.5)
seed += 1
def start(self):
self._tasks = [asyncio.create_task(self._w(i)) for i in range(self.n)]
async def stop(self):
self._stop.set()
for t in self._tasks:
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
async def num_running(client, host, port):
try:
r = await client.get(f"http://{host}:{port}/metrics", timeout=5.0)
for line in r.text.splitlines():
if line.startswith("vllm:num_requests_running"):
return int(float(line.split()[-1]))
except Exception:
pass
return -1
async def prefill_only(client, host, port, prompt):
"""Reference: plain prefill cost on A, no transfer."""
dt, _ = await completion(client, host, port, prompt, max_tokens=1)
return dt
async def measure_baseline(client, A, B, src_eid, src_bp_addr, prompt, seed):
tid = uuid.uuid4().hex
t0 = time.perf_counter()
t_pf, _ = await completion(client, *A, prompt, 1,
ktp={"do_remote_decode": True, "transfer_id": tid})
t_xfer, _ = await completion(client, *B, prompt, 1,
ktp={"do_remote_prefill": True, "transfer_id": tid,
"remote_engine_id": src_eid,
"remote_bootstrap_addr": src_bp_addr})
t_total = time.perf_counter() - t0
# correctness: B follow-up should hit cache
_, fr = await completion(client, *B, prompt, 1)
return {"t_prefill_s": t_pf, "t_xfer_s": t_xfer, "t_total_s": t_total,
"cached": cached_of(fr)}
async def measure_layerwise(client, A, B, src_eid, src_bp_addr, prompt, seed):
"""Dispatch B handshake + A prefill concurrently => layer-wise overlap."""
tid = uuid.uuid4().hex
t0 = time.perf_counter()
async def run_B():
return await completion(client, *B, prompt, 1,
ktp={"do_remote_prefill": True, "transfer_id": tid,
"remote_engine_id": src_eid,
"remote_bootstrap_addr": src_bp_addr})
async def run_A():
# small head start for B's handshake to reach A before A's forward
await asyncio.sleep(0.05)
return await completion(client, *A, prompt, 1,
ktp={"do_remote_decode": True, "transfer_id": tid})
b_task = asyncio.create_task(run_B())
a_task = asyncio.create_task(run_A())
(t_b, _), (t_a, _) = await asyncio.gather(b_task, a_task)
t_total = time.perf_counter() - t0
_, fr = await completion(client, *B, prompt, 1)
return {"t_A_s": t_a, "t_B_s": t_b, "t_total_s": t_total,
"cached": cached_of(fr)}
async def main_async(a):
sizes = [int(s) for s in a.sizes.split(",")]
A = (a.src_host, a.src_port)
B = (a.dst_host, a.dst_port)
limits = httpx.Limits(max_connections=64, max_keepalive_connections=64)
async with httpx.AsyncClient(limits=limits, trust_env=False) as client:
src_eid = await get_engine_id(client, a.src_host, a.src_bp)
src_bp_addr = f"http://{a.src_host}:{a.src_bp}"
print(f"[mb7] mode={a.mode} bg_load={a.bg_load} src_eid={src_eid[:16]}...")
loader = None
if a.bg_load > 0:
loader = BackgroundLoad(client, [A, B], a.bg_load)
loader.start()
print(f"[mb7] ramping background load ({a.bg_load}) ...")
for _ in range(40):
await asyncio.sleep(1.0)
na = await num_running(client, *A)
nb = await num_running(client, *B)
if na >= 1 and nb >= 1:
print(f"[mb7] busy: A_run={na} B_run={nb}")
break
results = []
for sz in sizes:
for rep in range(a.repeats):
prompt = synth_prompt(sz * 100 + rep, sz)
# reference prefill-only cost (fresh prompt, different seed so no cache)
t_pf_only = await prefill_only(
client, *A, synth_prompt(sz * 100 + rep + 555, sz))
if a.mode == "baseline":
row = await measure_baseline(client, A, B, src_eid, src_bp_addr,
prompt, sz * 100 + rep)
else:
row = await measure_layerwise(client, A, B, src_eid, src_bp_addr,
prompt, sz * 100 + rep)
row.update({"mode": a.mode, "size": sz, "rep": rep,
"t_prefill_only_s": t_pf_only,
"kv_gib": sz * KV_PER_TOK / 2**30,
"correct": row["cached"] >= int(sz * 0.9)})
results.append(row)
extra = (f"xfer={row.get('t_xfer_s', 0)*1000:.0f}ms"
if a.mode == "baseline"
else f"tA={row.get('t_A_s',0)*1000:.0f}ms tB={row.get('t_B_s',0)*1000:.0f}ms")
print(f" sz={sz:>6} rep={rep} pf_only={t_pf_only*1000:6.0f}ms "
f"total={row['t_total_s']*1000:7.0f}ms {extra} "
f"cached={row['cached']}/{sz} correct={row['correct']}")
if loader:
await loader.stop()
# summary
print(f"\n=== {a.mode} (bg={a.bg_load}) summary ===")
print(f"{'size':>7} {'n':>2} {'pf_only_ms':>11} {'total_ms':>9} "
f"{'overhead_ms':>12} {'correct':>8}")
summary = []
for sz in sizes:
rs = [r for r in results if r["size"] == sz]
if not rs:
continue
pf = statistics.median(r["t_prefill_only_s"] for r in rs) * 1000
tot = statistics.median(r["t_total_s"] for r in rs) * 1000
allok = all(r["correct"] for r in rs)
# overhead = total - prefill_only = the part NOT hidden behind prefill
overhead = tot - pf
summary.append({"size": sz, "n": len(rs), "pf_only_ms": pf,
"total_ms": tot, "overhead_ms": overhead,
"all_correct": allok})
print(f"{sz:>7} {len(rs):>2} {pf:>11.0f} {tot:>9.0f} {overhead:>12.0f} "
f"{str(allok):>8}")
Path(a.out).write_text(json.dumps(
{"mode": a.mode, "model": MODEL, "raw": results, "summary": summary}, indent=2))
print(f"\n[mb7] wrote {a.out}")
def main():
p = argparse.ArgumentParser()
p.add_argument("--mode", choices=["baseline", "layerwise"], required=True)
p.add_argument("--src-host", default="127.0.0.1")
p.add_argument("--dst-host", default="127.0.0.1")
p.add_argument("--src-port", type=int, default=8000)
p.add_argument("--dst-port", type=int, default=8001)
p.add_argument("--src-bp", type=int, default=8998)
p.add_argument("--dst-bp", type=int, default=8999)
p.add_argument("--sizes", default="8192,32768,65536")
p.add_argument("--repeats", type=int, default=3)
p.add_argument("--bg-load", type=int, default=0,
help="N concurrent background decode streams across A+B")
p.add_argument("--out", default="mb7_result.json")
args = p.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()