Files
agentic-kvc/microbench/fresh_setup/mb1_driver.py
Gahow Wang 029821c1b6 MB1: prefill-decode interference under chunked-prefill default; §3.2 headline
Single-GPU bench on dash1 GPU 0 (vanilla vLLM 0.18.1, chunked-prefill on,
no kv_connector). 3 decode batch sizes × 5 prefill sizes × 3 reps.

Method recap (driver: microbench/interference/driver.py, repurposed):
- Pin D streaming decode requests at constant max_tokens
- Inject one prefill-only request (max_tokens=1) of varying input length
- Bin decode-stream token timestamps into "during prefill" vs baseline
- Headline metric: effective per-stream TPOT during the prefill burst,
  = prefill_ttft / (num_tokens_during_prefill / D). This is the average
  rate at which each decode stream produces tokens during the burst.
  p50 of inter-token intervals is deceptive (chunked-prefill makes most
  intervals look normal); the burst-average gives the true cost.

Results (D=8 row, the most agentic-realistic case):
  P (tokens) | prefill_ttft | per-stream TPOT during | penalty
       2048  |    143 ms    |      32 ms             |    4×
       8192  |    583 ms    |     114 ms             |   15×
      32768  |  4520 ms     |     388 ms             |   52×
      65536  | 15615 ms     |     757 ms             |   99×
     131072  | 56991 ms     |    1419 ms             |  183×

Baseline TPOT at D=8: ~7.7 ms. So during a 131k-token prefill burst
each ongoing decode is running ~183× slower (i.e. essentially halted)
for ~57 seconds.

§3.2 implication: PD-disagg's promised phase-isolation benefit per
agentic request is bounded by the decode duration, which is 50–200 ms
for tool-call output. MB2 says the KV-transfer cost of PD-disagg
is 300 ms – 10 s for agentic-size requests. Cost > benefit for every
KV size above ~80 MiB (well below trace mean 192 MiB).

The new figs/pd_cost_vs_benefit.png overlays MB1 benefit ceiling
(50–200 ms band, capped by decode) onto MB2 transfer cost curve and
marks the agentic-distribution waypoints (trace mean, p90, p95, p99)
on the x-axis. Across the entire agentic distribution, the cost curve
sits above the benefit band.

Adds:
- microbench/fresh_setup/mb1_launch.sh: single-GPU vLLM launcher (no
  kv_connector, default chunked_prefill=on, max_num_batched_tokens=8192)
- microbench/fresh_setup/mb1_driver.py: copy of the existing
  microbench/interference/driver.py for cpfs deployment
- microbench/fresh_setup/analyze_mb1.py: aggregator emitting
  per-(D, P) effective-TPOT-during + max PD-disagg-benefit table
- microbench/fresh_setup/plot_mb1.py: mb1 standalone +
  pd_cost_vs_benefit headline figure
- analysis/mb1/summary.csv: 45 raw rows from the sweep
- analysis/mb1/breakdown.json: per-(D, P) aggregate
- analysis/mb1/README.md: persistent doc
- figs/mb1_interference.png: effective TPOT during prefill, one line per D
- figs/pd_cost_vs_benefit.png: §3.2 headline (cost > benefit everywhere)

Caveats noted in README:
- chunk_tokens=8192 only; Sarathi-Serve's smaller chunks would
  interleave decode more aggressively. Chunk-size sensitivity is
  flagged as next run.
- D ≤ 8; higher D may saturate or shrink the penalty further.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-27 21:25:09 +08:00

423 lines
15 KiB
Python

#!/usr/bin/env python3
"""Prefill-Decode Interference Microbenchmark Driver.
Measures TPOT degradation caused by prefill chunks interfering with ongoing decode batches.
Produces: f(decode_batch_size, new_prefill_tokens, chunk_size) -> TPOT_penalty_ms
Usage:
python driver.py --host 127.0.0.1 --port 8000 \
--decode-batch-sizes 0,1,2,4,6,8,12 \
--prefill-tokens 512,1024,2048,4096,8192,16384,32768 \
--reps 5 --output-dir results/
"""
import argparse
import asyncio
import json
import os
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional
import httpx
import numpy as np
FIXED_SEED_PROMPT = (
"You are a helpful assistant. Please analyze the following document carefully "
"and provide a comprehensive summary covering all key points, main arguments, "
"supporting evidence, and conclusions. The document discusses various aspects "
"of distributed systems, including consensus protocols, fault tolerance mechanisms, "
"and performance optimization strategies for large-scale deployments.\n\n"
) * 50 # ~4k tokens worth of repeated text for prefix cache sharing
WARMUP_TOKENS = 32
MEASURE_WINDOW_TOKENS = 500
@dataclass
class Config:
decode_batch_size: int
new_prefill_tokens: int
chunk_size: int
model: str
repetition: int
@dataclass
class BaselineResult:
tpot_p50_ms: float
tpot_p90_ms: float
tpot_p99_ms: float
tokens_collected: int
@dataclass
class InterferenceResult:
tpot_during_prefill_p50_ms: float
tpot_during_prefill_p90_ms: float
tpot_after_prefill_p50_ms: float
prefill_ttft_ms: float
num_tokens_during_prefill: int
async def stream_tokens(client: httpx.AsyncClient, url: str, payload: dict) -> list[float]:
"""Send a streaming request, return list of timestamps (seconds) for each token."""
timestamps = []
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
if "role" in delta:
continue
timestamps.append(time.perf_counter())
except json.JSONDecodeError:
continue
return timestamps
def compute_tpot(timestamps: list[float], skip_first: int = 0) -> np.ndarray:
"""Compute inter-token intervals in ms, skipping first N tokens."""
if len(timestamps) < skip_first + 2:
return np.array([])
ts = np.array(timestamps[skip_first:])
return np.diff(ts) * 1000.0 # seconds → ms
def make_decode_payload(model: str) -> dict:
return {
"model": model,
"messages": [{"role": "user", "content": FIXED_SEED_PROMPT}],
"max_tokens": WARMUP_TOKENS + MEASURE_WINDOW_TOKENS + 50,
"temperature": 0,
"stream": True,
}
def make_prefill_payload(model: str, num_tokens: int) -> dict:
import hashlib
import uuid
# Generate UNIQUE content every call to guarantee zero prefix cache hits.
# Calibration: each "Block N: <32-hex>" → ~35 tokens after tokenization
unique_id = f"{uuid.uuid4().hex}_{time.time_ns()}"
n_parts = max(1, num_tokens // 35)
content_parts = []
for i in range(n_parts):
seed = hashlib.md5(f"{unique_id}_{i}".encode()).hexdigest()
content_parts.append(f"Block {i}: {seed}")
content = " ".join(content_parts)
return {
"model": model,
"messages": [{"role": "user", "content": content}],
"max_tokens": 1,
"temperature": 0,
"stream": True,
}
async def wait_for_steady_state(decode_streams: list[asyncio.Task], min_tokens: int = 32):
"""Wait until all decode streams have emitted at least min_tokens."""
# We don't directly control this — we wait a fixed time based on expected TPOT
# At ~50ms/token, 32 tokens ≈ 1.6s. Wait 3s to be safe.
await asyncio.sleep(3.0)
async def run_baseline(
client: httpx.AsyncClient, url: str, model: str, decode_batch_size: int
) -> Optional[BaselineResult]:
"""Measure decode-only TPOT (no prefill interference)."""
if decode_batch_size == 0:
return BaselineResult(tpot_p50_ms=0, tpot_p90_ms=0, tpot_p99_ms=0, tokens_collected=0)
payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
tasks = [asyncio.create_task(stream_tokens(client, url, p)) for p in payloads]
all_timestamps = await asyncio.gather(*tasks, return_exceptions=True)
all_tpots = []
for ts in all_timestamps:
if isinstance(ts, Exception):
print(f" [WARN] decode stream error: {ts}")
continue
tpot = compute_tpot(ts, skip_first=WARMUP_TOKENS)
if len(tpot) > 0:
all_tpots.extend(tpot.tolist())
if not all_tpots:
return None
arr = np.array(all_tpots)
return BaselineResult(
tpot_p50_ms=float(np.percentile(arr, 50)),
tpot_p90_ms=float(np.percentile(arr, 90)),
tpot_p99_ms=float(np.percentile(arr, 99)),
tokens_collected=len(arr),
)
async def run_interference(
client: httpx.AsyncClient,
url: str,
model: str,
decode_batch_size: int,
new_prefill_tokens: int,
) -> Optional[InterferenceResult]:
"""Measure TPOT while a prefill request is being processed."""
if decode_batch_size == 0:
# No decode to interfere with; just measure prefill TTFT
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
t_start = time.perf_counter()
ts = await stream_tokens(client, url, prefill_payload)
prefill_ttft = (ts[0] - t_start) * 1000.0 if ts else 0
return InterferenceResult(
tpot_during_prefill_p50_ms=0,
tpot_during_prefill_p90_ms=0,
tpot_after_prefill_p50_ms=0,
prefill_ttft_ms=prefill_ttft,
num_tokens_during_prefill=0,
)
# Phase 1: Start decode streams
decode_payloads = [make_decode_payload(model) for _ in range(decode_batch_size)]
decode_timestamps: list[list[float]] = [[] for _ in range(decode_batch_size)]
prefill_done_event = asyncio.Event()
prefill_ttft_ms = 0.0
prefill_inject_time = 0.0
async def decode_stream_with_tracking(idx: int, payload: dict):
timestamps = await stream_tokens(client, url, payload)
decode_timestamps[idx] = timestamps
async def prefill_after_warmup():
nonlocal prefill_ttft_ms, prefill_inject_time
# Wait for decode streams to stabilize
await asyncio.sleep(1.0)
prefill_inject_time = time.perf_counter()
prefill_payload = make_prefill_payload(model, new_prefill_tokens)
ts = await stream_tokens(client, url, prefill_payload)
if ts:
prefill_ttft_ms = (ts[0] - prefill_inject_time) * 1000.0
prefill_done_event.set()
# Launch all
decode_tasks = [
asyncio.create_task(decode_stream_with_tracking(i, p))
for i, p in enumerate(decode_payloads)
]
prefill_task = asyncio.create_task(prefill_after_warmup())
await asyncio.gather(*decode_tasks, prefill_task, return_exceptions=True)
# Analyze: split decode tokens into "during prefill" and "after prefill"
prefill_end_time = prefill_inject_time + prefill_ttft_ms / 1000.0
tpot_during = []
tpot_after = []
for ts_list in decode_timestamps:
if len(ts_list) < WARMUP_TOKENS + 5:
continue
for i in range(WARMUP_TOKENS + 1, len(ts_list)):
t_prev = ts_list[i - 1]
t_curr = ts_list[i]
interval_ms = (t_curr - t_prev) * 1000.0
if prefill_inject_time <= t_prev <= prefill_end_time:
tpot_during.append(interval_ms)
elif t_curr > prefill_end_time + 0.05: # 50ms after prefill settles
tpot_after.append(interval_ms)
during_arr = np.array(tpot_during) if tpot_during else np.array([0.0])
after_arr = np.array(tpot_after) if tpot_after else np.array([0.0])
return InterferenceResult(
tpot_during_prefill_p50_ms=float(np.percentile(during_arr, 50)),
tpot_during_prefill_p90_ms=float(np.percentile(during_arr, 90)),
tpot_after_prefill_p50_ms=float(np.percentile(after_arr, 50)),
prefill_ttft_ms=prefill_ttft_ms,
num_tokens_during_prefill=len(tpot_during),
)
async def run_single_config(
client: httpx.AsyncClient,
url: str,
model: str,
decode_batch_size: int,
new_prefill_tokens: int,
chunk_size: int,
rep: int,
output_dir: Path,
):
"""Run one (D, P) configuration."""
config = Config(
decode_batch_size=decode_batch_size,
new_prefill_tokens=new_prefill_tokens,
chunk_size=chunk_size,
model=model,
repetition=rep,
)
print(f" [rep {rep}] Running baseline (D={decode_batch_size})...")
baseline = await run_baseline(client, url, model, decode_batch_size)
if baseline is None:
print(f" [rep {rep}] Baseline failed, skipping")
return
# Brief cooldown between baseline and interference
await asyncio.sleep(2.0)
print(f" [rep {rep}] Running interference (D={decode_batch_size}, P={new_prefill_tokens})...")
interference = await run_interference(
client, url, model, decode_batch_size, new_prefill_tokens
)
if interference is None:
print(f" [rep {rep}] Interference measurement failed, skipping")
return
# Compute derived metrics
tpot_penalty_p50 = interference.tpot_during_prefill_p50_ms - baseline.tpot_p50_ms
penalty_ratio = (
interference.tpot_during_prefill_p50_ms / baseline.tpot_p50_ms
if baseline.tpot_p50_ms > 0 else 0
)
result = {
"config": asdict(config),
"baseline": asdict(baseline),
"interference": asdict(interference),
"derived": {
"tpot_penalty_p50_ms": tpot_penalty_p50,
"tpot_penalty_ratio": penalty_ratio,
},
}
# Save
fname = f"D{decode_batch_size}_P{new_prefill_tokens}_rep{rep}.json"
out_path = output_dir / fname
out_path.write_text(json.dumps(result, indent=2))
print(f" [rep {rep}] Done. penalty={tpot_penalty_p50:.1f}ms ratio={penalty_ratio:.2f}")
async def main():
parser = argparse.ArgumentParser(description="Prefill-Decode Interference Microbenchmark")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", default="Qwen3-Coder-30B-A3B-Instruct")
parser.add_argument("--decode-batch-sizes", default="0,1,2,4,6,8,12",
help="Comma-separated decode batch sizes")
parser.add_argument("--prefill-tokens", default="512,1024,2048,4096,8192,16384,32768",
help="Comma-separated prefill token counts")
parser.add_argument("--chunk-size", type=int, default=8192,
help="vLLM max_num_batched_tokens (effective chunk size)")
parser.add_argument("--reps", type=int, default=5)
parser.add_argument("--output-dir", default="results/interference")
args = parser.parse_args()
decode_sizes = [int(x) for x in args.decode_batch_sizes.split(",")]
prefill_tokens = [int(x) for x in args.prefill_tokens.split(",")]
output_dir = Path(args.output_dir) / f"chunk{args.chunk_size}"
output_dir.mkdir(parents=True, exist_ok=True)
url = f"http://{args.host}:{args.port}/v1/chat/completions"
print(f"Target: {url}")
print(f"Model: {args.model}")
print(f"Chunk size: {args.chunk_size}")
print(f"Decode batch sizes: {decode_sizes}")
print(f"Prefill tokens: {prefill_tokens}")
print(f"Repetitions: {args.reps}")
print(f"Output: {output_dir}")
print()
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as client:
# Sanity check: is the server up?
try:
resp = await client.get(f"http://{args.host}:{args.port}/v1/models")
resp.raise_for_status()
models = resp.json()
print(f"Server ready. Models: {[m['id'] for m in models.get('data', [])]}")
except Exception as e:
print(f"ERROR: Cannot reach server at {args.host}:{args.port}: {e}")
return
total_configs = len(decode_sizes) * len(prefill_tokens)
done = 0
for D in decode_sizes:
for P in prefill_tokens:
done += 1
print(f"\n[{done}/{total_configs}] D={D}, P={P}")
for rep in range(args.reps):
try:
await run_single_config(
client, url, args.model, D, P,
args.chunk_size, rep, output_dir,
)
except Exception as e:
print(f" [rep {rep}] ERROR: {e}")
# Cooldown between reps
await asyncio.sleep(1.0)
# Cooldown between configs
await asyncio.sleep(3.0)
print("\n\nDone! Results in:", output_dir)
# Generate summary CSV
await generate_summary(output_dir, args.chunk_size)
async def generate_summary(output_dir: Path, chunk_size: int):
"""Aggregate all per-run JSONs into a summary CSV."""
import csv
rows = []
for f in sorted(output_dir.glob("D*_P*_rep*.json")):
data = json.loads(f.read_text())
cfg = data["config"]
bl = data["baseline"]
itf = data["interference"]
drv = data["derived"]
rows.append({
"chunk_size": cfg["chunk_size"],
"decode_batch_size": cfg["decode_batch_size"],
"new_prefill_tokens": cfg["new_prefill_tokens"],
"repetition": cfg["repetition"],
"tpot_baseline_p50_ms": bl["tpot_p50_ms"],
"tpot_baseline_p90_ms": bl["tpot_p90_ms"],
"tpot_during_prefill_p50_ms": itf["tpot_during_prefill_p50_ms"],
"tpot_during_prefill_p90_ms": itf["tpot_during_prefill_p90_ms"],
"tpot_after_prefill_p50_ms": itf["tpot_after_prefill_p50_ms"],
"prefill_ttft_ms": itf["prefill_ttft_ms"],
"num_tokens_during_prefill": itf["num_tokens_during_prefill"],
"tpot_penalty_p50_ms": drv["tpot_penalty_p50_ms"],
"tpot_penalty_ratio": drv["tpot_penalty_ratio"],
})
if not rows:
return
csv_path = output_dir / "summary.csv"
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
writer.writeheader()
writer.writerows(rows)
print(f"Summary CSV written: {csv_path}")
if __name__ == "__main__":
asyncio.run(main())