A4: open-loop session-causal SRR loadgen

New replayer/srr.py drives a Poisson session-arrival load against the
existing proxy, with strict per-session turn sequentiality, explicit
warmup/steady/drain windows, and per-arrival fresh session_id +
request_id so APC/session-affinity counters are not contaminated by
repeated draws from the trace pool. Writes window_summary.json with
attempted/completed/errored split by window so latency tails can be
read on the steady-state window only.

Required by Batch 4 SRR sweep; trace-timestamp dispatch in replay.py
cannot drive arrival rate independently.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 16:19:20 +08:00
parent 5816aad731
commit f42c715ec1
2 changed files with 420 additions and 0 deletions

320
replayer/srr.py Normal file
View File

@@ -0,0 +1,320 @@
"""Open-loop session-causal SRR (Sustainable Request Rate) loadgen.
Differs from `replayer.replay` in three ways:
- Sessions arrive at Poisson rate `lambda` (sessions/s) independent of the
trace timestamps. The trace is treated as a *pool* of session templates;
each arrival picks one and replays its turns sequentially.
- A warmup window and a steady-state window are explicit; per-window
attempted / completed / error counters are written to
`window_summary.json` so latency tails are computable on steady state
only.
- Each arrival gets a fresh `session_id` and per-turn `request_id`,
so APC and session affinity are not contaminated by repeated draws.
"""
from __future__ import annotations
import asyncio
import json
import logging
import random
import time
from collections import defaultdict
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Any
import httpx
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
from .replay import (
ReplayConfig,
_apply_realized_prefix,
_build_prompt_token_ids,
_dispatch_request,
)
from .trace import TraceRequest, load_trace
logger = logging.getLogger(__name__)
@dataclass
class SrrConfig:
trace_path: Path
output_path: Path
endpoint_url: str
arrival_rate: float # sessions per second
warmup_s: float = 60.0
steady_s: float = 300.0
drain_s: float = 60.0
concurrency_limit: int = 2000
request_timeout_s: float = 600.0
model_name: str = "default"
session_pool_size: int | None = None # cap on distinct session templates
rng_seed: int = 42
request_limit: int | None = None # cap on raw trace lines loaded
def _build_session_pool(
requests: list[TraceRequest],
pool_cap: int | None,
) -> list[list[TraceRequest]]:
"""Group trace requests by session id and keep turn order."""
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for r in requests:
by_session[r.session_id].append(r)
sessions = []
for sid, turns in by_session.items():
turns.sort(key=lambda r: (r.turn_id, r.timestamp_s))
sessions.append(turns)
if pool_cap is not None and pool_cap < len(sessions):
sessions = sessions[:pool_cap]
return sessions
def _window_for(t_unix: float, warmup_end: float, steady_end: float) -> str:
if t_unix < warmup_end:
return "warmup"
if t_unix < steady_end:
return "steady"
return "drain"
def _clone_session_for_arrival(
template: list[TraceRequest],
arrival_idx: int,
) -> list[TraceRequest]:
"""Return a copy of the session turns with fresh ids for this arrival.
The replayer uses request.request_id as X-Request-Id and request.session_id
as X-Session-Id. Reusing a template's ids across arrivals would alias on
both server side and proxy. Stamp `srr<arrival_idx>` to disambiguate.
"""
new_session_id = f"srr{arrival_idx}_{template[0].session_id}"
cloned: list[TraceRequest] = []
for t in template:
cloned.append(replace(
t,
session_id=new_session_id,
request_id=f"{new_session_id}:{t.turn_id}:{t.chat_id}",
))
return cloned
async def _run_one_session(
*,
turns: list[TraceRequest],
config: SrrConfig,
replay_cfg: ReplayConfig,
client: httpx.AsyncClient,
request_sem: asyncio.Semaphore,
sink: IncrementalMetricSink,
counters: dict[str, dict[str, int]],
window_for_now,
deadline_unix: float,
) -> None:
realized_context: list[int] = []
for req in turns:
t_dispatch_unix = time.time()
if t_dispatch_unix > deadline_unix:
return
window = window_for_now(t_dispatch_unix)
counters["attempted"][window] += 1
token_ids = _apply_realized_prefix(
_build_prompt_token_ids(req),
realized_context,
)
result = await _dispatch_request(
client=client,
config=replay_cfg,
req=req,
prompt_token_ids=token_ids,
sem=request_sem,
)
metric = result.metric
await sink.append(metric)
if metric.error is None:
counters["completed"][window] += 1
realized_context = token_ids + result.output_token_ids
else:
counters["errored"][window] += 1
async def run_srr(config: SrrConfig) -> dict[str, Any]:
requests = load_trace(config.trace_path, request_limit=config.request_limit)
pool = _build_session_pool(requests, config.session_pool_size)
if not pool:
raise ValueError(f"empty session pool from trace {config.trace_path}")
rng = random.Random(config.rng_seed)
sink = IncrementalMetricSink(config.output_path)
request_sem = asyncio.Semaphore(config.concurrency_limit)
replay_cfg = ReplayConfig(
trace_path=config.trace_path,
output_path=config.output_path,
endpoint_url=config.endpoint_url,
model_name=config.model_name,
concurrency_limit=config.concurrency_limit,
request_timeout_s=config.request_timeout_s,
)
run_start_unix = time.time()
warmup_end_unix = run_start_unix + config.warmup_s
steady_end_unix = warmup_end_unix + config.steady_s
drain_end_unix = steady_end_unix + config.drain_s
def _window(t: float) -> str:
return _window_for(t, warmup_end_unix, steady_end_unix)
counters: dict[str, dict[str, int]] = {
"attempted": defaultdict(int),
"completed": defaultdict(int),
"errored": defaultdict(int),
}
inter_arrivals: list[float] = []
arrival_unix: list[float] = []
limits = httpx.Limits(
max_connections=2000,
max_keepalive_connections=500,
keepalive_expiry=30.0,
)
logger.info(
"SRR start: pool=%d lambda=%.4f sess/s warmup=%.0fs steady=%.0fs drain=%.0fs",
len(pool), config.arrival_rate,
config.warmup_s, config.steady_s, config.drain_s,
)
arrival_idx = 0
session_tasks: list[asyncio.Task] = []
try:
async with httpx.AsyncClient(
timeout=config.request_timeout_s,
trust_env=False,
limits=limits,
) as client:
while True:
now = time.time()
if now >= steady_end_unix:
break
dt = rng.expovariate(config.arrival_rate)
inter_arrivals.append(dt)
await asyncio.sleep(dt)
if time.time() >= steady_end_unix:
break
template = rng.choice(pool)
cloned = _clone_session_for_arrival(template, arrival_idx)
arrival_idx += 1
arrival_unix.append(time.time())
task = asyncio.create_task(_run_one_session(
turns=cloned,
config=config,
replay_cfg=replay_cfg,
client=client,
request_sem=request_sem,
sink=sink,
counters=counters,
window_for_now=_window,
deadline_unix=drain_end_unix,
))
session_tasks.append(task)
# Drain phase: stop accepting new arrivals, wait for in-flight.
drain_timeout = max(0.0, drain_end_unix - time.time())
logger.info("SRR drain: waiting up to %.1fs for %d in-flight sessions",
drain_timeout, sum(1 for t in session_tasks if not t.done()))
if session_tasks:
done, pending = await asyncio.wait(
session_tasks, timeout=drain_timeout,
)
for t in pending:
t.cancel()
finally:
sink.close()
summary = {
"run_start_unix": run_start_unix,
"warmup_end_unix": warmup_end_unix,
"steady_end_unix": steady_end_unix,
"drain_end_unix": drain_end_unix,
"arrival_rate": config.arrival_rate,
"session_pool_size": len(pool),
"sessions_arrived": arrival_idx,
"inter_arrival_mean_s": (
sum(inter_arrivals) / len(inter_arrivals)
if inter_arrivals else None
),
"attempted": dict(counters["attempted"]),
"completed": dict(counters["completed"]),
"errored": dict(counters["errored"]),
"rng_seed": config.rng_seed,
}
summary_path = config.output_path.with_name(
config.output_path.stem + ".window_summary.json"
)
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True))
logger.info(
"SRR done: arrived=%d attempted=%s completed=%s errored=%s",
arrival_idx,
dict(counters["attempted"]),
dict(counters["completed"]),
dict(counters["errored"]),
)
return summary
def _parse_args(argv: list[str] | None = None):
import argparse
p = argparse.ArgumentParser(
description="Open-loop Poisson SRR loadgen with session-causal turns",
)
p.add_argument("--trace", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--endpoint", type=str, required=True)
p.add_argument("--model", type=str, default="default")
p.add_argument("--arrival-rate", type=float, required=True,
help="Sessions per second (Poisson)")
p.add_argument("--warmup-s", type=float, default=60.0)
p.add_argument("--steady-s", type=float, default=300.0)
p.add_argument("--drain-s", type=float, default=60.0)
p.add_argument("--concurrency-limit", type=int, default=2000)
p.add_argument("--request-timeout", type=float, default=600.0)
p.add_argument("--session-pool-size", type=int, default=None)
p.add_argument("--request-limit", type=int, default=None)
p.add_argument("--rng-seed", type=int, default=42)
p.add_argument("-v", "--verbose", action="store_true")
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> None:
args = _parse_args(argv)
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
config = SrrConfig(
trace_path=args.trace,
output_path=args.output,
endpoint_url=args.endpoint.rstrip("/"),
model_name=args.model,
arrival_rate=args.arrival_rate,
warmup_s=args.warmup_s,
steady_s=args.steady_s,
drain_s=args.drain_s,
concurrency_limit=args.concurrency_limit,
request_timeout_s=args.request_timeout,
session_pool_size=args.session_pool_size,
request_limit=args.request_limit,
rng_seed=args.rng_seed,
)
summary = asyncio.run(run_srr(config))
print(json.dumps(summary, indent=2, sort_keys=True))
if __name__ == "__main__":
main()