Microbench: prefill-decode interference + PD transfer lifecycle

Two microbenchmarks quantifying the elastic offload decision:

1. Interference (corrected): cold prefill causes 14-214x TPOT p90
   degradation on same-worker decode (D∈{1,2,4,8} × P∈{2k,8k,16k,32k}).
   Earlier run had a prefix-cache bug (deterministic prompts hit cache
   after rep 0); fixed with uuid+time_ns unique prompts.

2. Transfer lifecycle: PD-sep TTFT breakdown via Mooncake proxy,
   measuring prefill→RDMA→decode startup overhead.

Key finding: offload wins at all P≥2048 operating points —
transfer cost is 25-50% of interference cost even with bulk Mooncake.
This commit is contained in:
2026-05-26 00:57:06 +08:00
parent 559faa1e26
commit f784e49c07
9 changed files with 1992 additions and 0 deletions

View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""Prefill-Decode Interference Microbenchmark Driver.
Measures TPOT degradation caused by prefill chunks interfering with ongoing decode batches.
Produces: f(decode_batch_size, new_prefill_tokens, chunk_size) -> TPOT_penalty_ms
Usage:
python driver.py --host 127.0.0.1 --port 8000 \
--decode-batch-sizes 0,1,2,4,6,8,12 \
--prefill-tokens 512,1024,2048,4096,8192,16384,32768 \
--reps 5 --output-dir results/
"""
import argparse
import asyncio
import json
import os
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional
import httpx
import numpy as np
FIXED_SEED_PROMPT = (
"You are a helpful assistant. Please analyze the following document carefully "
"and provide a comprehensive summary covering all key points, main arguments, "
"supporting evidence, and conclusions. The document discusses various aspects "
"of distributed systems, including consensus protocols, fault tolerance mechanisms, "
"and performance optimization strategies for large-scale deployments.\n\n"
) * 50 # ~4k tokens worth of repeated text for prefix cache sharing
WARMUP_TOKENS = 32
MEASURE_WINDOW_TOKENS = 500
@dataclass
class Config:
decode_batch_size: int
new_prefill_tokens: int
chunk_size: int
model: str
repetition: int
@dataclass
class BaselineResult:
tpot_p50_ms: float
tpot_p90_ms: float
tpot_p99_ms: float
tokens_collected: int
@dataclass
class InterferenceResult:
tpot_during_prefill_p50_ms: float
tpot_during_prefill_p90_ms: float
tpot_after_prefill_p50_ms: float
prefill_ttft_ms: float
num_tokens_during_prefill: int
async def stream_tokens(client: httpx.AsyncClient, url: str, payload: dict) -> list[float]:
"""Send a streaming request, return list of timestamps (seconds) for each token."""
timestamps = []
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
if "role" in delta:
continue
timestamps.append(time.perf_counter())
except json.JSONDecodeError:
continue
return timestamps
def compute_tpot(timestamps: list[float], skip_first: int = 0) -> np.ndarray:
"""Compute inter-token intervals in ms, skipping first N tokens."""
if len(timestamps) < skip_first + 2:
return np.array([])
ts = np.array(timestamps[skip_first:])
return np.diff(ts) * 1000.0 # seconds → ms
def make_decode_payload(model: str) -> dict:
return {
"model": model,
"messages": [{"role": "user", "content": FIXED_SEED_PROMPT}],
"max_tokens": WARMUP_TOKENS + MEASURE_WINDOW_TOKENS + 50,
"temperature": 0,
"stream": True,
}
def make_prefill_payload(model: str, num_tokens: int) -> dict:
import hashlib
import uuid
# Generate UNIQUE content every call to guarantee zero prefix cache hits.
# Calibration: each "Block N: <32-hex>" → ~35 tokens after tokenization
unique_id = f"{uuid.uuid4().hex}_{time.time_ns()}"
n_parts = max(1, num_tokens // 35)
content_parts = []
for i in range(n_parts):
seed = hashlib.md5(f"{unique_id}_{i}".encode()).hexdigest()
content_parts.append(f"Block {i}: {seed}")
content = " ".join(content_parts)
return {
"model": model,
"messages": [{"role": "user", "content": content}],
"max_tokens": 1,
"temperature": 0,
"stream": True,
}
async def wait_for_steady_state(decode_streams: list[asyncio.Task], min_tokens: int = 32):
"""Wait until all decode streams have emitted at least min_tokens."""
# We don't directly control this — we wait a fixed time based on expected TPOT
# At ~50ms/token, 32 tokens ≈ 1.6s. Wait 3s to be safe.
await asyncio.sleep(3.0)
async def run_baseline(
client: httpx.AsyncClient, url: str, model: str, decode_batch_size: int
) -> Optional[BaselineResult]:
"""Measure decode-only TPOT (no prefill interference)."""
if decode_batch_size == 0:
return BaselineResult(tpot_p50_ms=0, tpot_p90_ms=0, tpot_p99_ms=0, tokens_collected=0)
payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
tasks = [asyncio.create_task(stream_tokens(client, url, p)) for p in payloads]
all_timestamps = await asyncio.gather(*tasks, return_exceptions=True)
all_tpots = []
for ts in all_timestamps:
if isinstance(ts, Exception):
print(f" [WARN] decode stream error: {ts}")
continue
tpot = compute_tpot(ts, skip_first=WARMUP_TOKENS)
if len(tpot) > 0:
all_tpots.extend(tpot.tolist())
if not all_tpots:
return None
arr = np.array(all_tpots)
return BaselineResult(
tpot_p50_ms=float(np.percentile(arr, 50)),
tpot_p90_ms=float(np.percentile(arr, 90)),
tpot_p99_ms=float(np.percentile(arr, 99)),
tokens_collected=len(arr),
)
async def run_interference(
client: httpx.AsyncClient,
url: str,
model: str,
decode_batch_size: int,
new_prefill_tokens: int,
) -> Optional[InterferenceResult]:
"""Measure TPOT while a prefill request is being processed."""
if decode_batch_size == 0:
# No decode to interfere with; just measure prefill TTFT
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
t_start = time.perf_counter()
ts = await stream_tokens(client, url, prefill_payload)
prefill_ttft = (ts[0] - t_start) * 1000.0 if ts else 0
return InterferenceResult(
tpot_during_prefill_p50_ms=0,
tpot_during_prefill_p90_ms=0,
tpot_after_prefill_p50_ms=0,
prefill_ttft_ms=prefill_ttft,
num_tokens_during_prefill=0,
)
# Phase 1: Start decode streams
decode_payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
decode_timestamps: list[list[float]] = [[] for _ in range(decode_batch_size)]
prefill_done_event = asyncio.Event()
prefill_ttft_ms = 0.0
prefill_inject_time = 0.0
async def decode_stream_with_tracking(idx: int, payload: dict):
timestamps = await stream_tokens(client, url, payload)
decode_timestamps[idx] = timestamps
async def prefill_after_warmup():
nonlocal prefill_ttft_ms, prefill_inject_time
# Wait for decode streams to stabilize
await asyncio.sleep(1.0)
prefill_inject_time = time.perf_counter()
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
ts = await stream_tokens(client, url, prefill_payload)
if ts:
prefill_ttft_ms = (ts[0] - prefill_inject_time) * 1000.0
prefill_done_event.set()
# Launch all
decode_tasks = [
asyncio.create_task(decode_stream_with_tracking(i, p))
for i, p in enumerate(decode_payloads)
]
prefill_task = asyncio.create_task(prefill_after_warmup())
await asyncio.gather(*decode_tasks, prefill_task, return_exceptions=True)
# Analyze: split decode tokens into "during prefill" and "after prefill"
prefill_end_time = prefill_inject_time + prefill_ttft_ms / 1000.0
tpot_during = []
tpot_after = []
for ts_list in decode_timestamps:
if len(ts_list) < WARMUP_TOKENS + 5:
continue
for i in range(WARMUP_TOKENS + 1, len(ts_list)):
t_prev = ts_list[i - 1]
t_curr = ts_list[i]
interval_ms = (t_curr - t_prev) * 1000.0
if prefill_inject_time <= t_prev <= prefill_end_time:
tpot_during.append(interval_ms)
elif t_curr > prefill_end_time + 0.05: # 50ms after prefill settles
tpot_after.append(interval_ms)
during_arr = np.array(tpot_during) if tpot_during else np.array([0.0])
after_arr = np.array(tpot_after) if tpot_after else np.array([0.0])
return InterferenceResult(
tpot_during_prefill_p50_ms=float(np.percentile(during_arr, 50)),
tpot_during_prefill_p90_ms=float(np.percentile(during_arr, 90)),
tpot_after_prefill_p50_ms=float(np.percentile(after_arr, 50)),
prefill_ttft_ms=prefill_ttft_ms,
num_tokens_during_prefill=len(tpot_during),
)
async def run_single_config(
client: httpx.AsyncClient,
url: str,
model: str,
decode_batch_size: int,
new_prefill_tokens: int,
chunk_size: int,
rep: int,
output_dir: Path,
):
"""Run one (D, P) configuration."""
config = Config(
decode_batch_size=decode_batch_size,
new_prefill_tokens=new_prefill_tokens,
chunk_size=chunk_size,
model=model,
repetition=rep,
)
print(f" [rep {rep}] Running baseline (D={decode_batch_size})...")
baseline = await run_baseline(client, url, model, decode_batch_size)
if baseline is None:
print(f" [rep {rep}] Baseline failed, skipping")
return
# Brief cooldown between baseline and interference
await asyncio.sleep(2.0)
print(f" [rep {rep}] Running interference (D={decode_batch_size}, P={new_prefill_tokens})...")
interference = await run_interference(
client, url, model, decode_batch_size, new_prefill_tokens
)
if interference is None:
print(f" [rep {rep}] Interference measurement failed, skipping")
return
# Compute derived metrics
tpot_penalty_p50 = interference.tpot_during_prefill_p50_ms - baseline.tpot_p50_ms
penalty_ratio = (
interference.tpot_during_prefill_p50_ms / baseline.tpot_p50_ms
if baseline.tpot_p50_ms > 0 else 0
)
result = {
"config": asdict(config),
"baseline": asdict(baseline),
"interference": asdict(interference),
"derived": {
"tpot_penalty_p50_ms": tpot_penalty_p50,
"tpot_penalty_ratio": penalty_ratio,
},
}
# Save
fname = f"D{decode_batch_size}_P{new_prefill_tokens}_rep{rep}.json"
out_path = output_dir / fname
out_path.write_text(json.dumps(result, indent=2))
print(f" [rep {rep}] Done. penalty={tpot_penalty_p50:.1f}ms ratio={penalty_ratio:.2f}")
async def main():
parser = argparse.ArgumentParser(description="Prefill-Decode Interference Microbenchmark")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", default="Qwen3-Coder-30B-A3B-Instruct")
parser.add_argument("--decode-batch-sizes", default="0,1,2,4,6,8,12",
help="Comma-separated decode batch sizes")
parser.add_argument("--prefill-tokens", default="512,1024,2048,4096,8192,16384,32768",
help="Comma-separated prefill token counts")
parser.add_argument("--chunk-size", type=int, default=8192,
help="vLLM max_num_batched_tokens (effective chunk size)")
parser.add_argument("--reps", type=int, default=5)
parser.add_argument("--output-dir", default="results/interference")
args = parser.parse_args()
decode_sizes = [int(x) for x in args.decode_batch_sizes.split(",")]
prefill_tokens = [int(x) for x in args.prefill_tokens.split(",")]
output_dir = Path(args.output_dir) / f"chunk{args.chunk_size}"
output_dir.mkdir(parents=True, exist_ok=True)
url = f"http://{args.host}:{args.port}/v1/chat/completions"
print(f"Target: {url}")
print(f"Model: {args.model}")
print(f"Chunk size: {args.chunk_size}")
print(f"Decode batch sizes: {decode_sizes}")
print(f"Prefill tokens: {prefill_tokens}")
print(f"Repetitions: {args.reps}")
print(f"Output: {output_dir}")
print()
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as client:
# Sanity check: is the server up?
try:
resp = await client.get(f"http://{args.host}:{args.port}/v1/models")
resp.raise_for_status()
models = resp.json()
print(f"Server ready. Models: {[m['id'] for m in models.get('data', [])]}")
except Exception as e:
print(f"ERROR: Cannot reach server at {args.host}:{args.port}: {e}")
return
total_configs = len(decode_sizes) * len(prefill_tokens)
done = 0
for D in decode_sizes:
for P in prefill_tokens:
done += 1
print(f"\n[{done}/{total_configs}] D={D}, P={P}")
for rep in range(args.reps):
try:
await run_single_config(
client, url, args.model, D, P,
args.chunk_size, rep, output_dir,
)
except Exception as e:
print(f" [rep {rep}] ERROR: {e}")
# Cooldown between reps
await asyncio.sleep(1.0)
# Cooldown between configs
await asyncio.sleep(3.0)
print("\n\nDone! Results in:", output_dir)
# Generate summary CSV
await generate_summary(output_dir, args.chunk_size)
async def generate_summary(output_dir: Path, chunk_size: int):
"""Aggregate all per-run JSONs into a summary CSV."""
import csv
rows = []
for f in sorted(output_dir.glob("D*_P*_rep*.json")):
data = json.loads(f.read_text())
cfg = data["config"]
bl = data["baseline"]
itf = data["interference"]
drv = data["derived"]
rows.append({
"chunk_size": cfg["chunk_size"],
"decode_batch_size": cfg["decode_batch_size"],
"new_prefill_tokens": cfg["new_prefill_tokens"],
"repetition": cfg["repetition"],
"tpot_baseline_p50_ms": bl["tpot_p50_ms"],
"tpot_baseline_p90_ms": bl["tpot_p90_ms"],
"tpot_during_prefill_p50_ms": itf["tpot_during_prefill_p50_ms"],
"tpot_during_prefill_p90_ms": itf["tpot_during_prefill_p90_ms"],
"tpot_after_prefill_p50_ms": itf["tpot_after_prefill_p50_ms"],
"prefill_ttft_ms": itf["prefill_ttft_ms"],
"num_tokens_during_prefill": itf["num_tokens_during_prefill"],
"tpot_penalty_p50_ms": drv["tpot_penalty_p50_ms"],
"tpot_penalty_ratio": drv["tpot_penalty_ratio"],
})
if not rows:
return
csv_path = output_dir / "summary.csv"
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
writer.writeheader()
writer.writerows(rows)
print(f"Summary CSV written: {csv_path}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,57 @@
#!/bin/bash
# Launch a single vLLM instance on GPU 0 for interference microbenchmark.
# Uses TP=1, enable-chunked-prefill, enable-prefix-caching.
#
# Usage: bash launch_microbench1.sh [chunk_size] [port]
# chunk_size: max_num_batched_tokens (default: 8192)
# port: serving port (default: 8000)
set -euo pipefail
CHUNK_SIZE=${1:-8192}
PORT=${2:-8000}
MODEL="${MODEL:-$HOME/models/Qwen/Qwen3-Coder-30B-A3B-Instruct}"
GPU_ID=${GPU_ID:-0}
LOG_FILE="vllm_microbench1_chunk${CHUNK_SIZE}.log"
echo "=== Interference Microbench vLLM Instance ==="
echo "Model: $MODEL"
echo "GPU: $GPU_ID"
echo "Port: $PORT"
echo "Chunk size (max_num_batched_tokens): $CHUNK_SIZE"
echo "Log: $LOG_FILE"
echo ""
# Kill any existing vLLM on this port
pkill -f "vllm.*--port $PORT" 2>/dev/null || true
sleep 2
CUDA_VISIBLE_DEVICES=$GPU_ID python -m vllm.entrypoints.openai.api_server \
--model "$MODEL" \
--tensor-parallel-size 1 \
--enable-prefix-caching \
--dtype auto \
--gpu-memory-utilization 0.9 \
--max-model-len 200000 \
--max-num-batched-tokens "$CHUNK_SIZE" \
--port "$PORT" \
--trust-remote-code \
--disable-log-requests \
2>&1 | tee "$LOG_FILE" &
VLLM_PID=$!
echo "vLLM PID: $VLLM_PID"
echo "$VLLM_PID" > .vllm_microbench1.pid
# Wait for server to be ready
echo "Waiting for server to start..."
for i in $(seq 1 120); do
if curl -s "http://127.0.0.1:$PORT/v1/models" > /dev/null 2>&1; then
echo "Server ready after ${i}s!"
exit 0
fi
sleep 1
done
echo "ERROR: Server did not start within 120s"
exit 1

View File

@@ -0,0 +1,36 @@
#!/bin/bash
# Run the interference microbenchmark sweep.
# Assumes vLLM is already running on the specified port.
#
# Usage: bash run_sweep.sh [port] [chunk_size]
set -euo pipefail
PORT=${1:-8000}
CHUNK_SIZE=${2:-8192}
REPS=${REPS:-5}
OUTPUT_DIR="results/interference"
echo "=== Interference Microbench Sweep ==="
echo "Server: http://127.0.0.1:$PORT"
echo "Chunk size: $CHUNK_SIZE"
echo "Reps: $REPS"
echo "Output: $OUTPUT_DIR"
echo ""
# Quick sanity check
curl -sf "http://127.0.0.1:$PORT/v1/models" > /dev/null || {
echo "ERROR: vLLM not reachable on port $PORT"
exit 1
}
cd "$(dirname "$0")"
python driver.py \
--host 127.0.0.1 \
--port "$PORT" \
--chunk-size "$CHUNK_SIZE" \
--decode-batch-sizes "0,1,2,4,6,8,12" \
--prefill-tokens "512,1024,2048,4096,8192,16384,32768" \
--reps "$REPS" \
--output-dir "$OUTPUT_DIR"