A4: open-loop session-causal SRR loadgen

New replayer/srr.py drives a Poisson session-arrival load against the
existing proxy, with strict per-session turn sequentiality, explicit
warmup/steady/drain windows, and per-arrival fresh session_id +
request_id so APC/session-affinity counters are not contaminated by
repeated draws from the trace pool. Writes window_summary.json with
attempted/completed/errored split by window so latency tails can be
read on the steady-state window only.

Required by Batch 4 SRR sweep; trace-timestamp dispatch in replay.py
cannot drive arrival rate independently.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 16:19:20 +08:00
parent 5816aad731
commit f42c715ec1
2 changed files with 420 additions and 0 deletions

320
replayer/srr.py Normal file
View File

@@ -0,0 +1,320 @@
"""Open-loop session-causal SRR (Sustainable Request Rate) loadgen.
Differs from `replayer.replay` in three ways:
- Sessions arrive at Poisson rate `lambda` (sessions/s) independent of the
trace timestamps. The trace is treated as a *pool* of session templates;
each arrival picks one and replays its turns sequentially.
- A warmup window and a steady-state window are explicit; per-window
attempted / completed / error counters are written to
`window_summary.json` so latency tails are computable on steady state
only.
- Each arrival gets a fresh `session_id` and per-turn `request_id`,
so APC and session affinity are not contaminated by repeated draws.
"""
from __future__ import annotations
import asyncio
import json
import logging
import random
import time
from collections import defaultdict
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Any
import httpx
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
from .replay import (
ReplayConfig,
_apply_realized_prefix,
_build_prompt_token_ids,
_dispatch_request,
)
from .trace import TraceRequest, load_trace
logger = logging.getLogger(__name__)
@dataclass
class SrrConfig:
trace_path: Path
output_path: Path
endpoint_url: str
arrival_rate: float # sessions per second
warmup_s: float = 60.0
steady_s: float = 300.0
drain_s: float = 60.0
concurrency_limit: int = 2000
request_timeout_s: float = 600.0
model_name: str = "default"
session_pool_size: int | None = None # cap on distinct session templates
rng_seed: int = 42
request_limit: int | None = None # cap on raw trace lines loaded
def _build_session_pool(
requests: list[TraceRequest],
pool_cap: int | None,
) -> list[list[TraceRequest]]:
"""Group trace requests by session id and keep turn order."""
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
for r in requests:
by_session[r.session_id].append(r)
sessions = []
for sid, turns in by_session.items():
turns.sort(key=lambda r: (r.turn_id, r.timestamp_s))
sessions.append(turns)
if pool_cap is not None and pool_cap < len(sessions):
sessions = sessions[:pool_cap]
return sessions
def _window_for(t_unix: float, warmup_end: float, steady_end: float) -> str:
if t_unix < warmup_end:
return "warmup"
if t_unix < steady_end:
return "steady"
return "drain"
def _clone_session_for_arrival(
template: list[TraceRequest],
arrival_idx: int,
) -> list[TraceRequest]:
"""Return a copy of the session turns with fresh ids for this arrival.
The replayer uses request.request_id as X-Request-Id and request.session_id
as X-Session-Id. Reusing a template's ids across arrivals would alias on
both server side and proxy. Stamp `srr<arrival_idx>` to disambiguate.
"""
new_session_id = f"srr{arrival_idx}_{template[0].session_id}"
cloned: list[TraceRequest] = []
for t in template:
cloned.append(replace(
t,
session_id=new_session_id,
request_id=f"{new_session_id}:{t.turn_id}:{t.chat_id}",
))
return cloned
async def _run_one_session(
*,
turns: list[TraceRequest],
config: SrrConfig,
replay_cfg: ReplayConfig,
client: httpx.AsyncClient,
request_sem: asyncio.Semaphore,
sink: IncrementalMetricSink,
counters: dict[str, dict[str, int]],
window_for_now,
deadline_unix: float,
) -> None:
realized_context: list[int] = []
for req in turns:
t_dispatch_unix = time.time()
if t_dispatch_unix > deadline_unix:
return
window = window_for_now(t_dispatch_unix)
counters["attempted"][window] += 1
token_ids = _apply_realized_prefix(
_build_prompt_token_ids(req),
realized_context,
)
result = await _dispatch_request(
client=client,
config=replay_cfg,
req=req,
prompt_token_ids=token_ids,
sem=request_sem,
)
metric = result.metric
await sink.append(metric)
if metric.error is None:
counters["completed"][window] += 1
realized_context = token_ids + result.output_token_ids
else:
counters["errored"][window] += 1
async def run_srr(config: SrrConfig) -> dict[str, Any]:
requests = load_trace(config.trace_path, request_limit=config.request_limit)
pool = _build_session_pool(requests, config.session_pool_size)
if not pool:
raise ValueError(f"empty session pool from trace {config.trace_path}")
rng = random.Random(config.rng_seed)
sink = IncrementalMetricSink(config.output_path)
request_sem = asyncio.Semaphore(config.concurrency_limit)
replay_cfg = ReplayConfig(
trace_path=config.trace_path,
output_path=config.output_path,
endpoint_url=config.endpoint_url,
model_name=config.model_name,
concurrency_limit=config.concurrency_limit,
request_timeout_s=config.request_timeout_s,
)
run_start_unix = time.time()
warmup_end_unix = run_start_unix + config.warmup_s
steady_end_unix = warmup_end_unix + config.steady_s
drain_end_unix = steady_end_unix + config.drain_s
def _window(t: float) -> str:
return _window_for(t, warmup_end_unix, steady_end_unix)
counters: dict[str, dict[str, int]] = {
"attempted": defaultdict(int),
"completed": defaultdict(int),
"errored": defaultdict(int),
}
inter_arrivals: list[float] = []
arrival_unix: list[float] = []
limits = httpx.Limits(
max_connections=2000,
max_keepalive_connections=500,
keepalive_expiry=30.0,
)
logger.info(
"SRR start: pool=%d lambda=%.4f sess/s warmup=%.0fs steady=%.0fs drain=%.0fs",
len(pool), config.arrival_rate,
config.warmup_s, config.steady_s, config.drain_s,
)
arrival_idx = 0
session_tasks: list[asyncio.Task] = []
try:
async with httpx.AsyncClient(
timeout=config.request_timeout_s,
trust_env=False,
limits=limits,
) as client:
while True:
now = time.time()
if now >= steady_end_unix:
break
dt = rng.expovariate(config.arrival_rate)
inter_arrivals.append(dt)
await asyncio.sleep(dt)
if time.time() >= steady_end_unix:
break
template = rng.choice(pool)
cloned = _clone_session_for_arrival(template, arrival_idx)
arrival_idx += 1
arrival_unix.append(time.time())
task = asyncio.create_task(_run_one_session(
turns=cloned,
config=config,
replay_cfg=replay_cfg,
client=client,
request_sem=request_sem,
sink=sink,
counters=counters,
window_for_now=_window,
deadline_unix=drain_end_unix,
))
session_tasks.append(task)
# Drain phase: stop accepting new arrivals, wait for in-flight.
drain_timeout = max(0.0, drain_end_unix - time.time())
logger.info("SRR drain: waiting up to %.1fs for %d in-flight sessions",
drain_timeout, sum(1 for t in session_tasks if not t.done()))
if session_tasks:
done, pending = await asyncio.wait(
session_tasks, timeout=drain_timeout,
)
for t in pending:
t.cancel()
finally:
sink.close()
summary = {
"run_start_unix": run_start_unix,
"warmup_end_unix": warmup_end_unix,
"steady_end_unix": steady_end_unix,
"drain_end_unix": drain_end_unix,
"arrival_rate": config.arrival_rate,
"session_pool_size": len(pool),
"sessions_arrived": arrival_idx,
"inter_arrival_mean_s": (
sum(inter_arrivals) / len(inter_arrivals)
if inter_arrivals else None
),
"attempted": dict(counters["attempted"]),
"completed": dict(counters["completed"]),
"errored": dict(counters["errored"]),
"rng_seed": config.rng_seed,
}
summary_path = config.output_path.with_name(
config.output_path.stem + ".window_summary.json"
)
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True))
logger.info(
"SRR done: arrived=%d attempted=%s completed=%s errored=%s",
arrival_idx,
dict(counters["attempted"]),
dict(counters["completed"]),
dict(counters["errored"]),
)
return summary
def _parse_args(argv: list[str] | None = None):
import argparse
p = argparse.ArgumentParser(
description="Open-loop Poisson SRR loadgen with session-causal turns",
)
p.add_argument("--trace", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--endpoint", type=str, required=True)
p.add_argument("--model", type=str, default="default")
p.add_argument("--arrival-rate", type=float, required=True,
help="Sessions per second (Poisson)")
p.add_argument("--warmup-s", type=float, default=60.0)
p.add_argument("--steady-s", type=float, default=300.0)
p.add_argument("--drain-s", type=float, default=60.0)
p.add_argument("--concurrency-limit", type=int, default=2000)
p.add_argument("--request-timeout", type=float, default=600.0)
p.add_argument("--session-pool-size", type=int, default=None)
p.add_argument("--request-limit", type=int, default=None)
p.add_argument("--rng-seed", type=int, default=42)
p.add_argument("-v", "--verbose", action="store_true")
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> None:
args = _parse_args(argv)
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
config = SrrConfig(
trace_path=args.trace,
output_path=args.output,
endpoint_url=args.endpoint.rstrip("/"),
model_name=args.model,
arrival_rate=args.arrival_rate,
warmup_s=args.warmup_s,
steady_s=args.steady_s,
drain_s=args.drain_s,
concurrency_limit=args.concurrency_limit,
request_timeout_s=args.request_timeout,
session_pool_size=args.session_pool_size,
request_limit=args.request_limit,
rng_seed=args.rng_seed,
)
summary = asyncio.run(run_srr(config))
print(json.dumps(summary, indent=2, sort_keys=True))
if __name__ == "__main__":
main()

100
tests/test_srr_loadgen.py Normal file
View File

@@ -0,0 +1,100 @@
"""Tests for A4 SRR loadgen helpers (no network I/O)."""
from __future__ import annotations
import random
from replayer.srr import (
_build_session_pool,
_clone_session_for_arrival,
_window_for,
)
from replayer.trace import TraceRequest
def _mk_req(session_id: str, turn: int, chat_id: int, ts: float) -> TraceRequest:
return TraceRequest(
request_id=f"{session_id}:{turn}:{chat_id}:{turn}",
session_id=session_id,
chat_id=chat_id,
parent_chat_id=-1 if turn == 0 else chat_id - 1,
timestamp_s=ts,
input_length=100,
output_length=10,
request_type="user",
turn_id=turn,
hash_ids=(chat_id,),
)
def test_build_session_pool_groups_and_orders_turns():
reqs = [
_mk_req("s1", 1, 11, 2.0),
_mk_req("s2", 0, 20, 1.5),
_mk_req("s1", 0, 10, 1.0),
]
pool = _build_session_pool(reqs, pool_cap=None)
assert len(pool) == 2
by_sid = {turns[0].session_id: turns for turns in pool}
assert [t.turn_id for t in by_sid["s1"]] == [0, 1]
assert [t.turn_id for t in by_sid["s2"]] == [0]
def test_build_session_pool_honors_pool_cap():
reqs = [_mk_req(f"s{i}", 0, i, float(i)) for i in range(5)]
pool = _build_session_pool(reqs, pool_cap=2)
assert len(pool) == 2
def test_window_for_classifies_correctly():
warmup_end = 100.0
steady_end = 400.0
assert _window_for(50.0, warmup_end, steady_end) == "warmup"
assert _window_for(100.0, warmup_end, steady_end) == "steady"
assert _window_for(399.999, warmup_end, steady_end) == "steady"
assert _window_for(400.0, warmup_end, steady_end) == "drain"
assert _window_for(500.0, warmup_end, steady_end) == "drain"
def test_clone_session_uses_fresh_ids_so_arrivals_do_not_alias():
template = [_mk_req("orig", 0, 100, 1.0), _mk_req("orig", 1, 101, 2.0)]
clone_a = _clone_session_for_arrival(template, arrival_idx=7)
clone_b = _clone_session_for_arrival(template, arrival_idx=8)
for c in (clone_a, clone_b):
assert c[0].session_id == c[1].session_id # within an arrival
assert c[0].session_id.startswith("srr")
assert c[0].request_id != template[0].request_id
assert clone_a[0].session_id != clone_b[0].session_id
assert clone_a[0].request_id != clone_b[0].request_id
def test_clone_session_preserves_token_payload_fields():
template = [_mk_req("orig", 0, 100, 1.0)]
template = [TraceRequest(
request_id=template[0].request_id,
session_id=template[0].session_id,
chat_id=template[0].chat_id,
parent_chat_id=template[0].parent_chat_id,
timestamp_s=template[0].timestamp_s,
input_length=4000,
output_length=300,
request_type="user",
turn_id=0,
hash_ids=(1, 2, 3, 4, 5),
)]
cloned = _clone_session_for_arrival(template, arrival_idx=1)
assert cloned[0].input_length == 4000
assert cloned[0].output_length == 300
assert cloned[0].hash_ids == (1, 2, 3, 4, 5)
assert cloned[0].turn_id == 0
def test_poisson_inter_arrival_mean_matches_rate():
"""Sanity check on the exponential RNG used for arrivals."""
rng = random.Random(0)
rate = 5.0 # 5 sess/s -> mean inter-arrival ~ 0.2 s
samples = [rng.expovariate(rate) for _ in range(20000)]
mean = sum(samples) / len(samples)
assert abs(mean - 1.0 / rate) < 0.01