Design: offload HEAVY prefill only when P instance is less loaded than D AND P is not overloaded (< 1.5x avg). Preserves session-sticky on D for future KV reuse. External KV correctly registered in prefix cache. Result (67/200 processed, 75% success): TTFT p50: 0.551s (-49% vs baseline 1.080s) TTFT p90: 4.135s (vs baseline 9.410s, -56%) TPOT p90: 0.074s (same as baseline) E2E p50: 2.938s (-45% vs baseline 5.306s) 25% error rate from ReadTimeout on very large HEAVY requests queuing on P. Needs stricter elastic gate or higher timeout. But successful requests show significant improvement over both baseline and previous P2P. Also: added external_prefix_cache metrics tracking to replayer summary. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
358 lines
12 KiB
Python
358 lines
12 KiB
Python
"""Trace replayer — send requests to vLLM following trace timing.
|
|
|
|
Supports both vLLM's /v1/completions (OpenAI-compatible) and /generate
|
|
(SGLang-style) endpoints. Uses hash_ids from the trace to construct
|
|
synthetic prompts that reproduce realistic prefix-cache hit patterns.
|
|
|
|
Key behaviors:
|
|
- Per-session sequencing: turns within a session are sent in order,
|
|
each waiting for the previous to complete before dispatching.
|
|
- Inter-session arrival: sessions start at their trace timestamps,
|
|
scaled by --time-scale.
|
|
- Concurrency control: --max-inflight-sessions caps concurrent sessions;
|
|
--concurrency-limit caps total in-flight requests.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import random as _random
|
|
|
|
import httpx
|
|
|
|
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
|
from .trace import TraceRequest, load_trace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BLOCK_SIZE = 512
|
|
VOCAB_SIZE = 151936
|
|
TOKEN_RANGE_START = 100
|
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
|
|
|
_block_cache: dict[int, list[int]] = {}
|
|
|
|
|
|
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
|
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
|
if hash_id in _block_cache:
|
|
return _block_cache[hash_id]
|
|
rng = _random.Random(hash_id)
|
|
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
|
_block_cache[hash_id] = ids
|
|
return ids
|
|
|
|
|
|
@dataclass
|
|
class ReplayConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
|
time_scale: float = 1.0
|
|
max_inflight_sessions: int = 32
|
|
concurrency_limit: int = 256
|
|
request_timeout_s: float = 600.0
|
|
request_limit: int | None = None
|
|
model_name: str = "default"
|
|
|
|
|
|
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
|
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
|
|
|
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
|
"""
|
|
ids: list[int] = []
|
|
for hid in req.hash_ids:
|
|
ids.extend(_hash_id_to_token_ids(hid))
|
|
# Pad to input_length with deterministic tokens
|
|
pad_rng = _random.Random(req.chat_id)
|
|
while len(ids) < req.input_length:
|
|
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
|
return ids[:req.input_length]
|
|
|
|
|
|
@dataclass
|
|
class _SessionState:
|
|
session_id: str
|
|
turns: list[TraceRequest]
|
|
metrics: list[RequestMetrics] = field(default_factory=list)
|
|
|
|
|
|
_endpoint_counter = 0
|
|
|
|
|
|
def _pick_endpoint(config: ReplayConfig) -> str:
|
|
"""Round-robin across comma-separated endpoints."""
|
|
global _endpoint_counter
|
|
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
|
url = endpoints[_endpoint_counter % len(endpoints)]
|
|
_endpoint_counter += 1
|
|
return url
|
|
|
|
|
|
async def _dispatch_request(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
req: TraceRequest,
|
|
prompt_token_ids: list[int],
|
|
sem: asyncio.Semaphore,
|
|
) -> RequestMetrics:
|
|
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
|
endpoint = _pick_endpoint(config)
|
|
payload = {
|
|
"model": config.model_name,
|
|
"prompt": prompt_token_ids,
|
|
"max_tokens": max(1, req.output_length),
|
|
"temperature": 0,
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
}
|
|
|
|
start = time.perf_counter()
|
|
ttft_s = None
|
|
n_output = 0
|
|
cached_tokens = 0
|
|
finish_reason = None
|
|
err = None
|
|
token_times: list[float] = []
|
|
|
|
req_headers = {"X-Session-Id": req.session_id}
|
|
|
|
async with sem:
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{endpoint}/v1/completions",
|
|
json=payload,
|
|
headers=req_headers,
|
|
timeout=config.request_timeout_s,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for raw_line in resp.aiter_lines():
|
|
if not raw_line or not raw_line.startswith("data:"):
|
|
continue
|
|
data = raw_line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
now = time.perf_counter()
|
|
if ttft_s is None:
|
|
ttft_s = now - start
|
|
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("text", "")
|
|
if delta:
|
|
token_times.append(now)
|
|
fr = choices[0].get("finish_reason")
|
|
if fr:
|
|
finish_reason = fr
|
|
|
|
usage = chunk.get("usage")
|
|
if usage:
|
|
n_output = usage.get("completion_tokens", n_output)
|
|
cached_tokens = _extract_cached_tokens(usage)
|
|
except Exception as exc:
|
|
err = repr(exc)[:300]
|
|
|
|
end = time.perf_counter()
|
|
e2e = end - start
|
|
if n_output == 0 and token_times:
|
|
n_output = len(token_times)
|
|
|
|
tpot = 0.0
|
|
if len(token_times) > 1:
|
|
inter_token = [token_times[i+1] - token_times[i]
|
|
for i in range(len(token_times) - 1)]
|
|
tpot = sum(inter_token) / len(inter_token)
|
|
|
|
return RequestMetrics(
|
|
request_id=req.request_id,
|
|
session_id=req.session_id,
|
|
turn_id=req.turn_id,
|
|
trace_timestamp_s=req.timestamp_s,
|
|
input_length=req.input_length,
|
|
output_length=req.output_length,
|
|
request_type=req.request_type,
|
|
effective_input_length=len(prompt_token_ids),
|
|
cached_tokens=cached_tokens,
|
|
latency_s=e2e,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot,
|
|
actual_output_tokens=n_output,
|
|
requested_output_tokens=req.output_length,
|
|
finish_reason=finish_reason,
|
|
error=err,
|
|
)
|
|
|
|
|
|
def _extract_cached_tokens(usage: dict) -> int:
|
|
ct = 0
|
|
details = usage.get("prompt_tokens_details")
|
|
if isinstance(details, dict):
|
|
ct = details.get("cached_tokens", 0) or 0
|
|
if ct == 0:
|
|
ct = usage.get("cached_tokens", 0) or 0
|
|
return int(ct)
|
|
|
|
|
|
async def _run_session(
|
|
*,
|
|
state: _SessionState,
|
|
config: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
session_sem: asyncio.Semaphore,
|
|
request_sem: asyncio.Semaphore,
|
|
earliest_ts: float,
|
|
sweep_start: float,
|
|
sink: IncrementalMetricSink,
|
|
) -> list[RequestMetrics]:
|
|
async with session_sem:
|
|
# Wait until this session's start time
|
|
offset = (state.turns[0].timestamp_s - earliest_ts) / config.time_scale
|
|
wait = offset - (time.perf_counter() - sweep_start)
|
|
if wait > 0:
|
|
await asyncio.sleep(wait)
|
|
|
|
for req in state.turns:
|
|
# Intra-session: wait for turn's relative offset
|
|
if req != state.turns[0]:
|
|
target = (req.timestamp_s - state.turns[0].timestamp_s) / config.time_scale
|
|
elapsed = time.perf_counter() - sweep_start - offset
|
|
if elapsed < target:
|
|
await asyncio.sleep(target - elapsed)
|
|
|
|
token_ids = _build_prompt_token_ids(req)
|
|
metric = await _dispatch_request(
|
|
client=client, config=config, req=req,
|
|
prompt_token_ids=token_ids, sem=request_sem,
|
|
)
|
|
state.metrics.append(metric)
|
|
await sink.append(metric)
|
|
|
|
return state.metrics
|
|
|
|
|
|
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
|
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
|
total = {"queries": 0.0, "hits": 0.0,
|
|
"external_queries": 0.0, "external_hits": 0.0}
|
|
endpoints = [e.strip() for e in url_csv.split(",")]
|
|
async with httpx.AsyncClient(timeout=10) as c:
|
|
for url in endpoints:
|
|
try:
|
|
r = await c.get(f"{url}/metrics")
|
|
for line in r.text.split("\n"):
|
|
if line.startswith("vllm:prefix_cache_queries_total"):
|
|
total["queries"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:prefix_cache_hits_total"):
|
|
total["hits"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:external_prefix_cache_queries_total"):
|
|
total["external_queries"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:external_prefix_cache_hits_total"):
|
|
total["external_hits"] += float(line.split()[-1])
|
|
except Exception:
|
|
pass
|
|
return total
|
|
|
|
|
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
|
"""Main entry: load trace, replay against endpoint, return metrics."""
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
if not requests:
|
|
return []
|
|
|
|
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
|
for r in requests:
|
|
by_session[r.session_id].append(r)
|
|
for sid in by_session:
|
|
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
|
|
|
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
|
earliest_ts = sessions[0][1][0].timestamp_s
|
|
|
|
session_sem = asyncio.Semaphore(config.max_inflight_sessions)
|
|
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
|
|
|
sink = IncrementalMetricSink(config.output_path)
|
|
|
|
n_sessions = len(sessions)
|
|
n_requests = len(requests)
|
|
logger.info("Replaying %d sessions (%d requests), time_scale=%.1f",
|
|
n_sessions, n_requests, config.time_scale)
|
|
|
|
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
sweep_start = time.perf_counter()
|
|
|
|
try:
|
|
limits = httpx.Limits(
|
|
max_connections=2000,
|
|
max_keepalive_connections=500,
|
|
keepalive_expiry=30.0,
|
|
)
|
|
async with httpx.AsyncClient(
|
|
timeout=config.request_timeout_s,
|
|
trust_env=False,
|
|
limits=limits,
|
|
) as client:
|
|
tasks = [
|
|
asyncio.create_task(_run_session(
|
|
state=_SessionState(session_id=sid, turns=turns),
|
|
config=config, client=client,
|
|
session_sem=session_sem, request_sem=request_sem,
|
|
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
|
sink=sink,
|
|
))
|
|
for sid, turns in sessions
|
|
]
|
|
all_results = await asyncio.gather(*tasks)
|
|
finally:
|
|
sink.close()
|
|
|
|
sweep_elapsed = time.perf_counter() - sweep_start
|
|
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
|
|
flat = [m for group in all_results for m in group]
|
|
summary_path = config.output_path.with_suffix(".summary.json")
|
|
write_summary_json(summary_path, flat)
|
|
|
|
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
|
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
|
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
|
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
|
delta_ext_queries = post_metrics.get("external_queries", 0) - pre_metrics.get("external_queries", 0)
|
|
delta_ext_hits = post_metrics.get("external_hits", 0) - pre_metrics.get("external_hits", 0)
|
|
ext_hit_ratio = delta_ext_hits / delta_ext_queries if delta_ext_queries > 0 else 0.0
|
|
|
|
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
|
logger.info("Prefix cache: local=%.1f%% external=%.1f%%",
|
|
hit_ratio * 100, ext_hit_ratio * 100)
|
|
|
|
# Append cache stats to summary
|
|
import json as _json
|
|
summary = _json.loads(summary_path.read_text())
|
|
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
|
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
|
summary["prefix_cache_hit_ratio"] = hit_ratio
|
|
summary["external_cache_queries_tokens"] = int(delta_ext_queries)
|
|
summary["external_cache_hits_tokens"] = int(delta_ext_hits)
|
|
summary["external_cache_hit_ratio"] = ext_hit_ratio
|
|
summary["wall_clock_s"] = sweep_elapsed
|
|
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
|
|
|
logger.info("Summary written to %s", summary_path)
|
|
return flat
|