Files
agentic-kvc/replayer/replay.py
Gahow Wang fafc44da79 MB5 PD reuse-centric ablation: tooling, data, Fig 1-3
Three-axis controlled ablation of PD-colo vs PD-disagg on synthetic regular
traces (closed-loop, controlled reuse via REPLAY_NO_REALIZED_PREFIX) on the
clean stack (e13391e gated off).

  Axis 1 (Fig 1) -- reuse 6%->94% at N=8, in8192/out256
  Axis 2 (Fig 2) -- shape in2048/out2048 -> in32768/out64 at N=8, reuse~70%
  Axis 3 (Fig 3) -- concurrency N=8/16/32/64 at reuse~71%, in8192/out256

Findings:
  * APC parity colo=PD at every reuse (5.5/22/44/66/77/82%) -- contamination
    fix validated.
  * PD edge erodes 1.57x->1.10x with reuse; prefill GPUs strand 26%->9%.
  * Shape: PD-best peaks mid-sweep (1.34x at in8192/out512); wrong PD ratio
    catastrophic at prefill extreme (in32768/out64 pd2 = 378/400, p99 432s).
  * Concurrency: PD wins N<=32 (1.23-1.29x), TIPS at N=64 -- pd2/pd4
    crater (APC 71%->1.4%, TPS -30%) while colo scales cleanly.

Infrastructure:
  * replayer: --max-inflight-sessions, --inter-turn-think, --no-realized-prefix
    (env-defaulted via REPLAY_MAX_INFLIGHT, REPLAY_INTER_TURN_THINK_S,
    REPLAY_NO_REALIZED_PREFIX).
  * mb5_run.sh: writes bench_config.json + gpu_util.csv + run_window.json +
    instance_apc.txt + metrics.jsonl for bench_report/fig_agg ingest.
  * fig_agg.py: per-arm GPU role split + producer-side APC; --json mode.
  * gpu_util_report.py: companion per-GPU util report from gpu_util.csv.
  * partial_summary.py: stats from in-flight replay_metrics.jsonl
    (works before metrics.summary.json exists).

Data: analysis/mb5_pd_ablation/fig{1,2,3}.json (24 + 20 + 16 rows).
Figures: figs/mb5_pd_ablation/fig{1_reuse,2_shape,3_concurrency}_axis.png.
2026-05-31 20:14:46 +08:00

507 lines
20 KiB
Python

"""Trace replayer — send requests to vLLM following trace timing.
Uses hash_ids from the trace to construct synthetic prompts that
reproduce realistic prefix-cache hit patterns.
Key behaviors:
- Trace-driven dispatch: each request is sent at its trace timestamp.
No artificial concurrency limits or time compression.
- Per-session sequencing: turns within a session are sent in order,
each waiting for the previous to complete before dispatching.
If a turn completes after its successor's timestamp, the successor
fires immediately (no waiting for a past timestamp).
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import random as _random
import httpx
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
from .trace import TraceRequest, load_trace
logger = logging.getLogger(__name__)
BLOCK_SIZE = 512
VOCAB_SIZE = 151936
TOKEN_RANGE_START = 100
TOKEN_RANGE_END = VOCAB_SIZE - 100
_block_cache: dict[int, list[int]] = {}
def _hash_id_to_token_ids(hash_id: int) -> list[int]:
"""Deterministically map a hash_id to BLOCK_SIZE token IDs."""
if hash_id in _block_cache:
return _block_cache[hash_id]
rng = _random.Random(hash_id)
ids = [rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END) for _ in range(BLOCK_SIZE)]
_block_cache[hash_id] = ids
return ids
@dataclass
class ReplayConfig:
trace_path: Path
output_path: Path
endpoint_url: str # comma-separated for round-robin: "http://host:8000,http://host:8001"
concurrency_limit: int = 2000
request_timeout_s: float = 600.0
request_limit: int | None = None
model_name: str = "default"
max_inflight_sessions: int | None = None # cap on concurrent sessions; None = unlimited
# Closed-loop think-time mode: if set, ignore absolute trace timestamps for
# subsequent turns — fire turn 1 on session admission, then each later turn a
# FIXED think-time after the previous turn COMPLETES. Combined with
# max_inflight_sessions=N this is a stable N-user closed-loop (no open-loop
# runaway), so it removes the "immediate retrigger under load" artifact.
inter_turn_think_s: float | None = None
# Controlled-reuse mode: skip _apply_realized_prefix so each turn's prompt is
# exactly the hash-built tokens. Then prefix-cache reuse is governed solely by
# the generated hash_ids (shared prefix blocks hit, fresh delta blocks miss) —
# required for the reuse-fraction sweep, where realized-prefix would otherwise
# force every fixed-length turn to ≈ the prior turn (≈100% reuse regardless).
# Keep OFF (realized-prefix ON) for the real agentic trace.
no_realized_prefix: bool = False
# Dispatch timing for intra-session turns:
# "tracets" (Mode 1): fire at absolute trace timestamp -> effectively
# max(prev_finished, trace_ts); collapses think-time to 0 when
# the system is behind (the amplification-inflation suspect).
# "thinktime" (Mode 2): turn-1 at trace arrival; turn-k at
# prev_finished + time_to_parent_chat (real production gap).
dispatch_mode: str = "tracets"
# Overall wall-clock deadline for the whole replay (seconds). When exceeded,
# stop awaiting in-flight sessions, cancel them, and write the summary over
# whatever completed — un-run turns are counted as failures so completion%
# stays honest (request_count == full trace). None = no deadline (default,
# original behavior unchanged). Used to bound the slow drain of a collapsed
# config in a sweep. Env: REPLAY_MAX_DURATION.
max_duration_s: float | None = None
def _skipped_metric() -> "RequestMetrics":
"""Placeholder failure row for a turn never run due to a max_duration cutoff.
Only its error (non-None) matters: it counts toward request/error totals but
is excluded from latency/ttft/tpot percentiles (successes only)."""
return RequestMetrics(
request_id="deadline_skipped", session_id="", turn_id=-1,
trace_timestamp_s=0.0, input_length=0, output_length=0,
request_type="skipped", effective_input_length=None, cached_tokens=0,
latency_s=None, ttft_s=None, tpot_s=None, error="deadline_skipped",
)
def _build_prompt_token_ids(req: TraceRequest) -> list[int]:
"""Build token IDs from hash_ids for prefix-cache-aware replay.
Same hash_id prefix → same token ID prefix → APC cache hit in vLLM.
"""
ids: list[int] = []
for hid in req.hash_ids:
ids.extend(_hash_id_to_token_ids(hid))
# Pad to input_length with deterministic tokens
pad_rng = _random.Random(req.chat_id)
while len(ids) < req.input_length:
ids.append(pad_rng.randint(TOKEN_RANGE_START, TOKEN_RANGE_END))
return ids[:req.input_length]
@dataclass
class _SessionState:
session_id: str
turns: list[TraceRequest]
metrics: list[RequestMetrics] = field(default_factory=list)
@dataclass
class _DispatchResult:
metric: RequestMetrics
output_token_ids: list[int]
_endpoint_counter = 0
def _pick_endpoint(config: ReplayConfig) -> str:
"""Round-robin across comma-separated endpoints."""
global _endpoint_counter
endpoints = [e.strip() for e in config.endpoint_url.split(",")]
url = endpoints[_endpoint_counter % len(endpoints)]
_endpoint_counter += 1
return url
async def _dispatch_request(
*,
client: httpx.AsyncClient,
config: ReplayConfig,
req: TraceRequest,
prompt_token_ids: list[int],
sem: asyncio.Semaphore,
) -> _DispatchResult:
"""Send one request via /v1/completions (streaming) and collect metrics."""
endpoint = _pick_endpoint(config)
target_output_tokens = max(1, req.output_length)
payload = {
"model": config.model_name,
"prompt": prompt_token_ids,
"max_tokens": target_output_tokens,
"min_tokens": target_output_tokens,
"temperature": 0,
"return_token_ids": True,
"stream": True,
"stream_options": {"include_usage": True},
}
start = time.perf_counter()
t_dispatch_unix = time.time()
t_first_token_unix: float | None = None
ttft_s = None
n_output = 0
cached_tokens = 0
finish_reason = None
err = None
token_times: list[float] = []
output_token_ids: list[int] = []
req_headers = {
"X-Session-Id": req.session_id,
"X-Request-Id": req.request_id,
}
async with sem:
try:
async with client.stream(
"POST",
f"{endpoint}/v1/completions",
json=payload,
headers=req_headers,
timeout=config.request_timeout_s,
) as resp:
resp.raise_for_status()
async for raw_line in resp.aiter_lines():
if not raw_line or not raw_line.startswith("data:"):
continue
data = raw_line[5:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except json.JSONDecodeError:
continue
choices = chunk.get("choices", [])
if choices:
now = time.perf_counter()
delta = choices[0].get("text", "")
chunk_token_ids = choices[0].get("token_ids")
if isinstance(chunk_token_ids, list):
clean_ids = [
int(t) for t in chunk_token_ids
if isinstance(t, int)
]
if clean_ids:
if ttft_s is None:
ttft_s = now - start
t_first_token_unix = time.time()
output_token_ids.extend(clean_ids)
token_times.extend([now] * len(clean_ids))
elif delta:
if ttft_s is None:
ttft_s = now - start
t_first_token_unix = time.time()
token_times.append(now)
fr = choices[0].get("finish_reason")
if fr:
finish_reason = fr
usage = chunk.get("usage")
if usage:
n_output = usage.get("completion_tokens", n_output)
cached_tokens = _extract_cached_tokens(usage)
except Exception as exc:
err = repr(exc)[:300]
end = time.perf_counter()
t_finish_unix = time.time()
e2e = end - start
if output_token_ids:
n_output = len(output_token_ids)
elif n_output == 0 and token_times:
n_output = len(token_times)
if err is None and n_output != target_output_tokens:
err = (
"output_token_mismatch "
f"requested={target_output_tokens} actual={n_output}"
)
tpot = 0.0
if len(token_times) > 1:
inter_token = [token_times[i+1] - token_times[i]
for i in range(len(token_times) - 1)]
tpot = sum(inter_token) / len(inter_token)
return _DispatchResult(
metric=RequestMetrics(
request_id=req.request_id,
session_id=req.session_id,
turn_id=req.turn_id,
trace_timestamp_s=req.timestamp_s,
input_length=req.input_length,
output_length=req.output_length,
request_type=req.request_type,
effective_input_length=len(prompt_token_ids),
cached_tokens=cached_tokens,
latency_s=e2e,
ttft_s=ttft_s,
tpot_s=tpot,
actual_output_tokens=n_output,
requested_output_tokens=req.output_length,
finish_reason=finish_reason,
error=err,
t_dispatch_unix=t_dispatch_unix,
t_first_token_unix=t_first_token_unix,
t_finish_unix=t_finish_unix,
proxy_request_id=req.request_id,
endpoint_url=endpoint,
trace_hash_ids=req.hash_ids,
),
output_token_ids=output_token_ids,
)
def _apply_realized_prefix(
prompt_token_ids: list[int],
realized_context: list[int],
) -> list[int]:
"""Replace the reusable session prefix with engine-realized tokens."""
if not realized_context:
return prompt_token_ids
out = prompt_token_ids.copy()
n = min(len(out), len(realized_context))
out[:n] = realized_context[:n]
return out
def _extract_cached_tokens(usage: dict) -> int:
ct = 0
details = usage.get("prompt_tokens_details")
if isinstance(details, dict):
ct = details.get("cached_tokens", 0) or 0
if ct == 0:
ct = usage.get("cached_tokens", 0) or 0
return int(ct)
async def _run_session(
*,
state: _SessionState,
config: ReplayConfig,
client: httpx.AsyncClient,
request_sem: asyncio.Semaphore,
earliest_ts: float,
sweep_start: float,
sink: IncrementalMetricSink,
session_sem: asyncio.Semaphore | None = None,
) -> list[RequestMetrics]:
if session_sem is not None:
await session_sem.acquire()
realized_context: list[int] = []
try:
for turn_idx, req in enumerate(state.turns):
if config.dispatch_mode == "thinktime":
# Mode 2: turn-1 at absolute trace arrival (preserve session
# schedule); later turns wait the REAL per-record think-time after
# the previous turn completed -> no think-collapse under load.
if turn_idx == 0:
target_wall = (req.timestamp_s - earliest_ts)
elapsed = time.perf_counter() - sweep_start
if elapsed < target_wall:
await asyncio.sleep(target_wall - elapsed)
else:
think = req.time_to_parent_chat_s
await asyncio.sleep(think if think is not None else 0.0)
elif config.inter_turn_think_s is not None:
# Closed-loop: turn 1 fires on admission; later turns wait a fixed
# think-time AFTER the previous turn completed (no absolute schedule,
# so no "fire immediately because timestamp is in the past").
if turn_idx > 0:
await asyncio.sleep(config.inter_turn_think_s)
else:
# Mode 1: dispatch at the request's absolute trace timestamp.
target_wall = (req.timestamp_s - earliest_ts)
elapsed = time.perf_counter() - sweep_start
if elapsed < target_wall:
await asyncio.sleep(target_wall - elapsed)
token_ids = _build_prompt_token_ids(req)
if not config.no_realized_prefix:
token_ids = _apply_realized_prefix(token_ids, realized_context)
result = await _dispatch_request(
client=client, config=config, req=req,
prompt_token_ids=token_ids, sem=request_sem,
)
metric = result.metric
state.metrics.append(metric)
await sink.append(metric)
if metric.error is None:
realized_context = token_ids + result.output_token_ids
finally:
if session_sem is not None:
session_sem.release()
return state.metrics
async def _snapshot_prefix_cache_metrics(url_csv: str) -> dict[str, float]:
"""Scrape vLLM /metrics for prefix cache counters (aggregated across endpoints)."""
total = {"queries": 0.0, "hits": 0.0,
"external_queries": 0.0, "external_hits": 0.0}
endpoints = [e.strip() for e in url_csv.split(",")]
async with httpx.AsyncClient(timeout=10) as c:
for url in endpoints:
try:
r = await c.get(f"{url}/metrics")
for line in r.text.split("\n"):
if line.startswith("vllm:prefix_cache_queries_total"):
total["queries"] += float(line.split()[-1])
elif line.startswith("vllm:prefix_cache_hits_total"):
total["hits"] += float(line.split()[-1])
elif line.startswith("vllm:external_prefix_cache_queries_total"):
total["external_queries"] += float(line.split()[-1])
elif line.startswith("vllm:external_prefix_cache_hits_total"):
total["external_hits"] += float(line.split()[-1])
except Exception:
pass
return total
async def replay_trace(config: ReplayConfig) -> list[RequestMetrics]:
"""Main entry: load trace, replay against endpoint, return metrics."""
requests = load_trace(config.trace_path, request_limit=config.request_limit)
if not requests:
return []
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for r in requests:
by_session[r.session_id].append(r)
for sid in by_session:
by_session[sid].sort(key=lambda r: (r.turn_id, r.timestamp_s))
sessions = sorted(by_session.items(), key=lambda kv: kv[1][0].timestamp_s)
earliest_ts = sessions[0][1][0].timestamp_s
latest_ts = max(r.timestamp_s for r in requests)
trace_span = latest_ts - earliest_ts
request_sem = asyncio.Semaphore(config.concurrency_limit)
session_sem = (
asyncio.Semaphore(config.max_inflight_sessions)
if config.max_inflight_sessions and config.max_inflight_sessions > 0
else None
)
sink = IncrementalMetricSink(config.output_path)
n_sessions = len(sessions)
n_requests = len(requests)
qps = n_requests / trace_span if trace_span > 0 else 0
logger.info("Replaying %d sessions (%d requests) over %.0fs (%.2f req/s)",
n_sessions, n_requests, trace_span, qps)
if session_sem is not None:
logger.info("Session admission cap: %d concurrent sessions",
config.max_inflight_sessions)
pre_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
sweep_start = time.perf_counter()
try:
limits = httpx.Limits(
max_connections=2000,
max_keepalive_connections=500,
keepalive_expiry=30.0,
)
async with httpx.AsyncClient(
timeout=config.request_timeout_s,
trust_env=False,
limits=limits,
) as client:
states = [_SessionState(session_id=sid, turns=turns)
for sid, turns in sessions]
tasks = [
asyncio.create_task(_run_session(
state=st, config=config, client=client,
request_sem=request_sem,
earliest_ts=earliest_ts, sweep_start=sweep_start,
sink=sink,
session_sem=session_sem,
))
for st in states
]
if config.max_duration_s and config.max_duration_s > 0:
_done, pending = await asyncio.wait(
tasks, timeout=config.max_duration_s)
if pending:
logger.warning(
"max_duration %.0fs reached: cancelling %d in-flight "
"session(s); un-run turns counted as failures",
config.max_duration_s, len(pending))
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
else:
await asyncio.gather(*tasks)
finally:
sink.close()
sweep_elapsed = time.perf_counter() - sweep_start
post_metrics = await _snapshot_prefix_cache_metrics(config.endpoint_url)
# Build from the session states (identical to the gather return in the
# uncapped path) so partially-completed (cancelled) sessions still contribute
# their finished turns; pad un-run turns as failures so request_count == trace.
flat = [m for st in states for m in st.metrics]
missing = n_requests - len(flat)
if missing > 0:
flat.extend(_skipped_metric() for _ in range(missing))
summary_path = config.output_path.with_suffix(".summary.json")
write_summary_json(summary_path, flat)
# Compute aggregate prefix cache hit ratio from /metrics deltas
delta_queries = post_metrics.get("queries", 0) - pre_metrics.get("queries", 0)
delta_hits = post_metrics.get("hits", 0) - pre_metrics.get("hits", 0)
hit_ratio = delta_hits / delta_queries if delta_queries > 0 else 0.0
delta_ext_queries = post_metrics.get("external_queries", 0) - pre_metrics.get("external_queries", 0)
delta_ext_hits = post_metrics.get("external_hits", 0) - pre_metrics.get("external_hits", 0)
ext_hit_ratio = delta_ext_hits / delta_ext_queries if delta_ext_queries > 0 else 0.0
logger.info("Done: %d/%d succeeded in %.1fs", sum(1 for m in flat if m.error is None), len(flat), sweep_elapsed)
logger.info("Prefix cache: local=%.1f%% external=%.1f%%",
hit_ratio * 100, ext_hit_ratio * 100)
# Append cache stats to summary
import json as _json
summary = _json.loads(summary_path.read_text())
summary["prefix_cache_queries_tokens"] = int(delta_queries)
summary["prefix_cache_hits_tokens"] = int(delta_hits)
summary["prefix_cache_hit_ratio"] = hit_ratio
summary["external_cache_queries_tokens"] = int(delta_ext_queries)
summary["external_cache_hits_tokens"] = int(delta_ext_hits)
summary["external_cache_hit_ratio"] = ext_hit_ratio
summary["wall_clock_s"] = sweep_elapsed
summary["trace_span_s"] = trace_span
summary["amplification"] = sweep_elapsed / trace_span if trace_span > 0 else None
summary_path.write_text(_json.dumps(summary, indent=2, sort_keys=True))
logger.info("Summary written to %s", summary_path)
return flat