Agentic workload PD separation analysis with trace-driven benchmarks

Systematic study of prefill-decode disaggregation for agentic LLM workloads
using production GLM-5.1 coder trace (2.1M requests, 71B input tokens).

Key findings:
- Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7%
  without PD separation, matching PD-Sep's decode isolation benefit
- PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain
  when using the same cache-aware scheduler
- Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x
  vs decode AI <2), but absolute FLOPs drop 71% from cache hits
- For agentic MoE workloads, cache-aware routing > PD separation

Infrastructure:
- Trace sampler preserving session structure + hash_ids for prefix sharing
- Async trace replayer with streaming TTFT/TPOT/E2E measurement
- Unified cache-aware + token-level load-balanced global scheduler proxy
  supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes
- vLLM 0.18.1 scheduler patch for KV transfer abort race condition
- Roofline analysis tool for prefill/decode compute characterization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:21:57 +08:00
commit 05592e6adc
22 changed files with 2837 additions and 0 deletions

0
replayer/__init__.py Normal file
View File

55
replayer/__main__.py Normal file
View File

@@ -0,0 +1,55 @@
"""CLI entry point: python -m replayer replay ..."""
from __future__ import annotations
import argparse
import asyncio
import logging
from pathlib import Path
from .replay import ReplayConfig, replay_trace
def main() -> None:
p = argparse.ArgumentParser(description="Trace replayer for vLLM benchmarking")
p.add_argument("--trace", type=Path, required=True, help="Sampled trace JSONL")
p.add_argument("--output", type=Path, required=True, help="Output metrics JSONL")
p.add_argument("--endpoint", type=str, required=True,
help="vLLM server URL (e.g. http://localhost:8000)")
p.add_argument("--model", type=str, default="default", help="Model name for API")
p.add_argument("--time-scale", type=float, default=1.0,
help="Time compression (>1 = faster)")
p.add_argument("--max-inflight-sessions", type=int, default=32)
p.add_argument("--concurrency-limit", type=int, default=256)
p.add_argument("--request-timeout", type=float, default=600.0)
p.add_argument("--request-limit", type=int, default=None,
help="Limit number of requests to replay")
p.add_argument("-v", "--verbose", action="store_true")
args = p.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
config = ReplayConfig(
trace_path=args.trace,
output_path=args.output,
endpoint_url=args.endpoint.rstrip("/"),
model_name=args.model,
time_scale=args.time_scale,
max_inflight_sessions=args.max_inflight_sessions,
concurrency_limit=args.concurrency_limit,
request_timeout_s=args.request_timeout,
request_limit=args.request_limit,
)
results = asyncio.run(replay_trace(config))
succeeded = sum(1 for r in results if r.error is None)
print(f"\nDone: {succeeded}/{len(results)} requests succeeded")
print(f"Metrics: {args.output}")
print(f"Summary: {args.output.with_suffix('.summary.json')}")
if __name__ == "__main__":
main()

107
replayer/metrics.py Normal file
View File

@@ -0,0 +1,107 @@
"""Per-request metrics collection and summary reporting."""
from __future__ import annotations
import asyncio
import json
import statistics
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class RequestMetrics:
request_id: str
session_id: str
turn_id: int
trace_timestamp_s: float
input_length: int
output_length: int
request_type: str
effective_input_length: int | None
cached_tokens: int
latency_s: float | None
ttft_s: float | None
tpot_s: float | None
actual_output_tokens: int | None = None
requested_output_tokens: int | None = None
finish_reason: str | None = None
error: str | None = None
class IncrementalMetricSink:
"""Append each RequestMetrics to JSONL immediately (crash-safe)."""
def __init__(self, path: Path):
self.path = path
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("")
self._lock = asyncio.Lock()
self._fh = path.open("a", encoding="utf-8", buffering=1)
async def append(self, metric: RequestMetrics) -> None:
line = json.dumps(asdict(metric), sort_keys=True) + "\n"
async with self._lock:
self._fh.write(line)
self._fh.flush()
def close(self) -> None:
try:
self._fh.flush()
self._fh.close()
except Exception:
pass
def write_summary_json(path: Path, rows: list[RequestMetrics]) -> None:
successful = [r for r in rows if r.error is None]
latencies = [r.latency_s for r in successful if r.latency_s is not None]
ttfts = [r.ttft_s for r in successful if r.ttft_s is not None]
tpots = [r.tpot_s for r in successful if r.tpot_s is not None]
total_input = sum(r.input_length for r in successful)
total_cached = sum(r.cached_tokens for r in successful)
summary: dict[str, Any] = {
"request_count": len(rows),
"success_count": len(successful),
"error_count": sum(1 for r in rows if r.error is not None),
"latency_stats_s": _stats(latencies),
"ttft_stats_s": _stats(ttfts),
"tpot_stats_s": _stats(tpots),
"cache_hit_request_count": sum(1 for r in successful if r.cached_tokens > 0),
"total_input_tokens": total_input,
"total_cached_tokens": total_cached,
"prefix_cache_hit_ratio": total_cached / total_input if total_input > 0 else 0.0,
"cached_tokens_stats": _stats([float(r.cached_tokens) for r in successful]),
"actual_output_tokens_stats": _stats(
[float(r.actual_output_tokens) for r in successful
if r.actual_output_tokens is not None]
),
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
json.dump(summary, fh, indent=2, sort_keys=True)
def _stats(values: list[float | None]) -> dict[str, float] | None:
clean = [v for v in values if v is not None]
if not clean:
return None
clean.sort()
return {
"count": float(len(clean)),
"mean": statistics.fmean(clean),
"p50": _percentile(clean, 0.50),
"p90": _percentile(clean, 0.90),
"p99": _percentile(clean, 0.99),
}
def _percentile(sorted_vals: list[float], pct: float) -> float:
if len(sorted_vals) == 1:
return sorted_vals[0]
idx = round((len(sorted_vals) - 1) * pct)
return sorted_vals[idx]

343
replayer/replay.py Normal file
View File

@@ -0,0 +1,343 @@
"""Trace replayer — send requests to vLLM following trace timing.
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
synthetic prompts that reproduce realistic prefix-cache hit patterns.
Key behaviors:
- Per-session sequencing: turns within a session are sent in order,
each waiting for the previous to complete before dispatching.
- Inter-session arrival: sessions start at their trace timestamps,
scaled by --time-scale.
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
--concurrency-limit caps total in-flight requests.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import random as _random
import httpx
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
from .trace import TraceRequest, load_trace
logger = logging.getLogger(__name__)
BLOCK_SIZE = 512
VOCAB_SIZE = 151936
TOKEN_RANGE_START = 100
TOKEN_RANGE_END = VOCAB_SIZE - 100
_block_cache: dict[int, list[int]] = {}
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
if hash_id in _block_cache:
return _block_cache[hash_id]
rng = _random.Random(hash_id)
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
_block_cache[hash_id] = ids
return ids
@dataclass
class ReplayConfig:
trace_path: Path
output_path: Path
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
time_scale: float = 1.0
max_inflight_sessions: int = 32
concurrency_limit: int = 256
request_timeout_s: float = 600.0
request_limit: int | None = None
model_name: str = "default"
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
"""Build token IDs from hash_ids for prefix-cache-aware replay.
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
"""
ids: list[int] = []
for hid in req.hash_ids:
ids.extend(_hash_id_to_token_ids(hid))
# Pad to input_length with deterministic tokens
pad_rng = _random.Random(req.chat_id)
while len(ids) < req.input_length:
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
return ids[:req.input_length]
@dataclass
class _SessionState:
session_id: str
turns: list[TraceRequest]
metrics: list[RequestMetrics] = field(default_factory=list)
_endpoint_counter = 0
def _pick_endpoint(config: ReplayConfig) -> str:
"""Round-robin across comma-separated endpoints."""
global _endpoint_counter
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
url = endpoints[_endpoint_counter % len(endpoints)]
_endpoint_counter += 1
return url
async def _dispatch_request(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
req: TraceRequest,
prompt_token_ids: list[int],
sem: asyncio.Semaphore,
) -> RequestMetrics:
"""Send one request via /v1/completions (streaming) and collect metrics."""
endpoint = _pick_endpoint(config)
payload = {
"model": config.model_name,
"prompt": prompt_token_ids,
"max_tokens": max(1, req.output_length),
"temperature": 0,
"stream": True,
"stream_options": {"include_usage": True},
}
start = time.perf_counter()
ttft_s = None
n_output = 0
cached_tokens = 0
finish_reason = None
err = None
token_times: list[float] = []
async with sem:
try:
async with client.stream(
"POST",
f"{endpoint}/v1/completions",
json=payload,
timeout=config.request_timeout_s,
) as resp:
resp.raise_for_status()
async for raw_line in resp.aiter_lines():
if not raw_line or not raw_line.startswith("data:"):
continue
data = raw_line[5:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except json.JSONDecodeError:
continue
now = time.perf_counter()
if ttft_s is None:
ttft_s = now - start
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("text", "")
if delta:
token_times.append(now)
fr = choices[0].get("finish_reason")
if fr:
finish_reason = fr
usage = chunk.get("usage")
if usage:
n_output = usage.get("completion_tokens", n_output)
cached_tokens = _extract_cached_tokens(usage)
except Exception as exc:
err = repr(exc)[:300]
end = time.perf_counter()
e2e = end - start
if n_output == 0 and token_times:
n_output = len(token_times)
tpot = 0.0
if len(token_times) > 1:
inter_token = [token_times[i+1] - token_times[i]
for i in range(len(token_times) - 1)]
tpot = sum(inter_token) / len(inter_token)
return RequestMetrics(
request_id=req.request_id,
session_id=req.session_id,
turn_id=req.turn_id,
trace_timestamp_s=req.timestamp_s,
input_length=req.input_length,
output_length=req.output_length,
request_type=req.request_type,
effective_input_length=len(prompt_token_ids),
cached_tokens=cached_tokens,
latency_s=e2e,
ttft_s=ttft_s,
tpot_s=tpot,
actual_output_tokens=n_output,
requested_output_tokens=req.output_length,
finish_reason=finish_reason,
error=err,
)
def _extract_cached_tokens(usage: dict) -> int:
ct = 0
details = usage.get("prompt_tokens_details")
if isinstance(details, dict):
ct = details.get("cached_tokens", 0) or 0
if ct == 0:
ct = usage.get("cached_tokens", 0) or 0
return int(ct)
async def _run_session(
*,
state: _SessionState,
config: ReplayConfig,
client: httpx.AsyncClient,
session_sem: asyncio.Semaphore,
request_sem: asyncio.Semaphore,
earliest_ts: float,
sweep_start: float,
sink: IncrementalMetricSink,
) -> list[RequestMetrics]:
async with session_sem:
# Wait until this session's start time
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
wait = offset - (time.perf_counter() - sweep_start)
if wait > 0:
await asyncio.sleep(wait)
for req in state.turns:
# Intra-session: wait for turn's relative offset
if req != state.turns[0]:
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
elapsed = time.perf_counter() - sweep_start - offset
if elapsed < target:
await asyncio.sleep(target - elapsed)
token_ids = _build_prompt_token_ids(req)
metric = await _dispatch_request(
client=client, config=config, req=req,
prompt_token_ids=token_ids, sem=request_sem,
)
state.metrics.append(metric)
await sink.append(metric)
return state.metrics
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
total = {"queries": 0.0, "hits": 0.0}
endpoints = [e.strip() for e in url_csv.split(",")]
async with httpx.AsyncClient(timeout=10) as c:
for url in endpoints:
try:
r = await c.get(f"{url}/metrics")
for line in r.text.split("\n"):
if line.startswith("vllm:prefix_cache_queries_total"):
total["queries"] += float(line.split()[-1])
elif line.startswith("vllm:prefix_cache_hits_total"):
total["hits"] += float(line.split()[-1])
except Exception:
pass
return total
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
"""Main entry: load trace, replay against endpoint, return metrics."""
requests = load_trace(config.trace_path, request_limit=config.request_limit)
if not requests:
return []
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for r in requests:
by_session[r.session_id].append(r)
for sid in by_session:
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
earliest_ts = sessions[0][1][0].timestamp_s
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
request_sem = asyncio.Semaphore(config.concurrency_limit)
sink = IncrementalMetricSink(config.output_path)
n_sessions = len(sessions)
n_requests = len(requests)
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
n_sessions, n_requests, config.time_scale)
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
sweep_start = time.perf_counter()
try:
limits = httpx.Limits(
max_connections=2000,
max_keepalive_connections=500,
keepalive_expiry=30.0,
)
async with httpx.AsyncClient(
timeout=config.request_timeout_s,
trust_env=False,
limits=limits,
) as client:
tasks = [
asyncio.create_task(_run_session(
state=_SessionState(session_id=sid, turns=turns),
config=config, client=client,
session_sem=session_sem, request_sem=request_sem,
earliest_ts=earliest_ts, sweep_start=sweep_start,
sink=sink,
))
for sid, turns in sessions
]
all_results = await asyncio.gather(*tasks)
finally:
sink.close()
sweep_elapsed = time.perf_counter() - sweep_start
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
flat = [m for group in all_results for m in group]
summary_path = config.output_path.with_suffix(".summary.json")
write_summary_json(summary_path, flat)
# Compute aggregate prefix cache hit ratio from /metrics deltas
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
hit_ratio * 100, int(delta_hits), int(delta_queries))
# Append cache stats to summary
import json as _json
summary = _json.loads(summary_path.read_text())
summary["prefix_cache_queries_tokens"] = int(delta_queries)
summary["prefix_cache_hits_tokens"] = int(delta_hits)
summary["prefix_cache_hit_ratio"] = hit_ratio
summary["wall_clock_s"] = sweep_elapsed
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
logger.info("Summary written to %s", summary_path)
return flat

84
replayer/trace.py Normal file
View File

@@ -0,0 +1,84 @@
"""Trace data structures and loader for the Ali agentic-coder trace format.
Trace format (one JSON per line):
chat_id, parent_chat_id, timestamp, input_length, output_length,
type, turn, hash_ids[]
Sessions are derived from parent_chat_id chains:
- parent_chat_id == -1 → new session root
- parent_chat_id >= 0 → belongs to the same session as the parent
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class TraceRequest:
request_id: str
session_id: str
chat_id: int
parent_chat_id: int
timestamp_s: float
input_length: int
output_length: int
request_type: str
turn_id: int
hash_ids: tuple[int, ...]
def load_trace(
path: Path,
*,
request_limit: int | None = None,
) -> list[TraceRequest]:
"""Load trace and resolve session IDs from parent_chat_id chains."""
chat_to_session: dict[int, str] = {}
requests: list[TraceRequest] = []
with path.open("r", encoding="utf-8") as fh:
for idx, line in enumerate(fh):
if request_limit is not None and len(requests) >= request_limit:
break
row = json.loads(line)
chat_id = int(row["chat_id"])
parent_chat_id = int(row["parent_chat_id"])
if "session_id" in row:
session_id = str(row["session_id"])
else:
session_id = _resolve_session_id(
chat_id, parent_chat_id, chat_to_session,
)
chat_to_session[chat_id] = session_id
requests.append(TraceRequest(
request_id=f"{session_id}:{row['turn']}:{chat_id}:{idx}",
session_id=session_id,
chat_id=chat_id,
parent_chat_id=parent_chat_id,
timestamp_s=float(row["timestamp"]),
input_length=int(row["input_length"]),
output_length=int(row["output_length"]),
request_type=str(row["type"]),
turn_id=int(row["turn"]),
hash_ids=tuple(int(h) for h in row.get("hash_ids", [])),
))
return requests
def _resolve_session_id(
chat_id: int,
parent_chat_id: int,
chat_to_session: dict[int, str],
) -> str:
if parent_chat_id < 0:
session_id = str(chat_id)
else:
session_id = chat_to_session.get(parent_chat_id, str(parent_chat_id))
chat_to_session[chat_id] = session_id
return session_id