"""Trace replayer — send requests to vLLM following trace timing. Uses hash_ids from the trace to construct synthetic prompts that reproduce realistic prefix-cache hit patterns. Key behaviors: - Trace-driven dispatch: each request is sent at its trace timestamp. No artificial concurrency limits or time compression. - Per-session sequencing: turns within a session are sent in order, each waiting for the previous to complete before dispatching. If a turn completes after its successor's timestamp, the successor fires immediately (no waiting for a past timestamp). """ from __future__ import annotations import asyncio import json import logging import time from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Any import random as _random import httpx from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json from .trace import TraceRequest, load_trace logger = logging.getLogger(__name__) BLOCK_SIZE = 512 VOCAB_SIZE = 151936 TOKEN_RANGE_START = 100 TOKEN_RANGE_END = VOCAB_SIZE - 100 _block_cache: dict[int, list[int]] = {} def _hash_id_to_token_ids(hash_id: int) -> list[int]: """Deterministically map a hash_id to BLOCK_SIZE token IDs.""" if hash_id in _block_cache: return _block_cache[hash_id] rng = _random.Random(hash_id) ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)] _block_cache[hash_id] = ids return ids @dataclass class ReplayConfig: trace_path: Path output_path: Path endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001" concurrency_limit: int = 2000 request_timeout_s: float = 600.0 request_limit: int | None = None model_name: str = "default" max_inflight_sessions: int | None = None # cap on concurrent sessions; None = unlimited # Closed-loop think-time mode: if set, ignore absolute trace timestamps for # subsequent turns — fire turn 1 on session admission, then each later turn a # FIXED think-time after the previous turn COMPLETES. Combined with # max_inflight_sessions=N this is a stable N-user closed-loop (no open-loop # runaway), so it removes the "immediate retrigger under load" artifact. inter_turn_think_s: float | None = None # Dispatch timing for intra-session turns: # "tracets" (Mode 1): fire at absolute trace timestamp -> effectively # max(prev_finished, trace_ts); collapses think-time to 0 when # the system is behind (the amplification-inflation suspect). # "thinktime" (Mode 2): turn-1 at trace arrival; turn-k at # prev_finished + time_to_parent_chat (real production gap). dispatch_mode: str = "tracets" def _build_prompt_token_ids(req: TraceRequest) -> list[int]: """Build token IDs from hash_ids for prefix-cache-aware replay. Same hash_id prefix → same token ID prefix → APC cache hit in vLLM. """ ids: list[int] = [] for hid in req.hash_ids: ids.extend(_hash_id_to_token_ids(hid)) # Pad to input_length with deterministic tokens pad_rng = _random.Random(req.chat_id) while len(ids) < req.input_length: ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END)) return ids[:req.input_length] @dataclass class _SessionState: session_id: str turns: list[TraceRequest] metrics: list[RequestMetrics] = field(default_factory=list) @dataclass class _DispatchResult: metric: RequestMetrics output_token_ids: list[int] _endpoint_counter = 0 def _pick_endpoint(config: ReplayConfig) -> str: """Round-robin across comma-separated endpoints.""" global _endpoint_counter endpoints = [e.strip() for e in config.endpoint_url.split(",")] url = endpoints[_endpoint_counter % len(endpoints)] _endpoint_counter += 1 return url async def _dispatch_request( *, client: httpx.AsyncClient, config: ReplayConfig, req: TraceRequest, prompt_token_ids: list[int], sem: asyncio.Semaphore, ) -> _DispatchResult: """Send one request via /v1/completions (streaming) and collect metrics.""" endpoint = _pick_endpoint(config) target_output_tokens = max(1, req.output_length) payload = { "model": config.model_name, "prompt": prompt_token_ids, "max_tokens": target_output_tokens, "min_tokens": target_output_tokens, "temperature": 0, "return_token_ids": True, "stream": True, "stream_options": {"include_usage": True}, } start = time.perf_counter() t_dispatch_unix = time.time() t_first_token_unix: float | None = None ttft_s = None n_output = 0 cached_tokens = 0 finish_reason = None err = None token_times: list[float] = [] output_token_ids: list[int] = [] req_headers = { "X-Session-Id": req.session_id, "X-Request-Id": req.request_id, } async with sem: try: async with client.stream( "POST", f"{endpoint}/v1/completions", json=payload, headers=req_headers, timeout=config.request_timeout_s, ) as resp: resp.raise_for_status() async for raw_line in resp.aiter_lines(): if not raw_line or not raw_line.startswith("data:"): continue data = raw_line[5:].strip() if data == "[DONE]": break try: chunk = json.loads(data) except json.JSONDecodeError: continue choices = chunk.get("choices", []) if choices: now = time.perf_counter() delta = choices[0].get("text", "") chunk_token_ids = choices[0].get("token_ids") if isinstance(chunk_token_ids, list): clean_ids = [ int(t) for t in chunk_token_ids if isinstance(t, int) ] if clean_ids: if ttft_s is None: ttft_s = now - start t_first_token_unix = time.time() output_token_ids.extend(clean_ids) token_times.extend([now] * len(clean_ids)) elif delta: if ttft_s is None: ttft_s = now - start t_first_token_unix = time.time() token_times.append(now) fr = choices[0].get("finish_reason") if fr: finish_reason = fr usage = chunk.get("usage") if usage: n_output = usage.get("completion_tokens", n_output) cached_tokens = _extract_cached_tokens(usage) except Exception as exc: err = repr(exc)[:300] end = time.perf_counter() t_finish_unix = time.time() e2e = end - start if output_token_ids: n_output = len(output_token_ids) elif n_output == 0 and token_times: n_output = len(token_times) if err is None and n_output != target_output_tokens: err = ( "output_token_mismatch " f"requested={target_output_tokens} actual={n_output}" ) tpot = 0.0 if len(token_times) > 1: inter_token = [token_times[i+1] - token_times[i] for i in range(len(token_times) - 1)] tpot = sum(inter_token) / len(inter_token) return _DispatchResult( metric=RequestMetrics( request_id=req.request_id, session_id=req.session_id, turn_id=req.turn_id, trace_timestamp_s=req.timestamp_s, input_length=req.input_length, output_length=req.output_length, request_type=req.request_type, effective_input_length=len(prompt_token_ids), cached_tokens=cached_tokens, latency_s=e2e, ttft_s=ttft_s, tpot_s=tpot, actual_output_tokens=n_output, requested_output_tokens=req.output_length, finish_reason=finish_reason, error=err, t_dispatch_unix=t_dispatch_unix, t_first_token_unix=t_first_token_unix, t_finish_unix=t_finish_unix, proxy_request_id=req.request_id, endpoint_url=endpoint, trace_hash_ids=req.hash_ids, ), output_token_ids=output_token_ids, ) def _apply_realized_prefix( prompt_token_ids: list[int], realized_context: list[int], ) -> list[int]: """Replace the reusable session prefix with engine-realized tokens.""" if not realized_context: return prompt_token_ids out = prompt_token_ids.copy() n = min(len(out), len(realized_context)) out[:n] = realized_context[:n] return out def _extract_cached_tokens(usage: dict) -> int: ct = 0 details = usage.get("prompt_tokens_details") if isinstance(details, dict): ct = details.get("cached_tokens", 0) or 0 if ct == 0: ct = usage.get("cached_tokens", 0) or 0 return int(ct) async def _run_session( *, state: _SessionState, config: ReplayConfig, client: httpx.AsyncClient, request_sem: asyncio.Semaphore, earliest_ts: float, sweep_start: float, sink: IncrementalMetricSink, session_sem: asyncio.Semaphore | None = None, ) -> list[RequestMetrics]: if session_sem is not None: await session_sem.acquire() realized_context: list[int] = [] try: for turn_idx, req in enumerate(state.turns): if config.dispatch_mode == "thinktime": # Mode 2: turn-1 at absolute trace arrival (preserve session # schedule); later turns wait the REAL per-record think-time after # the previous turn completed -> no think-collapse under load. if turn_idx == 0: target_wall = (req.timestamp_s - earliest_ts) elapsed = time.perf_counter() - sweep_start if elapsed < target_wall: await asyncio.sleep(target_wall - elapsed) else: think = req.time_to_parent_chat_s await asyncio.sleep(think if think is not None else 0.0) elif config.inter_turn_think_s is not None: # Closed-loop: turn 1 fires on admission; later turns wait a fixed # think-time AFTER the previous turn completed (no absolute schedule, # so no "fire immediately because timestamp is in the past"). if turn_idx > 0: await asyncio.sleep(config.inter_turn_think_s) else: # Mode 1: dispatch at the request's absolute trace timestamp. target_wall = (req.timestamp_s - earliest_ts) elapsed = time.perf_counter() - sweep_start if elapsed < target_wall: await asyncio.sleep(target_wall - elapsed) token_ids = _apply_realized_prefix( _build_prompt_token_ids(req), realized_context, ) result = await _dispatch_request( client=client, config=config, req=req, prompt_token_ids=token_ids, sem=request_sem, ) metric = result.metric state.metrics.append(metric) await sink.append(metric) if metric.error is None: realized_context = token_ids + result.output_token_ids finally: if session_sem is not None: session_sem.release() return state.metrics async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]: """Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints).""" total = {"queries": 0.0, "hits": 0.0, "external_queries": 0.0, "external_hits": 0.0} endpoints = [e.strip() for e in url_csv.split(",")] async with httpx.AsyncClient(timeout=10) as c: for url in endpoints: try: r = await c.get(f"{url}/metrics") for line in r.text.split("\n"): if line.startswith("vllm:prefix_cache_queries_total"): total["queries"] += float(line.split()[-1]) elif line.startswith("vllm:prefix_cache_hits_total"): total["hits"] += float(line.split()[-1]) elif line.startswith("vllm:external_prefix_cache_queries_total"): total["external_queries"] += float(line.split()[-1]) elif line.startswith("vllm:external_prefix_cache_hits_total"): total["external_hits"] += float(line.split()[-1]) except Exception: pass return total async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]: """Main entry: load trace, replay against endpoint, return metrics.""" requests = load_trace(config.trace_path, request_limit=config.request_limit) if not requests: return [] by_session: dict[str, list[TraceRequest]] = defaultdict(list) for r in requests: by_session[r.session_id].append(r) for sid in by_session: by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s)) sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s) earliest_ts = sessions[0][1][0].timestamp_s latest_ts = max(r.timestamp_s for r in requests) trace_span = latest_ts - earliest_ts request_sem = asyncio.Semaphore(config.concurrency_limit) session_sem = ( asyncio.Semaphore(config.max_inflight_sessions) if config.max_inflight_sessions and config.max_inflight_sessions > 0 else None ) sink = IncrementalMetricSink(config.output_path) n_sessions = len(sessions) n_requests = len(requests) qps = n_requests / trace_span if trace_span > 0 else 0 logger.info("Replaying %d sessions (%d requests) over %.0fs (%.2f req/s)", n_sessions, n_requests, trace_span, qps) if session_sem is not None: logger.info("Session admission cap: %d concurrent sessions", config.max_inflight_sessions) pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url) sweep_start = time.perf_counter() try: limits = httpx.Limits( max_connections=2000, max_keepalive_connections=500, keepalive_expiry=30.0, ) async with httpx.AsyncClient( timeout=config.request_timeout_s, trust_env=False, limits=limits, ) as client: tasks = [ asyncio.create_task(_run_session( state=_SessionState(session_id=sid, turns=turns), config=config, client=client, request_sem=request_sem, earliest_ts=earliest_ts, sweep_start=sweep_start, sink=sink, session_sem=session_sem, )) for sid, turns in sessions ] all_results = await asyncio.gather(*tasks) finally: sink.close() sweep_elapsed = time.perf_counter() - sweep_start post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url) flat = [m for group in all_results for m in group] summary_path = config.output_path.with_suffix(".summary.json") write_summary_json(summary_path, flat) # Compute aggregate prefix cache hit ratio from /metrics deltas delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0) delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0) hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0 delta_ext_queries = post_metrics.get("external_queries", 0) - pre_metrics.get("external_queries", 0) delta_ext_hits = post_metrics.get("external_hits", 0) - pre_metrics.get("external_hits", 0) ext_hit_ratio = delta_ext_hits / delta_ext_queries if delta_ext_queries > 0 else 0.0 logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed) logger.info("Prefix cache: local=%.1f%% external=%.1f%%", hit_ratio * 100, ext_hit_ratio * 100) # Append cache stats to summary import json as _json summary = _json.loads(summary_path.read_text()) summary["prefix_cache_queries_tokens"] = int(delta_queries) summary["prefix_cache_hits_tokens"] = int(delta_hits) summary["prefix_cache_hit_ratio"] = hit_ratio summary["external_cache_queries_tokens"] = int(delta_ext_queries) summary["external_cache_hits_tokens"] = int(delta_ext_hits) summary["external_cache_hit_ratio"] = ext_hit_ratio summary["wall_clock_s"] = sweep_elapsed summary["trace_span_s"] = trace_span summary["amplification"] = sweep_elapsed / trace_span if trace_span > 0 else None summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True)) logger.info("Summary written to %s", summary_path) return flat