mb7 with background decode load (8/instance). Critical-path transfer overhead stays ~constant ~90ms for layerwise vs 158/239/749ms baseline (up to 7.9x at 32k), prefill not slowed, KV correct. Confirms the overlap holds on busy instances. DESIGN.md updated with idle-vs-load table + the two blockers (chunk-safety, concurrent-transfer safety) that the full 1200-req trace needs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
269 lines
11 KiB
Python
269 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""MB7: correctness + perf of layer-wise KV push vs post-hoc transfer.
|
|
|
|
Two 2-instance modes against A (src/producer) and B (dst/consumer):
|
|
|
|
baseline : prefill A (await) -> THEN B pulls (post-hoc full transfer).
|
|
T_total = T_prefill + T_xfer (sequential)
|
|
layerwise : dispatch B's remote-prefill (handshake) and A's prefill
|
|
CONCURRENTLY, so A pushes each layer as it computes it.
|
|
If overlap works, T_total ~= max(T_prefill, T_xfer) ~= T_prefill.
|
|
|
|
Reference: T_prefill_only = a plain prefill on A with no transfer.
|
|
|
|
Correctness: after the transfer, a plain follow-up to B on the same prompt
|
|
must report cached_tokens >= ~prompt_len (the KV actually landed on B).
|
|
|
|
The connector mode is selected by the launcher (run_mb7.sh): baseline uses the
|
|
stock connector; layerwise deploys mooncake_connector.LAYERWISE.py +
|
|
MOONCAKE_LAYERWISE=1. This script just drives the requests and measures.
|
|
|
|
Usage:
|
|
python mb7_layerwise.py --mode layerwise --sizes 8192,32768,65536 --repeats 3 \
|
|
--src-port 8000 --dst-port 8001 --src-bp 8998 --dst-bp 8999 --out mb7.json
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import statistics
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
|
|
MODEL = "/home/admin/cpfs/wjh/models/Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
|
KV_PER_TOK = 98304
|
|
|
|
|
|
def synth_prompt(seed: int, n: int) -> list[int]:
|
|
import random
|
|
rng = random.Random(seed)
|
|
return [rng.randint(100, 150000) for _ in range(n)]
|
|
|
|
|
|
async def get_engine_id(client, host, bp):
|
|
r = await client.get(f"http://{host}:{bp}/query")
|
|
r.raise_for_status()
|
|
return r.json()["0"]["engine_id"]
|
|
|
|
|
|
async def completion(client, host, port, prompt, max_tokens, ktp=None):
|
|
payload = {
|
|
"model": MODEL, "prompt": prompt, "max_tokens": max_tokens,
|
|
"min_tokens": max_tokens if max_tokens == 1 else 1,
|
|
"temperature": 0.0, "stream": False,
|
|
}
|
|
if ktp:
|
|
payload["kv_transfer_params"] = ktp
|
|
t0 = time.perf_counter()
|
|
r = await client.post(f"http://{host}:{port}/v1/completions",
|
|
json=payload, timeout=600.0)
|
|
dt = time.perf_counter() - t0
|
|
r.raise_for_status()
|
|
return dt, r.json()
|
|
|
|
|
|
def cached_of(resp) -> int:
|
|
usage = resp.get("usage") or {}
|
|
det = usage.get("prompt_tokens_details") or {}
|
|
return det.get("cached_tokens", 0) or usage.get("cached_tokens", 0) or 0
|
|
|
|
|
|
async def _stream_completion(client, host, port, prompt, max_tokens):
|
|
payload = {"model": MODEL, "prompt": prompt, "max_tokens": max_tokens,
|
|
"min_tokens": 1, "temperature": 0.0, "stream": True}
|
|
async with client.stream("POST", f"http://{host}:{port}/v1/completions",
|
|
json=payload, timeout=600.0) as r:
|
|
r.raise_for_status()
|
|
async for _ in r.aiter_bytes():
|
|
pass
|
|
|
|
|
|
class BackgroundLoad:
|
|
"""Hold N concurrent long-decode streams across endpoints to keep busy."""
|
|
def __init__(self, client, endpoints, n, prompt_tokens=2000, out_tokens=6000):
|
|
self.client, self.endpoints, self.n = client, endpoints, n
|
|
self.pt, self.ot = prompt_tokens, out_tokens
|
|
self._stop = asyncio.Event()
|
|
self._tasks = []
|
|
|
|
async def _w(self, idx):
|
|
host, port = self.endpoints[idx % len(self.endpoints)]
|
|
seed = 800000 + idx
|
|
while not self._stop.is_set():
|
|
try:
|
|
await _stream_completion(self.client, host, port,
|
|
synth_prompt(seed, self.pt), self.ot)
|
|
except Exception:
|
|
await asyncio.sleep(0.5)
|
|
seed += 1
|
|
|
|
def start(self):
|
|
self._tasks = [asyncio.create_task(self._w(i)) for i in range(self.n)]
|
|
|
|
async def stop(self):
|
|
self._stop.set()
|
|
for t in self._tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*self._tasks, return_exceptions=True)
|
|
|
|
|
|
async def num_running(client, host, port):
|
|
try:
|
|
r = await client.get(f"http://{host}:{port}/metrics", timeout=5.0)
|
|
for line in r.text.splitlines():
|
|
if line.startswith("vllm:num_requests_running"):
|
|
return int(float(line.split()[-1]))
|
|
except Exception:
|
|
pass
|
|
return -1
|
|
|
|
|
|
async def prefill_only(client, host, port, prompt):
|
|
"""Reference: plain prefill cost on A, no transfer."""
|
|
dt, _ = await completion(client, host, port, prompt, max_tokens=1)
|
|
return dt
|
|
|
|
|
|
async def measure_baseline(client, A, B, src_eid, src_bp_addr, prompt, seed):
|
|
tid = uuid.uuid4().hex
|
|
t0 = time.perf_counter()
|
|
t_pf, _ = await completion(client, *A, prompt, 1,
|
|
ktp={"do_remote_decode": True, "transfer_id": tid})
|
|
t_xfer, _ = await completion(client, *B, prompt, 1,
|
|
ktp={"do_remote_prefill": True, "transfer_id": tid,
|
|
"remote_engine_id": src_eid,
|
|
"remote_bootstrap_addr": src_bp_addr})
|
|
t_total = time.perf_counter() - t0
|
|
# correctness: B follow-up should hit cache
|
|
_, fr = await completion(client, *B, prompt, 1)
|
|
return {"t_prefill_s": t_pf, "t_xfer_s": t_xfer, "t_total_s": t_total,
|
|
"cached": cached_of(fr)}
|
|
|
|
|
|
async def measure_layerwise(client, A, B, src_eid, src_bp_addr, prompt, seed):
|
|
"""Dispatch B handshake + A prefill concurrently => layer-wise overlap."""
|
|
tid = uuid.uuid4().hex
|
|
t0 = time.perf_counter()
|
|
|
|
async def run_B():
|
|
return await completion(client, *B, prompt, 1,
|
|
ktp={"do_remote_prefill": True, "transfer_id": tid,
|
|
"remote_engine_id": src_eid,
|
|
"remote_bootstrap_addr": src_bp_addr})
|
|
|
|
async def run_A():
|
|
# small head start for B's handshake to reach A before A's forward
|
|
await asyncio.sleep(0.05)
|
|
return await completion(client, *A, prompt, 1,
|
|
ktp={"do_remote_decode": True, "transfer_id": tid})
|
|
|
|
b_task = asyncio.create_task(run_B())
|
|
a_task = asyncio.create_task(run_A())
|
|
(t_b, _), (t_a, _) = await asyncio.gather(b_task, a_task)
|
|
t_total = time.perf_counter() - t0
|
|
_, fr = await completion(client, *B, prompt, 1)
|
|
return {"t_A_s": t_a, "t_B_s": t_b, "t_total_s": t_total,
|
|
"cached": cached_of(fr)}
|
|
|
|
|
|
async def main_async(a):
|
|
sizes = [int(s) for s in a.sizes.split(",")]
|
|
A = (a.src_host, a.src_port)
|
|
B = (a.dst_host, a.dst_port)
|
|
limits = httpx.Limits(max_connections=64, max_keepalive_connections=64)
|
|
async with httpx.AsyncClient(limits=limits, trust_env=False) as client:
|
|
src_eid = await get_engine_id(client, a.src_host, a.src_bp)
|
|
src_bp_addr = f"http://{a.src_host}:{a.src_bp}"
|
|
print(f"[mb7] mode={a.mode} bg_load={a.bg_load} src_eid={src_eid[:16]}...")
|
|
|
|
loader = None
|
|
if a.bg_load > 0:
|
|
loader = BackgroundLoad(client, [A, B], a.bg_load)
|
|
loader.start()
|
|
print(f"[mb7] ramping background load ({a.bg_load}) ...")
|
|
for _ in range(40):
|
|
await asyncio.sleep(1.0)
|
|
na = await num_running(client, *A)
|
|
nb = await num_running(client, *B)
|
|
if na >= 1 and nb >= 1:
|
|
print(f"[mb7] busy: A_run={na} B_run={nb}")
|
|
break
|
|
|
|
results = []
|
|
for sz in sizes:
|
|
for rep in range(a.repeats):
|
|
prompt = synth_prompt(sz * 100 + rep, sz)
|
|
# reference prefill-only cost (fresh prompt, different seed so no cache)
|
|
t_pf_only = await prefill_only(
|
|
client, *A, synth_prompt(sz * 100 + rep + 555, sz))
|
|
if a.mode == "baseline":
|
|
row = await measure_baseline(client, A, B, src_eid, src_bp_addr,
|
|
prompt, sz * 100 + rep)
|
|
else:
|
|
row = await measure_layerwise(client, A, B, src_eid, src_bp_addr,
|
|
prompt, sz * 100 + rep)
|
|
row.update({"mode": a.mode, "size": sz, "rep": rep,
|
|
"t_prefill_only_s": t_pf_only,
|
|
"kv_gib": sz * KV_PER_TOK / 2**30,
|
|
"correct": row["cached"] >= int(sz * 0.9)})
|
|
results.append(row)
|
|
extra = (f"xfer={row.get('t_xfer_s', 0)*1000:.0f}ms"
|
|
if a.mode == "baseline"
|
|
else f"tA={row.get('t_A_s',0)*1000:.0f}ms tB={row.get('t_B_s',0)*1000:.0f}ms")
|
|
print(f" sz={sz:>6} rep={rep} pf_only={t_pf_only*1000:6.0f}ms "
|
|
f"total={row['t_total_s']*1000:7.0f}ms {extra} "
|
|
f"cached={row['cached']}/{sz} correct={row['correct']}")
|
|
|
|
if loader:
|
|
await loader.stop()
|
|
|
|
# summary
|
|
print(f"\n=== {a.mode} (bg={a.bg_load}) summary ===")
|
|
print(f"{'size':>7} {'n':>2} {'pf_only_ms':>11} {'total_ms':>9} "
|
|
f"{'overhead_ms':>12} {'correct':>8}")
|
|
summary = []
|
|
for sz in sizes:
|
|
rs = [r for r in results if r["size"] == sz]
|
|
if not rs:
|
|
continue
|
|
pf = statistics.median(r["t_prefill_only_s"] for r in rs) * 1000
|
|
tot = statistics.median(r["t_total_s"] for r in rs) * 1000
|
|
allok = all(r["correct"] for r in rs)
|
|
# overhead = total - prefill_only = the part NOT hidden behind prefill
|
|
overhead = tot - pf
|
|
summary.append({"size": sz, "n": len(rs), "pf_only_ms": pf,
|
|
"total_ms": tot, "overhead_ms": overhead,
|
|
"all_correct": allok})
|
|
print(f"{sz:>7} {len(rs):>2} {pf:>11.0f} {tot:>9.0f} {overhead:>12.0f} "
|
|
f"{str(allok):>8}")
|
|
|
|
Path(a.out).write_text(json.dumps(
|
|
{"mode": a.mode, "model": MODEL, "raw": results, "summary": summary}, indent=2))
|
|
print(f"\n[mb7] wrote {a.out}")
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--mode", choices=["baseline", "layerwise"], required=True)
|
|
p.add_argument("--src-host", default="127.0.0.1")
|
|
p.add_argument("--dst-host", default="127.0.0.1")
|
|
p.add_argument("--src-port", type=int, default=8000)
|
|
p.add_argument("--dst-port", type=int, default=8001)
|
|
p.add_argument("--src-bp", type=int, default=8998)
|
|
p.add_argument("--dst-bp", type=int, default=8999)
|
|
p.add_argument("--sizes", default="8192,32768,65536")
|
|
p.add_argument("--repeats", type=int, default=3)
|
|
p.add_argument("--bg-load", type=int, default=0,
|
|
help="N concurrent background decode streams across A+B")
|
|
p.add_argument("--out", default="mb7_result.json")
|
|
args = p.parse_args()
|
|
asyncio.run(main_async(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|