Add --inter-turn-think (env REPLAY_INTER_TURN_THINK_S): turn 1 fires on session admission, each later turn a FIXED think-time after the previous turn COMPLETES, ignoring absolute trace timestamps. Combined with --max-inflight-sessions (env REPLAY_MAX_INFLIGHT) this is a stable N-user closed loop, removing the open-loop "fire immediately because timestamp is in the past" retrigger artifact. Needed for the dispatch-coupling (wall-clock amplification) sweep. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""Trace replayer — send requests to vLLM following trace timing.
|
|
|
|
Uses hash_ids from the trace to construct synthetic prompts that
|
|
reproduce realistic prefix-cache hit patterns.
|
|
|
|
Key behaviors:
|
|
- Trace-driven dispatch: each request is sent at its trace timestamp.
|
|
No artificial concurrency limits or time compression.
|
|
- Per-session sequencing: turns within a session are sent in order,
|
|
each waiting for the previous to complete before dispatching.
|
|
If a turn completes after its successor's timestamp, the successor
|
|
fires immediately (no waiting for a past timestamp).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import random as _random
|
|
|
|
import httpx
|
|
|
|
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
|
from .trace import TraceRequest, load_trace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BLOCK_SIZE = 512
|
|
VOCAB_SIZE = 151936
|
|
TOKEN_RANGE_START = 100
|
|
TOKEN_RANGE_END = VOCAB_SIZE - 100
|
|
|
|
_block_cache: dict[int, list[int]] = {}
|
|
|
|
|
|
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
|
|
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
|
|
if hash_id in _block_cache:
|
|
return _block_cache[hash_id]
|
|
rng = _random.Random(hash_id)
|
|
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
|
|
_block_cache[hash_id] = ids
|
|
return ids
|
|
|
|
|
|
@dataclass
|
|
class ReplayConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
|
|
concurrency_limit: int = 2000
|
|
request_timeout_s: float = 600.0
|
|
request_limit: int | None = None
|
|
model_name: str = "default"
|
|
max_inflight_sessions: int | None = None # cap on concurrent sessions; None = unlimited
|
|
# Closed-loop think-time mode: if set, ignore absolute trace timestamps for
|
|
# subsequent turns — fire turn 1 on session admission, then each later turn a
|
|
# FIXED think-time after the previous turn COMPLETES. Combined with
|
|
# max_inflight_sessions=N this is a stable N-user closed-loop (no open-loop
|
|
# runaway), so it removes the "immediate retrigger under load" artifact.
|
|
inter_turn_think_s: float | None = None
|
|
|
|
|
|
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
|
|
"""Build token IDs from hash_ids for prefix-cache-aware replay.
|
|
|
|
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
|
|
"""
|
|
ids: list[int] = []
|
|
for hid in req.hash_ids:
|
|
ids.extend(_hash_id_to_token_ids(hid))
|
|
# Pad to input_length with deterministic tokens
|
|
pad_rng = _random.Random(req.chat_id)
|
|
while len(ids) < req.input_length:
|
|
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
|
|
return ids[:req.input_length]
|
|
|
|
|
|
@dataclass
|
|
class _SessionState:
|
|
session_id: str
|
|
turns: list[TraceRequest]
|
|
metrics: list[RequestMetrics] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class _DispatchResult:
|
|
metric: RequestMetrics
|
|
output_token_ids: list[int]
|
|
|
|
|
|
_endpoint_counter = 0
|
|
|
|
|
|
def _pick_endpoint(config: ReplayConfig) -> str:
|
|
"""Round-robin across comma-separated endpoints."""
|
|
global _endpoint_counter
|
|
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
|
|
url = endpoints[_endpoint_counter % len(endpoints)]
|
|
_endpoint_counter += 1
|
|
return url
|
|
|
|
|
|
async def _dispatch_request(
|
|
*,
|
|
client: httpx.AsyncClient,
|
|
config: ReplayConfig,
|
|
req: TraceRequest,
|
|
prompt_token_ids: list[int],
|
|
sem: asyncio.Semaphore,
|
|
) -> _DispatchResult:
|
|
"""Send one request via /v1/completions (streaming) and collect metrics."""
|
|
endpoint = _pick_endpoint(config)
|
|
target_output_tokens = max(1, req.output_length)
|
|
payload = {
|
|
"model": config.model_name,
|
|
"prompt": prompt_token_ids,
|
|
"max_tokens": target_output_tokens,
|
|
"min_tokens": target_output_tokens,
|
|
"temperature": 0,
|
|
"return_token_ids": True,
|
|
"stream": True,
|
|
"stream_options": {"include_usage": True},
|
|
}
|
|
|
|
start = time.perf_counter()
|
|
t_dispatch_unix = time.time()
|
|
t_first_token_unix: float | None = None
|
|
ttft_s = None
|
|
n_output = 0
|
|
cached_tokens = 0
|
|
finish_reason = None
|
|
err = None
|
|
token_times: list[float] = []
|
|
output_token_ids: list[int] = []
|
|
|
|
req_headers = {
|
|
"X-Session-Id": req.session_id,
|
|
"X-Request-Id": req.request_id,
|
|
}
|
|
|
|
async with sem:
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{endpoint}/v1/completions",
|
|
json=payload,
|
|
headers=req_headers,
|
|
timeout=config.request_timeout_s,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for raw_line in resp.aiter_lines():
|
|
if not raw_line or not raw_line.startswith("data:"):
|
|
continue
|
|
data = raw_line[5:].strip()
|
|
if data == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(data)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
now = time.perf_counter()
|
|
delta = choices[0].get("text", "")
|
|
chunk_token_ids = choices[0].get("token_ids")
|
|
if isinstance(chunk_token_ids, list):
|
|
clean_ids = [
|
|
int(t) for t in chunk_token_ids
|
|
if isinstance(t, int)
|
|
]
|
|
if clean_ids:
|
|
if ttft_s is None:
|
|
ttft_s = now - start
|
|
t_first_token_unix = time.time()
|
|
output_token_ids.extend(clean_ids)
|
|
token_times.extend([now] * len(clean_ids))
|
|
elif delta:
|
|
if ttft_s is None:
|
|
ttft_s = now - start
|
|
t_first_token_unix = time.time()
|
|
token_times.append(now)
|
|
fr = choices[0].get("finish_reason")
|
|
if fr:
|
|
finish_reason = fr
|
|
|
|
usage = chunk.get("usage")
|
|
if usage:
|
|
n_output = usage.get("completion_tokens", n_output)
|
|
cached_tokens = _extract_cached_tokens(usage)
|
|
except Exception as exc:
|
|
err = repr(exc)[:300]
|
|
|
|
end = time.perf_counter()
|
|
t_finish_unix = time.time()
|
|
e2e = end - start
|
|
if output_token_ids:
|
|
n_output = len(output_token_ids)
|
|
elif n_output == 0 and token_times:
|
|
n_output = len(token_times)
|
|
if err is None and n_output != target_output_tokens:
|
|
err = (
|
|
"output_token_mismatch "
|
|
f"requested={target_output_tokens} actual={n_output}"
|
|
)
|
|
|
|
tpot = 0.0
|
|
if len(token_times) > 1:
|
|
inter_token = [token_times[i+1] - token_times[i]
|
|
for i in range(len(token_times) - 1)]
|
|
tpot = sum(inter_token) / len(inter_token)
|
|
|
|
return _DispatchResult(
|
|
metric=RequestMetrics(
|
|
request_id=req.request_id,
|
|
session_id=req.session_id,
|
|
turn_id=req.turn_id,
|
|
trace_timestamp_s=req.timestamp_s,
|
|
input_length=req.input_length,
|
|
output_length=req.output_length,
|
|
request_type=req.request_type,
|
|
effective_input_length=len(prompt_token_ids),
|
|
cached_tokens=cached_tokens,
|
|
latency_s=e2e,
|
|
ttft_s=ttft_s,
|
|
tpot_s=tpot,
|
|
actual_output_tokens=n_output,
|
|
requested_output_tokens=req.output_length,
|
|
finish_reason=finish_reason,
|
|
error=err,
|
|
t_dispatch_unix=t_dispatch_unix,
|
|
t_first_token_unix=t_first_token_unix,
|
|
t_finish_unix=t_finish_unix,
|
|
proxy_request_id=req.request_id,
|
|
endpoint_url=endpoint,
|
|
trace_hash_ids=req.hash_ids,
|
|
),
|
|
output_token_ids=output_token_ids,
|
|
)
|
|
|
|
|
|
def _apply_realized_prefix(
|
|
prompt_token_ids: list[int],
|
|
realized_context: list[int],
|
|
) -> list[int]:
|
|
"""Replace the reusable session prefix with engine-realized tokens."""
|
|
if not realized_context:
|
|
return prompt_token_ids
|
|
out = prompt_token_ids.copy()
|
|
n = min(len(out), len(realized_context))
|
|
out[:n] = realized_context[:n]
|
|
return out
|
|
|
|
|
|
def _extract_cached_tokens(usage: dict) -> int:
|
|
ct = 0
|
|
details = usage.get("prompt_tokens_details")
|
|
if isinstance(details, dict):
|
|
ct = details.get("cached_tokens", 0) or 0
|
|
if ct == 0:
|
|
ct = usage.get("cached_tokens", 0) or 0
|
|
return int(ct)
|
|
|
|
|
|
async def _run_session(
|
|
*,
|
|
state: _SessionState,
|
|
config: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
request_sem: asyncio.Semaphore,
|
|
earliest_ts: float,
|
|
sweep_start: float,
|
|
sink: IncrementalMetricSink,
|
|
session_sem: asyncio.Semaphore | None = None,
|
|
) -> list[RequestMetrics]:
|
|
if session_sem is not None:
|
|
await session_sem.acquire()
|
|
realized_context: list[int] = []
|
|
try:
|
|
for turn_idx, req in enumerate(state.turns):
|
|
if config.inter_turn_think_s is not None:
|
|
# Closed-loop: turn 1 fires on admission; later turns wait a fixed
|
|
# think-time AFTER the previous turn completed (no absolute schedule,
|
|
# so no "fire immediately because timestamp is in the past").
|
|
if turn_idx > 0:
|
|
await asyncio.sleep(config.inter_turn_think_s)
|
|
else:
|
|
# Original: dispatch at the request's absolute trace timestamp.
|
|
target_wall = (req.timestamp_s - earliest_ts)
|
|
elapsed = time.perf_counter() - sweep_start
|
|
if elapsed < target_wall:
|
|
await asyncio.sleep(target_wall - elapsed)
|
|
|
|
token_ids = _apply_realized_prefix(
|
|
_build_prompt_token_ids(req),
|
|
realized_context,
|
|
)
|
|
result = await _dispatch_request(
|
|
client=client, config=config, req=req,
|
|
prompt_token_ids=token_ids, sem=request_sem,
|
|
)
|
|
metric = result.metric
|
|
state.metrics.append(metric)
|
|
await sink.append(metric)
|
|
if metric.error is None:
|
|
realized_context = token_ids + result.output_token_ids
|
|
finally:
|
|
if session_sem is not None:
|
|
session_sem.release()
|
|
|
|
return state.metrics
|
|
|
|
|
|
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
|
|
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
|
|
total = {"queries": 0.0, "hits": 0.0,
|
|
"external_queries": 0.0, "external_hits": 0.0}
|
|
endpoints = [e.strip() for e in url_csv.split(",")]
|
|
async with httpx.AsyncClient(timeout=10) as c:
|
|
for url in endpoints:
|
|
try:
|
|
r = await c.get(f"{url}/metrics")
|
|
for line in r.text.split("\n"):
|
|
if line.startswith("vllm:prefix_cache_queries_total"):
|
|
total["queries"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:prefix_cache_hits_total"):
|
|
total["hits"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:external_prefix_cache_queries_total"):
|
|
total["external_queries"] += float(line.split()[-1])
|
|
elif line.startswith("vllm:external_prefix_cache_hits_total"):
|
|
total["external_hits"] += float(line.split()[-1])
|
|
except Exception:
|
|
pass
|
|
return total
|
|
|
|
|
|
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
|
|
"""Main entry: load trace, replay against endpoint, return metrics."""
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
if not requests:
|
|
return []
|
|
|
|
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
|
for r in requests:
|
|
by_session[r.session_id].append(r)
|
|
for sid in by_session:
|
|
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
|
|
|
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
|
|
earliest_ts = sessions[0][1][0].timestamp_s
|
|
latest_ts = max(r.timestamp_s for r in requests)
|
|
trace_span = latest_ts - earliest_ts
|
|
|
|
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
|
session_sem = (
|
|
asyncio.Semaphore(config.max_inflight_sessions)
|
|
if config.max_inflight_sessions and config.max_inflight_sessions > 0
|
|
else None
|
|
)
|
|
|
|
sink = IncrementalMetricSink(config.output_path)
|
|
|
|
n_sessions = len(sessions)
|
|
n_requests = len(requests)
|
|
qps = n_requests / trace_span if trace_span > 0 else 0
|
|
logger.info("Replaying %d sessions (%d requests) over %.0fs (%.2f req/s)",
|
|
n_sessions, n_requests, trace_span, qps)
|
|
if session_sem is not None:
|
|
logger.info("Session admission cap: %d concurrent sessions",
|
|
config.max_inflight_sessions)
|
|
|
|
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
sweep_start = time.perf_counter()
|
|
|
|
try:
|
|
limits = httpx.Limits(
|
|
max_connections=2000,
|
|
max_keepalive_connections=500,
|
|
keepalive_expiry=30.0,
|
|
)
|
|
async with httpx.AsyncClient(
|
|
timeout=config.request_timeout_s,
|
|
trust_env=False,
|
|
limits=limits,
|
|
) as client:
|
|
tasks = [
|
|
asyncio.create_task(_run_session(
|
|
state=_SessionState(session_id=sid, turns=turns),
|
|
config=config, client=client,
|
|
request_sem=request_sem,
|
|
earliest_ts=earliest_ts, sweep_start=sweep_start,
|
|
sink=sink,
|
|
session_sem=session_sem,
|
|
))
|
|
for sid, turns in sessions
|
|
]
|
|
all_results = await asyncio.gather(*tasks)
|
|
finally:
|
|
sink.close()
|
|
|
|
sweep_elapsed = time.perf_counter() - sweep_start
|
|
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
|
|
|
|
flat = [m for group in all_results for m in group]
|
|
summary_path = config.output_path.with_suffix(".summary.json")
|
|
write_summary_json(summary_path, flat)
|
|
|
|
# Compute aggregate prefix cache hit ratio from /metrics deltas
|
|
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
|
|
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
|
|
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
|
|
delta_ext_queries = post_metrics.get("external_queries", 0) - pre_metrics.get("external_queries", 0)
|
|
delta_ext_hits = post_metrics.get("external_hits", 0) - pre_metrics.get("external_hits", 0)
|
|
ext_hit_ratio = delta_ext_hits / delta_ext_queries if delta_ext_queries > 0 else 0.0
|
|
|
|
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
|
|
logger.info("Prefix cache: local=%.1f%% external=%.1f%%",
|
|
hit_ratio * 100, ext_hit_ratio * 100)
|
|
|
|
# Append cache stats to summary
|
|
import json as _json
|
|
summary = _json.loads(summary_path.read_text())
|
|
summary["prefix_cache_queries_tokens"] = int(delta_queries)
|
|
summary["prefix_cache_hits_tokens"] = int(delta_hits)
|
|
summary["prefix_cache_hit_ratio"] = hit_ratio
|
|
summary["external_cache_queries_tokens"] = int(delta_ext_queries)
|
|
summary["external_cache_hits_tokens"] = int(delta_ext_hits)
|
|
summary["external_cache_hit_ratio"] = ext_hit_ratio
|
|
summary["wall_clock_s"] = sweep_elapsed
|
|
summary["trace_span_s"] = trace_span
|
|
summary["amplification"] = sweep_elapsed / trace_span if trace_span > 0 else None
|
|
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
|
|
|
|
logger.info("Summary written to %s", summary_path)
|
|
return flat
|