Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
344 lines
12 KiB
Python
344 lines
12 KiB
Python
"""Trace replayer — send requests to vLLM following trace timing.
|
|
|
|
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
|
|
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
|
|
synthetic prompts that reproduce realistic prefix-cache hit patterns.
|
|
|
|
Key behaviors:
|
|
- Per-session sequencing: turns within a session are sent in order,
|
|
each waiting for the previous to complete before dispatching.
|
|
- Inter-session arrival: sessions start at their trace timestamps,
|
|
scaled by --time-scale.
|
|
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
|
|
--concurrency-limit caps total in-flight requests.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import random as _random
|
|
|
|
import httpx
|
|
|
|
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
|
from .trace import TraceRequest, load_trace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BLOCK_SIZE = 512
|
|
VOCAB_SIZE = 151936
|
|
TOKEN_RANGE_START = 100
|
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
|
|
|
_block_cache: dict[int, list[int]] = {}
|
|
|
|
|
|
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
|
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
|
if hash_id in _block_cache:
|
|
return _block_cache[hash_id]
|
|
rng = _random.Random(hash_id)
|
|
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
|
_block_cache[hash_id] = ids
|
|
return ids
|
|
|
|
|
|
@dataclass
|
|
class ReplayConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
|
time_scale: float = 1.0
|
|
max_inflight_sessions: int = 32
|
|
concurrency_limit: int = 256
|
|
request_timeout_s: float = 600.0
|
|
request_limit: int | None = None
|
|
model_name: str = "default"
|
|
|
|
|
|
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
|
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
|
|
|
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
|
"""
|
|
ids: list[int] = []
|
|
for hid in req.hash_ids:
|
|
ids.extend(_hash_id_to_token_ids(hid))
|
|
# Pad to input_length with deterministic tokens
|
|
pad_rng = _random.Random(req.chat_id)
|
|
while len(ids) < req.input_length:
|
|
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
|
return ids[:req.input_length]
|
|
|
|
|
|
@dataclass
|
|
class _SessionState:
|
|
session_id: str
|
|
turns: list[TraceRequest]
|
|
metrics: list[RequestMetrics] = field(default_factory=list)
|
|
|
|
|
|
_endpoint_counter = 0
|
|
|
|
|
|
def _pick_endpoint(config: ReplayConfig) -> str:
|
|
"""Round-robin across comma-separated endpoints."""
|
|
global _endpoint_counter
|
|
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
|
url = endpoints[_endpoint_counter % len(endpoints)]
|
|
_endpoint_counter += 1
|
|
return url
|
|
|
|
|
|
async def _dispatch_request(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
req: TraceRequest,
|
|
prompt_token_ids: list[int],
|
|
sem: asyncio.Semaphore,
|
|
) -> RequestMetrics:
|
|
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
|
endpoint = _pick_endpoint(config)
|
|
payload = {
|
|
"model": config.model_name,
|
|
"prompt": prompt_token_ids,
|
|
"max_tokens": max(1, req.output_length),
|
|
"temperature": 0,
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
}
|
|
|
|
start = time.perf_counter()
|
|
ttft_s = None
|
|
n_output = 0
|
|
cached_tokens = 0
|
|
finish_reason = None
|
|
err = None
|
|
token_times: list[float] = []
|
|
|
|
async with sem:
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{endpoint}/v1/completions",
|
|
json=payload,
|
|
timeout=config.request_timeout_s,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for raw_line in resp.aiter_lines():
|
|
if not raw_line or not raw_line.startswith("data:"):
|
|
continue
|
|
data = raw_line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
now = time.perf_counter()
|
|
if ttft_s is None:
|
|
ttft_s = now - start
|
|
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("text", "")
|
|
if delta:
|
|
token_times.append(now)
|
|
fr = choices[0].get("finish_reason")
|
|
if fr:
|
|
finish_reason = fr
|
|
|
|
usage = chunk.get("usage")
|
|
if usage:
|
|
n_output = usage.get("completion_tokens", n_output)
|
|
cached_tokens = _extract_cached_tokens(usage)
|
|
except Exception as exc:
|
|
err = repr(exc)[:300]
|
|
|
|
end = time.perf_counter()
|
|
e2e = end - start
|
|
if n_output == 0 and token_times:
|
|
n_output = len(token_times)
|
|
|
|
tpot = 0.0
|
|
if len(token_times) > 1:
|
|
inter_token = [token_times[i+1] - token_times[i]
|
|
for i in range(len(token_times) - 1)]
|
|
tpot = sum(inter_token) / len(inter_token)
|
|
|
|
return RequestMetrics(
|
|
request_id=req.request_id,
|
|
session_id=req.session_id,
|
|
turn_id=req.turn_id,
|
|
trace_timestamp_s=req.timestamp_s,
|
|
input_length=req.input_length,
|
|
output_length=req.output_length,
|
|
request_type=req.request_type,
|
|
effective_input_length=len(prompt_token_ids),
|
|
cached_tokens=cached_tokens,
|
|
latency_s=e2e,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot,
|
|
actual_output_tokens=n_output,
|
|
requested_output_tokens=req.output_length,
|
|
finish_reason=finish_reason,
|
|
error=err,
|
|
)
|
|
|
|
|
|
def _extract_cached_tokens(usage: dict) -> int:
|
|
ct = 0
|
|
details = usage.get("prompt_tokens_details")
|
|
if isinstance(details, dict):
|
|
ct = details.get("cached_tokens", 0) or 0
|
|
if ct == 0:
|
|
ct = usage.get("cached_tokens", 0) or 0
|
|
return int(ct)
|
|
|
|
|
|
async def _run_session(
|
|
*,
|
|
state: _SessionState,
|
|
config: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
session_sem: asyncio.Semaphore,
|
|
request_sem: asyncio.Semaphore,
|
|
earliest_ts: float,
|
|
sweep_start: float,
|
|
sink: IncrementalMetricSink,
|
|
) -> list[RequestMetrics]:
|
|
async with session_sem:
|
|
# Wait until this session's start time
|
|
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
|
|
wait = offset - (time.perf_counter() - sweep_start)
|
|
if wait > 0:
|
|
await asyncio.sleep(wait)
|
|
|
|
for req in state.turns:
|
|
# Intra-session: wait for turn's relative offset
|
|
if req != state.turns[0]:
|
|
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
|
|
elapsed = time.perf_counter() - sweep_start - offset
|
|
if elapsed < target:
|
|
await asyncio.sleep(target - elapsed)
|
|
|
|
token_ids = _build_prompt_token_ids(req)
|
|
metric = await _dispatch_request(
|
|
client=client, config=config, req=req,
|
|
prompt_token_ids=token_ids, sem=request_sem,
|
|
)
|
|
state.metrics.append(metric)
|
|
await sink.append(metric)
|
|
|
|
return state.metrics
|
|
|
|
|
|
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
|
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
|
total = {"queries": 0.0, "hits": 0.0}
|
|
endpoints = [e.strip() for e in url_csv.split(",")]
|
|
async with httpx.AsyncClient(timeout=10) as c:
|
|
for url in endpoints:
|
|
try:
|
|
r = await c.get(f"{url}/metrics")
|
|
for line in r.text.split("\n"):
|
|
if line.startswith("vllm:prefix_cache_queries_total"):
|
|
total["queries"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:prefix_cache_hits_total"):
|
|
total["hits"] += float(line.split()[-1])
|
|
except Exception:
|
|
pass
|
|
return total
|
|
|
|
|
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
|
"""Main entry: load trace, replay against endpoint, return metrics."""
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
if not requests:
|
|
return []
|
|
|
|
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
|
for r in requests:
|
|
by_session[r.session_id].append(r)
|
|
for sid in by_session:
|
|
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
|
|
|
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
|
earliest_ts = sessions[0][1][0].timestamp_s
|
|
|
|
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
|
|
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
|
|
|
sink = IncrementalMetricSink(config.output_path)
|
|
|
|
n_sessions = len(sessions)
|
|
n_requests = len(requests)
|
|
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
|
|
n_sessions, n_requests, config.time_scale)
|
|
|
|
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
sweep_start = time.perf_counter()
|
|
|
|
try:
|
|
limits = httpx.Limits(
|
|
max_connections=2000,
|
|
max_keepalive_connections=500,
|
|
keepalive_expiry=30.0,
|
|
)
|
|
async with httpx.AsyncClient(
|
|
timeout=config.request_timeout_s,
|
|
trust_env=False,
|
|
limits=limits,
|
|
) as client:
|
|
tasks = [
|
|
asyncio.create_task(_run_session(
|
|
state=_SessionState(session_id=sid, turns=turns),
|
|
config=config, client=client,
|
|
session_sem=session_sem, request_sem=request_sem,
|
|
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
|
sink=sink,
|
|
))
|
|
for sid, turns in sessions
|
|
]
|
|
all_results = await asyncio.gather(*tasks)
|
|
finally:
|
|
sink.close()
|
|
|
|
sweep_elapsed = time.perf_counter() - sweep_start
|
|
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
|
|
flat = [m for group in all_results for m in group]
|
|
summary_path = config.output_path.with_suffix(".summary.json")
|
|
write_summary_json(summary_path, flat)
|
|
|
|
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
|
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
|
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
|
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
|
|
|
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
|
logger.info("Prefix cache: %.1f%% hit ratio (%d/%d tokens)",
|
|
hit_ratio * 100, int(delta_hits), int(delta_queries))
|
|
|
|
# Append cache stats to summary
|
|
import json as _json
|
|
summary = _json.loads(summary_path.read_text())
|
|
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
|
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
|
summary["prefix_cache_hit_ratio"] = hit_ratio
|
|
summary["wall_clock_s"] = sweep_elapsed
|
|
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
|
|
|
logger.info("Summary written to %s", summary_path)
|
|
return flat
|