Two microbenchmarks quantifying the elastic offload decision:
1. Interference (corrected): cold prefill causes 14-214x TPOT p90
degradation on same-worker decode (D∈{1,2,4,8} × P∈{2k,8k,16k,32k}).
Earlier run had a prefix-cache bug (deterministic prompts hit cache
after rep 0); fixed with uuid+time_ns unique prompts.
2. Transfer lifecycle: PD-sep TTFT breakdown via Mooncake proxy,
measuring prefill→RDMA→decode startup overhead.
Key finding: offload wins at all P≥2048 operating points —
transfer cost is 25-50% of interference cost even with bulk Mooncake.
423 lines
15 KiB
Python
423 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""Prefill-Decode Interference Microbenchmark Driver.
|
|
|
|
Measures TPOT degradation caused by prefill chunks interfering with ongoing decode batches.
|
|
Produces: f(decode_batch_size, new_prefill_tokens, chunk_size) -> TPOT_penalty_ms
|
|
|
|
Usage:
|
|
python driver.py --host 127.0.0.1 --port 8000 \
|
|
--decode-batch-sizes 0,1,2,4,6,8,12 \
|
|
--prefill-tokens 512,1024,2048,4096,8192,16384,32768 \
|
|
--reps 5 --output-dir results/
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
import numpy as np
|
|
|
|
|
|
FIXED_SEED_PROMPT = (
|
|
"You are a helpful assistant. Please analyze the following document carefully "
|
|
"and provide a comprehensive summary covering all key points, main arguments, "
|
|
"supporting evidence, and conclusions. The document discusses various aspects "
|
|
"of distributed systems, including consensus protocols, fault tolerance mechanisms, "
|
|
"and performance optimization strategies for large-scale deployments.\n\n"
|
|
) * 50 # ~4k tokens worth of repeated text for prefix cache sharing
|
|
|
|
WARMUP_TOKENS = 32
|
|
MEASURE_WINDOW_TOKENS = 500
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
decode_batch_size: int
|
|
new_prefill_tokens: int
|
|
chunk_size: int
|
|
model: str
|
|
repetition: int
|
|
|
|
|
|
@dataclass
|
|
class BaselineResult:
|
|
tpot_p50_ms: float
|
|
tpot_p90_ms: float
|
|
tpot_p99_ms: float
|
|
tokens_collected: int
|
|
|
|
|
|
@dataclass
|
|
class InterferenceResult:
|
|
tpot_during_prefill_p50_ms: float
|
|
tpot_during_prefill_p90_ms: float
|
|
tpot_after_prefill_p50_ms: float
|
|
prefill_ttft_ms: float
|
|
num_tokens_during_prefill: int
|
|
|
|
|
|
async def stream_tokens(client: httpx.AsyncClient, url: str, payload: dict) -> list[float]:
|
|
"""Send a streaming request, return list of timestamps (seconds) for each token."""
|
|
timestamps = []
|
|
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
|
|
resp.raise_for_status()
|
|
async for line in resp.aiter_lines():
|
|
if line.startswith("data: "):
|
|
data = line[6:]
|
|
if data.strip() == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(data)
|
|
choices = chunk.get("choices", [])
|
|
if not choices:
|
|
continue
|
|
delta = choices[0].get("delta", {})
|
|
if "role" in delta:
|
|
continue
|
|
timestamps.append(time.perf_counter())
|
|
except json.JSONDecodeError:
|
|
continue
|
|
return timestamps
|
|
|
|
|
|
def compute_tpot(timestamps: list[float], skip_first: int = 0) -> np.ndarray:
|
|
"""Compute inter-token intervals in ms, skipping first N tokens."""
|
|
if len(timestamps) < skip_first + 2:
|
|
return np.array([])
|
|
ts = np.array(timestamps[skip_first:])
|
|
return np.diff(ts) * 1000.0 # seconds → ms
|
|
|
|
|
|
def make_decode_payload(model: str) -> dict:
|
|
return {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": FIXED_SEED_PROMPT}],
|
|
"max_tokens": WARMUP_TOKENS + MEASURE_WINDOW_TOKENS + 50,
|
|
"temperature": 0,
|
|
"stream": True,
|
|
}
|
|
|
|
|
|
def make_prefill_payload(model: str, num_tokens: int) -> dict:
|
|
import hashlib
|
|
import uuid
|
|
# Generate UNIQUE content every call to guarantee zero prefix cache hits.
|
|
# Calibration: each "Block N: <32-hex>" → ~35 tokens after tokenization
|
|
unique_id = f"{uuid.uuid4().hex}_{time.time_ns()}"
|
|
n_parts = max(1, num_tokens // 35)
|
|
content_parts = []
|
|
for i in range(n_parts):
|
|
seed = hashlib.md5(f"{unique_id}_{i}".encode()).hexdigest()
|
|
content_parts.append(f"Block {i}: {seed}")
|
|
content = " ".join(content_parts)
|
|
return {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
"stream": True,
|
|
}
|
|
|
|
|
|
async def wait_for_steady_state(decode_streams: list[asyncio.Task], min_tokens: int = 32):
|
|
"""Wait until all decode streams have emitted at least min_tokens."""
|
|
# We don't directly control this — we wait a fixed time based on expected TPOT
|
|
# At ~50ms/token, 32 tokens ≈ 1.6s. Wait 3s to be safe.
|
|
await asyncio.sleep(3.0)
|
|
|
|
|
|
async def run_baseline(
|
|
client: httpx.AsyncClient, url: str, model: str, decode_batch_size: int
|
|
) -> Optional[BaselineResult]:
|
|
"""Measure decode-only TPOT (no prefill interference)."""
|
|
if decode_batch_size == 0:
|
|
return BaselineResult(tpot_p50_ms=0, tpot_p90_ms=0, tpot_p99_ms=0, tokens_collected=0)
|
|
|
|
payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
|
|
tasks = [asyncio.create_task(stream_tokens(client, url, p)) for p in payloads]
|
|
|
|
all_timestamps = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
all_tpots = []
|
|
for ts in all_timestamps:
|
|
if isinstance(ts, Exception):
|
|
print(f" [WARN] decode stream error: {ts}")
|
|
continue
|
|
tpot = compute_tpot(ts, skip_first=WARMUP_TOKENS)
|
|
if len(tpot) > 0:
|
|
all_tpots.extend(tpot.tolist())
|
|
|
|
if not all_tpots:
|
|
return None
|
|
|
|
arr = np.array(all_tpots)
|
|
return BaselineResult(
|
|
tpot_p50_ms=float(np.percentile(arr, 50)),
|
|
tpot_p90_ms=float(np.percentile(arr, 90)),
|
|
tpot_p99_ms=float(np.percentile(arr, 99)),
|
|
tokens_collected=len(arr),
|
|
)
|
|
|
|
|
|
async def run_interference(
|
|
client: httpx.AsyncClient,
|
|
url: str,
|
|
model: str,
|
|
decode_batch_size: int,
|
|
new_prefill_tokens: int,
|
|
) -> Optional[InterferenceResult]:
|
|
"""Measure TPOT while a prefill request is being processed."""
|
|
if decode_batch_size == 0:
|
|
# No decode to interfere with; just measure prefill TTFT
|
|
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
|
|
t_start = time.perf_counter()
|
|
ts = await stream_tokens(client, url, prefill_payload)
|
|
prefill_ttft = (ts[0] - t_start) * 1000.0 if ts else 0
|
|
return InterferenceResult(
|
|
tpot_during_prefill_p50_ms=0,
|
|
tpot_during_prefill_p90_ms=0,
|
|
tpot_after_prefill_p50_ms=0,
|
|
prefill_ttft_ms=prefill_ttft,
|
|
num_tokens_during_prefill=0,
|
|
)
|
|
|
|
# Phase 1: Start decode streams
|
|
decode_payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
|
|
|
|
decode_timestamps: list[list[float]] = [[] for _ in range(decode_batch_size)]
|
|
prefill_done_event = asyncio.Event()
|
|
prefill_ttft_ms = 0.0
|
|
prefill_inject_time = 0.0
|
|
|
|
async def decode_stream_with_tracking(idx: int, payload: dict):
|
|
timestamps = await stream_tokens(client, url, payload)
|
|
decode_timestamps[idx] = timestamps
|
|
|
|
async def prefill_after_warmup():
|
|
nonlocal prefill_ttft_ms, prefill_inject_time
|
|
# Wait for decode streams to stabilize
|
|
await asyncio.sleep(1.0)
|
|
prefill_inject_time = time.perf_counter()
|
|
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
|
|
ts = await stream_tokens(client, url, prefill_payload)
|
|
if ts:
|
|
prefill_ttft_ms = (ts[0] - prefill_inject_time) * 1000.0
|
|
prefill_done_event.set()
|
|
|
|
# Launch all
|
|
decode_tasks = [
|
|
asyncio.create_task(decode_stream_with_tracking(i, p))
|
|
for i, p in enumerate(decode_payloads)
|
|
]
|
|
prefill_task = asyncio.create_task(prefill_after_warmup())
|
|
|
|
await asyncio.gather(*decode_tasks, prefill_task, return_exceptions=True)
|
|
|
|
# Analyze: split decode tokens into "during prefill" and "after prefill"
|
|
prefill_end_time = prefill_inject_time + prefill_ttft_ms / 1000.0
|
|
|
|
tpot_during = []
|
|
tpot_after = []
|
|
|
|
for ts_list in decode_timestamps:
|
|
if len(ts_list) < WARMUP_TOKENS + 5:
|
|
continue
|
|
for i in range(WARMUP_TOKENS + 1, len(ts_list)):
|
|
t_prev = ts_list[i - 1]
|
|
t_curr = ts_list[i]
|
|
interval_ms = (t_curr - t_prev) * 1000.0
|
|
|
|
if prefill_inject_time <= t_prev <= prefill_end_time:
|
|
tpot_during.append(interval_ms)
|
|
elif t_curr > prefill_end_time + 0.05: # 50ms after prefill settles
|
|
tpot_after.append(interval_ms)
|
|
|
|
during_arr = np.array(tpot_during) if tpot_during else np.array([0.0])
|
|
after_arr = np.array(tpot_after) if tpot_after else np.array([0.0])
|
|
|
|
return InterferenceResult(
|
|
tpot_during_prefill_p50_ms=float(np.percentile(during_arr, 50)),
|
|
tpot_during_prefill_p90_ms=float(np.percentile(during_arr, 90)),
|
|
tpot_after_prefill_p50_ms=float(np.percentile(after_arr, 50)),
|
|
prefill_ttft_ms=prefill_ttft_ms,
|
|
num_tokens_during_prefill=len(tpot_during),
|
|
)
|
|
|
|
|
|
async def run_single_config(
|
|
client: httpx.AsyncClient,
|
|
url: str,
|
|
model: str,
|
|
decode_batch_size: int,
|
|
new_prefill_tokens: int,
|
|
chunk_size: int,
|
|
rep: int,
|
|
output_dir: Path,
|
|
):
|
|
"""Run one (D, P) configuration."""
|
|
config = Config(
|
|
decode_batch_size=decode_batch_size,
|
|
new_prefill_tokens=new_prefill_tokens,
|
|
chunk_size=chunk_size,
|
|
model=model,
|
|
repetition=rep,
|
|
)
|
|
|
|
print(f" [rep {rep}] Running baseline (D={decode_batch_size})...")
|
|
baseline = await run_baseline(client, url, model, decode_batch_size)
|
|
if baseline is None:
|
|
print(f" [rep {rep}] Baseline failed, skipping")
|
|
return
|
|
|
|
# Brief cooldown between baseline and interference
|
|
await asyncio.sleep(2.0)
|
|
|
|
print(f" [rep {rep}] Running interference (D={decode_batch_size}, P={new_prefill_tokens})...")
|
|
interference = await run_interference(
|
|
client, url, model, decode_batch_size, new_prefill_tokens
|
|
)
|
|
if interference is None:
|
|
print(f" [rep {rep}] Interference measurement failed, skipping")
|
|
return
|
|
|
|
# Compute derived metrics
|
|
tpot_penalty_p50 = interference.tpot_during_prefill_p50_ms - baseline.tpot_p50_ms
|
|
penalty_ratio = (
|
|
interference.tpot_during_prefill_p50_ms / baseline.tpot_p50_ms
|
|
if baseline.tpot_p50_ms > 0 else 0
|
|
)
|
|
|
|
result = {
|
|
"config": asdict(config),
|
|
"baseline": asdict(baseline),
|
|
"interference": asdict(interference),
|
|
"derived": {
|
|
"tpot_penalty_p50_ms": tpot_penalty_p50,
|
|
"tpot_penalty_ratio": penalty_ratio,
|
|
},
|
|
}
|
|
|
|
# Save
|
|
fname = f"D{decode_batch_size}_P{new_prefill_tokens}_rep{rep}.json"
|
|
out_path = output_dir / fname
|
|
out_path.write_text(json.dumps(result, indent=2))
|
|
print(f" [rep {rep}] Done. penalty={tpot_penalty_p50:.1f}ms ratio={penalty_ratio:.2f}")
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(description="Prefill-Decode Interference Microbenchmark")
|
|
parser.add_argument("--host", default="127.0.0.1")
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--model", default="Qwen3-Coder-30B-A3B-Instruct")
|
|
parser.add_argument("--decode-batch-sizes", default="0,1,2,4,6,8,12",
|
|
help="Comma-separated decode batch sizes")
|
|
parser.add_argument("--prefill-tokens", default="512,1024,2048,4096,8192,16384,32768",
|
|
help="Comma-separated prefill token counts")
|
|
parser.add_argument("--chunk-size", type=int, default=8192,
|
|
help="vLLM max_num_batched_tokens (effective chunk size)")
|
|
parser.add_argument("--reps", type=int, default=5)
|
|
parser.add_argument("--output-dir", default="results/interference")
|
|
args = parser.parse_args()
|
|
|
|
decode_sizes = [int(x) for x in args.decode_batch_sizes.split(",")]
|
|
prefill_tokens = [int(x) for x in args.prefill_tokens.split(",")]
|
|
|
|
output_dir = Path(args.output_dir) / f"chunk{args.chunk_size}"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
url = f"http://{args.host}:{args.port}/v1/chat/completions"
|
|
print(f"Target: {url}")
|
|
print(f"Model: {args.model}")
|
|
print(f"Chunk size: {args.chunk_size}")
|
|
print(f"Decode batch sizes: {decode_sizes}")
|
|
print(f"Prefill tokens: {prefill_tokens}")
|
|
print(f"Repetitions: {args.reps}")
|
|
print(f"Output: {output_dir}")
|
|
print()
|
|
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as client:
|
|
# Sanity check: is the server up?
|
|
try:
|
|
resp = await client.get(f"http://{args.host}:{args.port}/v1/models")
|
|
resp.raise_for_status()
|
|
models = resp.json()
|
|
print(f"Server ready. Models: {[m['id'] for m in models.get('data', [])]}")
|
|
except Exception as e:
|
|
print(f"ERROR: Cannot reach server at {args.host}:{args.port}: {e}")
|
|
return
|
|
|
|
total_configs = len(decode_sizes) * len(prefill_tokens)
|
|
done = 0
|
|
|
|
for D in decode_sizes:
|
|
for P in prefill_tokens:
|
|
done += 1
|
|
print(f"\n[{done}/{total_configs}] D={D}, P={P}")
|
|
|
|
for rep in range(args.reps):
|
|
try:
|
|
await run_single_config(
|
|
client, url, args.model, D, P,
|
|
args.chunk_size, rep, output_dir,
|
|
)
|
|
except Exception as e:
|
|
print(f" [rep {rep}] ERROR: {e}")
|
|
|
|
# Cooldown between reps
|
|
await asyncio.sleep(1.0)
|
|
|
|
# Cooldown between configs
|
|
await asyncio.sleep(3.0)
|
|
|
|
print("\n\nDone! Results in:", output_dir)
|
|
# Generate summary CSV
|
|
await generate_summary(output_dir, args.chunk_size)
|
|
|
|
|
|
async def generate_summary(output_dir: Path, chunk_size: int):
|
|
"""Aggregate all per-run JSONs into a summary CSV."""
|
|
import csv
|
|
|
|
rows = []
|
|
for f in sorted(output_dir.glob("D*_P*_rep*.json")):
|
|
data = json.loads(f.read_text())
|
|
cfg = data["config"]
|
|
bl = data["baseline"]
|
|
itf = data["interference"]
|
|
drv = data["derived"]
|
|
rows.append({
|
|
"chunk_size": cfg["chunk_size"],
|
|
"decode_batch_size": cfg["decode_batch_size"],
|
|
"new_prefill_tokens": cfg["new_prefill_tokens"],
|
|
"repetition": cfg["repetition"],
|
|
"tpot_baseline_p50_ms": bl["tpot_p50_ms"],
|
|
"tpot_baseline_p90_ms": bl["tpot_p90_ms"],
|
|
"tpot_during_prefill_p50_ms": itf["tpot_during_prefill_p50_ms"],
|
|
"tpot_during_prefill_p90_ms": itf["tpot_during_prefill_p90_ms"],
|
|
"tpot_after_prefill_p50_ms": itf["tpot_after_prefill_p50_ms"],
|
|
"prefill_ttft_ms": itf["prefill_ttft_ms"],
|
|
"num_tokens_during_prefill": itf["num_tokens_during_prefill"],
|
|
"tpot_penalty_p50_ms": drv["tpot_penalty_p50_ms"],
|
|
"tpot_penalty_ratio": drv["tpot_penalty_ratio"],
|
|
})
|
|
|
|
if not rows:
|
|
return
|
|
|
|
csv_path = output_dir / "summary.csv"
|
|
with open(csv_path, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
print(f"Summary CSV written: {csv_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|