Files
Gahow Wang 72790ae6c1 PD-sep server-side profiling: vLLM patches + per-request breakdown
Instrumentation patches (microbench/patches/):
  - pd_profile.py: shared event emitter (VLLM_PD_PROFILE_LOG env var)
  - apply_patches.py: idempotent patch installer for mooncake_connector.py
    and scheduler.py, marks insertions with # PD_PROFILE_PATCH
  - analyze_events.py: joins per-process JSONL event logs by transfer_id
    into per-request phase durations

Seven events captured per request:
  D_get_num_matched → P_zmq_received → P_prefill_done →
  P_rdma_start → P_rdma_end → D_recv_complete → D_request_promoted

Driver fix (microbench/lifecycle/driver.py):
  seed_prefix_cache now sends via the proxy URL so P and D both cache
  the seeded prefix with matching block hashes. Previously seeding D
  directly produced different block hashes than the proxy-routed
  measurement requests, making incremental transfer impossible.

Real breakdown (fig_breakdown_real.png, server_breakdown.csv, n=93):
  prefill_compute  620 ms median (95% of overhead)
  rdma_transfer     42 ms median (~71 Gbps effective)
  other overhead    10 ms median (dispatch + params + signal + promote)

Mooncake transfer is NOT the bottleneck. Even with bulk RDMA the
transfer cost is <10% of prefill cost for Qwen3-30B-A3B on H20.
2026-05-26 13:59:09 +08:00

420 lines
15 KiB
Python

#!/usr/bin/env python3
"""PD Transfer Lifecycle Breakdown Microbenchmark Driver.
Profiles the complete request lifecycle under PD disaggregation:
routing → P queue → P prefill → ZMQ handshake → RDMA transfer → D startup → D decode
Three independent variables:
- prior_context (C): tokens already cached on D from prior turns
- current_new_tokens (N): tokens P must prefill and transfer
- output_length (O): decode tokens D generates
Usage:
python driver.py --p-host 127.0.0.1 --p-port 8000 --d-host 127.0.0.1 --d-port 8001 \
--prior-contexts 0,4096,16384,32768,65536,100000 \
--new-tokens 512,2048,4096,8192,16384,32768 \
--output-lengths 1,32,128,512 \
--reps 5 --output-dir results/lifecycle
"""
import argparse
import asyncio
import hashlib
import json
import os
import time
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Optional
import httpx
import numpy as np
@dataclass
class LifecycleConfig:
prior_context: int
current_new_tokens: int
output_length: int
total_input_length: int
model: str
repetition: int
@dataclass
class LifecycleBreakdown:
# Client-observable timestamps (ms)
request_sent_to_first_token_ms: float # TTFT (includes all server-side phases)
first_token_to_last_token_ms: float # Decode time
e2e_ms: float # Total
# If server-side instrumentation available (from logs):
server_breakdown: Optional[dict] = None
def make_context_prompt(num_tokens: int, session_id: str = "default") -> str:
"""Generate deterministic prompt content of approximately num_tokens tokens.
Uses a fixed seed so the same (num_tokens, session_id) always produces the same prefix
(required for prefix cache hits across calls).
"""
if num_tokens == 0:
return ""
# Each "chunk" is ~50 tokens. Generate enough chunks.
parts = []
chunks_needed = (num_tokens // 50) + 1
for i in range(chunks_needed):
seed = hashlib.sha256(f"{session_id}_ctx_{i}".encode()).hexdigest()
parts.append(
f"[Context block {i}] The system processes request {seed[:16]} "
f"with parameters alpha={seed[16:20]} beta={seed[20:24]} "
f"resulting in state transition {seed[24:32]}. "
)
return " ".join(parts)
def make_new_tokens_prompt(num_tokens: int, unique_id: str) -> str:
"""Generate unique new content (guaranteed no prefix cache hit)."""
parts = []
chunks_needed = (num_tokens // 50) + 1
for i in range(chunks_needed):
seed = hashlib.sha256(f"{unique_id}_new_{i}_{time.time_ns()}".encode()).hexdigest()
parts.append(
f"[New block {i}] Analyze document {seed[:16]} considering "
f"factors {seed[16:24]} and constraints {seed[24:32]}. "
)
return " ".join(parts)
async def seed_prefix_cache(
client: httpx.AsyncClient, url: str, model: str, num_tokens: int, session_id: str
) -> bool:
"""Warm BOTH P and D prefix caches by sending the seed through the PD-sep proxy.
Sending directly to D would only warm D's cache but produce block hashes that
don't match what P later produces (different tokenization path). Sending through
the proxy makes P do the prefill (P caches), pushes KV to D (D caches with
matching hashes), so subsequent requests with the same prefix get incremental
transfer on both sides.
Returns True if successful.
"""
if num_tokens == 0:
return True
prompt = make_context_prompt(num_tokens, session_id)
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1,
"temperature": 0,
"stream": False,
}
try:
resp = await client.post(url, json=payload, timeout=120.0)
resp.raise_for_status()
result = resp.json()
usage = result.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
print(f" Cache seeded: {prompt_tokens} prompt tokens processed")
return True
except Exception as e:
print(f" Cache seed FAILED: {e}")
return False
async def measure_lifecycle(
client: httpx.AsyncClient,
url: str,
model: str,
prior_context: int,
current_new_tokens: int,
output_length: int,
session_id: str,
) -> Optional[LifecycleBreakdown]:
"""Send a PD-sep request and measure lifecycle timestamps.
The request has:
- prefix: prior_context tokens (should hit D's prefix cache)
- suffix: current_new_tokens tokens (must be prefilled by P and transferred)
"""
# Build prompt: shared prefix (cached on D) + unique suffix (cold)
prefix = make_context_prompt(prior_context, session_id)
suffix = make_new_tokens_prompt(current_new_tokens, f"{session_id}_{time.time_ns()}")
full_prompt = prefix + "\n\n" + suffix if prefix else suffix
payload = {
"model": model,
"messages": [{"role": "user", "content": full_prompt}],
"max_tokens": output_length,
"temperature": 0,
"stream": True,
}
timestamps = []
t_send = time.perf_counter()
try:
async with client.stream("POST", url, json=payload, timeout=300.0) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
try:
chunk = json.loads(data)
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
if "role" in delta:
continue
timestamps.append(time.perf_counter())
except json.JSONDecodeError:
continue
except Exception as e:
print(f" Request failed: {e}")
return None
if not timestamps:
print(" No tokens received")
return None
t_first = timestamps[0]
t_last = timestamps[-1] if len(timestamps) > 1 else t_first
ttft_ms = (t_first - t_send) * 1000.0
decode_ms = (t_last - t_first) * 1000.0
e2e_ms = (t_last - t_send) * 1000.0
return LifecycleBreakdown(
request_sent_to_first_token_ms=ttft_ms,
first_token_to_last_token_ms=decode_ms,
e2e_ms=e2e_ms,
)
async def measure_colocated_baseline(
client: httpx.AsyncClient,
url: str,
model: str,
prior_context: int,
current_new_tokens: int,
output_length: int,
session_id: str,
) -> Optional[LifecycleBreakdown]:
"""Same request on combined (no PD-sep) instance for comparison."""
return await measure_lifecycle(
client, url, model, prior_context, current_new_tokens, output_length, session_id
)
async def run_config(
client: httpx.AsyncClient,
pdsep_url: str,
colo_url: Optional[str],
model: str,
prior_context: int,
current_new_tokens: int,
output_length: int,
rep: int,
output_dir: Path,
session_id: str,
) -> dict:
"""Run one configuration: PD-sep measurement + optional colo baseline."""
config = LifecycleConfig(
prior_context=prior_context,
current_new_tokens=current_new_tokens,
output_length=output_length,
total_input_length=prior_context + current_new_tokens,
model=model,
repetition=rep,
)
# PD-sep measurement
pdsep_result = await measure_lifecycle(
client, pdsep_url, model, prior_context, current_new_tokens, output_length, session_id
)
# Colocated baseline (if URL provided)
colo_result = None
if colo_url:
colo_result = await measure_colocated_baseline(
client, colo_url, model, prior_context, current_new_tokens, output_length, session_id
)
result = {
"config": asdict(config),
"pdsep": asdict(pdsep_result) if pdsep_result else None,
"colocated": asdict(colo_result) if colo_result else None,
}
if pdsep_result and colo_result:
result["overhead"] = {
"ttft_overhead_ms": pdsep_result.request_sent_to_first_token_ms - colo_result.request_sent_to_first_token_ms,
"e2e_overhead_ms": pdsep_result.e2e_ms - colo_result.e2e_ms,
}
# Save
fname = f"C{prior_context}_N{current_new_tokens}_O{output_length}_rep{rep}.json"
out_path = output_dir / fname
out_path.write_text(json.dumps(result, indent=2))
if pdsep_result:
print(f" TTFT={pdsep_result.request_sent_to_first_token_ms:.0f}ms "
f"decode={pdsep_result.first_token_to_last_token_ms:.0f}ms "
f"E2E={pdsep_result.e2e_ms:.0f}ms")
return result
async def main():
parser = argparse.ArgumentParser(description="PD Transfer Lifecycle Breakdown Microbench")
parser.add_argument("--pdsep-url", required=True,
help="PD-sep endpoint URL (proxy or D instance)")
parser.add_argument("--colo-url", default=None,
help="Co-located baseline URL (optional, for overhead comparison)")
parser.add_argument("--seed-url", default=None,
help="D-instance URL for cache seeding (if different from pdsep-url)")
parser.add_argument("--model", default="Qwen3-Coder-30B-A3B-Instruct")
parser.add_argument("--prior-contexts", default="0,4096,16384,32768,65536,100000",
help="Comma-separated prior context sizes")
parser.add_argument("--new-tokens", default="512,2048,4096,8192,16384,32768",
help="Comma-separated new token counts")
parser.add_argument("--output-lengths", default="1,32,128,512",
help="Comma-separated output lengths")
parser.add_argument("--reps", type=int, default=5)
parser.add_argument("--output-dir", default="results/lifecycle")
parser.add_argument("--session-id", default="bench_session_0",
help="Session ID for deterministic prefix generation")
args = parser.parse_args()
prior_contexts = [int(x) for x in args.prior_contexts.split(",")]
new_tokens_list = [int(x) for x in args.new_tokens.split(",")]
output_lengths = [int(x) for x in args.output_lengths.split(",")]
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
pdsep_url = args.pdsep_url.rstrip("/") + "/v1/chat/completions"
colo_url = (args.colo_url.rstrip("/") + "/v1/chat/completions") if args.colo_url else None
seed_url = args.seed_url or args.pdsep_url
seed_endpoint = seed_url.rstrip("/") + "/v1/chat/completions"
print(f"PD-sep endpoint: {pdsep_url}")
print(f"Colo endpoint: {colo_url or 'none'}")
print(f"Seed endpoint: {seed_endpoint}")
print(f"Model: {args.model}")
print(f"Prior contexts: {prior_contexts}")
print(f"New tokens: {new_tokens_list}")
print(f"Output lengths: {output_lengths}")
print(f"Repetitions: {args.reps}")
print()
total_configs = len(prior_contexts) * len(new_tokens_list) * len(output_lengths)
done = 0
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as client:
# Quick connectivity check (proxy may not implement /v1/models)
print("Starting sweep (server connectivity will be verified on first request)...")
for C in prior_contexts:
print(f"\n{'='*60}")
print(f"Prior context C={C} tokens")
print(f"{'='*60}")
# Seed BOTH P and D prefix caches via the proxy
if C > 0:
print(f" Seeding P+D prefix caches with {C} tokens via proxy...")
success = await seed_prefix_cache(
client, pdsep_url, args.model, C, args.session_id
)
if not success:
print(f" SKIP all configs with C={C} (cache seed failed)")
done += len(new_tokens_list) * len(output_lengths)
continue
await asyncio.sleep(2.0)
for N in new_tokens_list:
for O in output_lengths:
done += 1
print(f"\n [{done}/{total_configs}] C={C}, N={N}, O={O}")
for rep in range(args.reps):
try:
await run_config(
client, pdsep_url, colo_url, args.model,
C, N, O, rep, output_dir, args.session_id,
)
except Exception as e:
print(f" [rep {rep}] ERROR: {e}")
await asyncio.sleep(1.0)
await asyncio.sleep(2.0)
# Note: we do NOT evict cache between C values because each C uses
# a deterministic prefix. Larger C is a superset of smaller C.
print(f"\n\nDone! Results in: {output_dir}")
generate_summary_csv(output_dir)
def generate_summary_csv(output_dir: Path):
"""Aggregate results into summary CSV."""
import csv
rows = []
for f in sorted(output_dir.glob("C*_N*_O*_rep*.json")):
data = json.loads(f.read_text())
cfg = data["config"]
pdsep = data.get("pdsep")
colo = data.get("colocated")
overhead = data.get("overhead")
row = {
"prior_context": cfg["prior_context"],
"new_tokens": cfg["current_new_tokens"],
"output_length": cfg["output_length"],
"total_input": cfg["total_input_length"],
"repetition": cfg["repetition"],
}
if pdsep:
row["pdsep_ttft_ms"] = pdsep["request_sent_to_first_token_ms"]
row["pdsep_decode_ms"] = pdsep["first_token_to_last_token_ms"]
row["pdsep_e2e_ms"] = pdsep["e2e_ms"]
if colo:
row["colo_ttft_ms"] = colo["request_sent_to_first_token_ms"]
row["colo_decode_ms"] = colo["first_token_to_last_token_ms"]
row["colo_e2e_ms"] = colo["e2e_ms"]
if overhead:
row["ttft_overhead_ms"] = overhead["ttft_overhead_ms"]
row["e2e_overhead_ms"] = overhead["e2e_overhead_ms"]
rows.append(row)
if not rows:
return
csv_path = output_dir / "summary.csv"
fieldnames = list(rows[0].keys())
# Ensure all rows have all fields
all_fields = set()
for r in rows:
all_fields.update(r.keys())
fieldnames = sorted(all_fields)
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
print(f"Summary CSV written: {csv_path}")
if __name__ == "__main__":
asyncio.run(main())