From f42c715ec18aa4c98adcc9b553d3cac2bac2c4ef Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 25 May 2026 16:19:20 +0800 Subject: [PATCH] A4: open-loop session-causal SRR loadgen New replayer/srr.py drives a Poisson session-arrival load against the existing proxy, with strict per-session turn sequentiality, explicit warmup/steady/drain windows, and per-arrival fresh session_id + request_id so APC/session-affinity counters are not contaminated by repeated draws from the trace pool. Writes window_summary.json with attempted/completed/errored split by window so latency tails can be read on the steady-state window only. Required by Batch 4 SRR sweep; trace-timestamp dispatch in replay.py cannot drive arrival rate independently. Co-Authored-By: Claude Opus 4.7 --- replayer/srr.py | 320 ++++++++++++++++++++++++++++++++++++++ tests/test_srr_loadgen.py | 100 ++++++++++++ 2 files changed, 420 insertions(+) create mode 100644 replayer/srr.py create mode 100644 tests/test_srr_loadgen.py diff --git a/replayer/srr.py b/replayer/srr.py new file mode 100644 index 0000000..7726cf0 --- /dev/null +++ b/replayer/srr.py @@ -0,0 +1,320 @@ +"""Open-loop session-causal SRR (Sustainable Request Rate) loadgen. + +Differs from `replayer.replay` in three ways: + +- Sessions arrive at Poisson rate `lambda` (sessions/s) independent of the + trace timestamps. The trace is treated as a *pool* of session templates; + each arrival picks one and replays its turns sequentially. +- A warmup window and a steady-state window are explicit; per-window + attempted / completed / error counters are written to + `window_summary.json` so latency tails are computable on steady state + only. +- Each arrival gets a fresh `session_id` and per-turn `request_id`, + so APC and session affinity are not contaminated by repeated draws. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import random +import time +from collections import defaultdict +from dataclasses import dataclass, replace +from pathlib import Path +from typing import Any + +import httpx + +from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json +from .replay import ( + ReplayConfig, + _apply_realized_prefix, + _build_prompt_token_ids, + _dispatch_request, +) +from .trace import TraceRequest, load_trace + +logger = logging.getLogger(__name__) + + +@dataclass +class SrrConfig: + trace_path: Path + output_path: Path + endpoint_url: str + arrival_rate: float # sessions per second + warmup_s: float = 60.0 + steady_s: float = 300.0 + drain_s: float = 60.0 + concurrency_limit: int = 2000 + request_timeout_s: float = 600.0 + model_name: str = "default" + session_pool_size: int | None = None # cap on distinct session templates + rng_seed: int = 42 + request_limit: int | None = None # cap on raw trace lines loaded + + +def _build_session_pool( + requests: list[TraceRequest], + pool_cap: int | None, +) -> list[list[TraceRequest]]: + """Group trace requests by session id and keep turn order.""" + by_session: dict[str, list[TraceRequest]] = defaultdict(list) + for r in requests: + by_session[r.session_id].append(r) + sessions = [] + for sid, turns in by_session.items(): + turns.sort(key=lambda r: (r.turn_id, r.timestamp_s)) + sessions.append(turns) + if pool_cap is not None and pool_cap < len(sessions): + sessions = sessions[:pool_cap] + return sessions + + +def _window_for(t_unix: float, warmup_end: float, steady_end: float) -> str: + if t_unix < warmup_end: + return "warmup" + if t_unix < steady_end: + return "steady" + return "drain" + + +def _clone_session_for_arrival( + template: list[TraceRequest], + arrival_idx: int, +) -> list[TraceRequest]: + """Return a copy of the session turns with fresh ids for this arrival. + + The replayer uses request.request_id as X-Request-Id and request.session_id + as X-Session-Id. Reusing a template's ids across arrivals would alias on + both server side and proxy. Stamp `srr` to disambiguate. + """ + new_session_id = f"srr{arrival_idx}_{template[0].session_id}" + cloned: list[TraceRequest] = [] + for t in template: + cloned.append(replace( + t, + session_id=new_session_id, + request_id=f"{new_session_id}:{t.turn_id}:{t.chat_id}", + )) + return cloned + + +async def _run_one_session( + *, + turns: list[TraceRequest], + config: SrrConfig, + replay_cfg: ReplayConfig, + client: httpx.AsyncClient, + request_sem: asyncio.Semaphore, + sink: IncrementalMetricSink, + counters: dict[str, dict[str, int]], + window_for_now, + deadline_unix: float, +) -> None: + realized_context: list[int] = [] + for req in turns: + t_dispatch_unix = time.time() + if t_dispatch_unix > deadline_unix: + return + window = window_for_now(t_dispatch_unix) + counters["attempted"][window] += 1 + + token_ids = _apply_realized_prefix( + _build_prompt_token_ids(req), + realized_context, + ) + result = await _dispatch_request( + client=client, + config=replay_cfg, + req=req, + prompt_token_ids=token_ids, + sem=request_sem, + ) + metric = result.metric + await sink.append(metric) + if metric.error is None: + counters["completed"][window] += 1 + realized_context = token_ids + result.output_token_ids + else: + counters["errored"][window] += 1 + + +async def run_srr(config: SrrConfig) -> dict[str, Any]: + requests = load_trace(config.trace_path, request_limit=config.request_limit) + pool = _build_session_pool(requests, config.session_pool_size) + if not pool: + raise ValueError(f"empty session pool from trace {config.trace_path}") + + rng = random.Random(config.rng_seed) + sink = IncrementalMetricSink(config.output_path) + request_sem = asyncio.Semaphore(config.concurrency_limit) + + replay_cfg = ReplayConfig( + trace_path=config.trace_path, + output_path=config.output_path, + endpoint_url=config.endpoint_url, + model_name=config.model_name, + concurrency_limit=config.concurrency_limit, + request_timeout_s=config.request_timeout_s, + ) + + run_start_unix = time.time() + warmup_end_unix = run_start_unix + config.warmup_s + steady_end_unix = warmup_end_unix + config.steady_s + drain_end_unix = steady_end_unix + config.drain_s + + def _window(t: float) -> str: + return _window_for(t, warmup_end_unix, steady_end_unix) + + counters: dict[str, dict[str, int]] = { + "attempted": defaultdict(int), + "completed": defaultdict(int), + "errored": defaultdict(int), + } + inter_arrivals: list[float] = [] + arrival_unix: list[float] = [] + + limits = httpx.Limits( + max_connections=2000, + max_keepalive_connections=500, + keepalive_expiry=30.0, + ) + + logger.info( + "SRR start: pool=%d lambda=%.4f sess/s warmup=%.0fs steady=%.0fs drain=%.0fs", + len(pool), config.arrival_rate, + config.warmup_s, config.steady_s, config.drain_s, + ) + + arrival_idx = 0 + session_tasks: list[asyncio.Task] = [] + try: + async with httpx.AsyncClient( + timeout=config.request_timeout_s, + trust_env=False, + limits=limits, + ) as client: + while True: + now = time.time() + if now >= steady_end_unix: + break + dt = rng.expovariate(config.arrival_rate) + inter_arrivals.append(dt) + await asyncio.sleep(dt) + if time.time() >= steady_end_unix: + break + template = rng.choice(pool) + cloned = _clone_session_for_arrival(template, arrival_idx) + arrival_idx += 1 + arrival_unix.append(time.time()) + task = asyncio.create_task(_run_one_session( + turns=cloned, + config=config, + replay_cfg=replay_cfg, + client=client, + request_sem=request_sem, + sink=sink, + counters=counters, + window_for_now=_window, + deadline_unix=drain_end_unix, + )) + session_tasks.append(task) + + # Drain phase: stop accepting new arrivals, wait for in-flight. + drain_timeout = max(0.0, drain_end_unix - time.time()) + logger.info("SRR drain: waiting up to %.1fs for %d in-flight sessions", + drain_timeout, sum(1 for t in session_tasks if not t.done())) + if session_tasks: + done, pending = await asyncio.wait( + session_tasks, timeout=drain_timeout, + ) + for t in pending: + t.cancel() + finally: + sink.close() + + summary = { + "run_start_unix": run_start_unix, + "warmup_end_unix": warmup_end_unix, + "steady_end_unix": steady_end_unix, + "drain_end_unix": drain_end_unix, + "arrival_rate": config.arrival_rate, + "session_pool_size": len(pool), + "sessions_arrived": arrival_idx, + "inter_arrival_mean_s": ( + sum(inter_arrivals) / len(inter_arrivals) + if inter_arrivals else None + ), + "attempted": dict(counters["attempted"]), + "completed": dict(counters["completed"]), + "errored": dict(counters["errored"]), + "rng_seed": config.rng_seed, + } + summary_path = config.output_path.with_name( + config.output_path.stem + ".window_summary.json" + ) + summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True)) + logger.info( + "SRR done: arrived=%d attempted=%s completed=%s errored=%s", + arrival_idx, + dict(counters["attempted"]), + dict(counters["completed"]), + dict(counters["errored"]), + ) + return summary + + +def _parse_args(argv: list[str] | None = None): + import argparse + + p = argparse.ArgumentParser( + description="Open-loop Poisson SRR loadgen with session-causal turns", + ) + p.add_argument("--trace", type=Path, required=True) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--endpoint", type=str, required=True) + p.add_argument("--model", type=str, default="default") + p.add_argument("--arrival-rate", type=float, required=True, + help="Sessions per second (Poisson)") + p.add_argument("--warmup-s", type=float, default=60.0) + p.add_argument("--steady-s", type=float, default=300.0) + p.add_argument("--drain-s", type=float, default=60.0) + p.add_argument("--concurrency-limit", type=int, default=2000) + p.add_argument("--request-timeout", type=float, default=600.0) + p.add_argument("--session-pool-size", type=int, default=None) + p.add_argument("--request-limit", type=int, default=None) + p.add_argument("--rng-seed", type=int, default=42) + p.add_argument("-v", "--verbose", action="store_true") + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = _parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + config = SrrConfig( + trace_path=args.trace, + output_path=args.output, + endpoint_url=args.endpoint.rstrip("/"), + model_name=args.model, + arrival_rate=args.arrival_rate, + warmup_s=args.warmup_s, + steady_s=args.steady_s, + drain_s=args.drain_s, + concurrency_limit=args.concurrency_limit, + request_timeout_s=args.request_timeout, + session_pool_size=args.session_pool_size, + request_limit=args.request_limit, + rng_seed=args.rng_seed, + ) + summary = asyncio.run(run_srr(config)) + print(json.dumps(summary, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/tests/test_srr_loadgen.py b/tests/test_srr_loadgen.py new file mode 100644 index 0000000..52b033f --- /dev/null +++ b/tests/test_srr_loadgen.py @@ -0,0 +1,100 @@ +"""Tests for A4 SRR loadgen helpers (no network I/O).""" + +from __future__ import annotations + +import random + +from replayer.srr import ( + _build_session_pool, + _clone_session_for_arrival, + _window_for, +) +from replayer.trace import TraceRequest + + +def _mk_req(session_id: str, turn: int, chat_id: int, ts: float) -> TraceRequest: + return TraceRequest( + request_id=f"{session_id}:{turn}:{chat_id}:{turn}", + session_id=session_id, + chat_id=chat_id, + parent_chat_id=-1 if turn == 0 else chat_id - 1, + timestamp_s=ts, + input_length=100, + output_length=10, + request_type="user", + turn_id=turn, + hash_ids=(chat_id,), + ) + + +def test_build_session_pool_groups_and_orders_turns(): + reqs = [ + _mk_req("s1", 1, 11, 2.0), + _mk_req("s2", 0, 20, 1.5), + _mk_req("s1", 0, 10, 1.0), + ] + pool = _build_session_pool(reqs, pool_cap=None) + assert len(pool) == 2 + by_sid = {turns[0].session_id: turns for turns in pool} + assert [t.turn_id for t in by_sid["s1"]] == [0, 1] + assert [t.turn_id for t in by_sid["s2"]] == [0] + + +def test_build_session_pool_honors_pool_cap(): + reqs = [_mk_req(f"s{i}", 0, i, float(i)) for i in range(5)] + pool = _build_session_pool(reqs, pool_cap=2) + assert len(pool) == 2 + + +def test_window_for_classifies_correctly(): + warmup_end = 100.0 + steady_end = 400.0 + assert _window_for(50.0, warmup_end, steady_end) == "warmup" + assert _window_for(100.0, warmup_end, steady_end) == "steady" + assert _window_for(399.999, warmup_end, steady_end) == "steady" + assert _window_for(400.0, warmup_end, steady_end) == "drain" + assert _window_for(500.0, warmup_end, steady_end) == "drain" + + +def test_clone_session_uses_fresh_ids_so_arrivals_do_not_alias(): + template = [_mk_req("orig", 0, 100, 1.0), _mk_req("orig", 1, 101, 2.0)] + clone_a = _clone_session_for_arrival(template, arrival_idx=7) + clone_b = _clone_session_for_arrival(template, arrival_idx=8) + + for c in (clone_a, clone_b): + assert c[0].session_id == c[1].session_id # within an arrival + assert c[0].session_id.startswith("srr") + assert c[0].request_id != template[0].request_id + + assert clone_a[0].session_id != clone_b[0].session_id + assert clone_a[0].request_id != clone_b[0].request_id + + +def test_clone_session_preserves_token_payload_fields(): + template = [_mk_req("orig", 0, 100, 1.0)] + template = [TraceRequest( + request_id=template[0].request_id, + session_id=template[0].session_id, + chat_id=template[0].chat_id, + parent_chat_id=template[0].parent_chat_id, + timestamp_s=template[0].timestamp_s, + input_length=4000, + output_length=300, + request_type="user", + turn_id=0, + hash_ids=(1, 2, 3, 4, 5), + )] + cloned = _clone_session_for_arrival(template, arrival_idx=1) + assert cloned[0].input_length == 4000 + assert cloned[0].output_length == 300 + assert cloned[0].hash_ids == (1, 2, 3, 4, 5) + assert cloned[0].turn_id == 0 + + +def test_poisson_inter_arrival_mean_matches_rate(): + """Sanity check on the exponential RNG used for arrivals.""" + rng = random.Random(0) + rate = 5.0 # 5 sess/s -> mean inter-arrival ~ 0.2 s + samples = [rng.expovariate(rate) for _ in range(20000)] + mean = sum(samples) / len(samples) + assert abs(mean - 1.0 / rate) < 0.01