A4: open-loop session-causal SRR loadgen
New replayer/srr.py drives a Poisson session-arrival load against the existing proxy, with strict per-session turn sequentiality, explicit warmup/steady/drain windows, and per-arrival fresh session_id + request_id so APC/session-affinity counters are not contaminated by repeated draws from the trace pool. Writes window_summary.json with attempted/completed/errored split by window so latency tails can be read on the steady-state window only. Required by Batch 4 SRR sweep; trace-timestamp dispatch in replay.py cannot drive arrival rate independently. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
320
replayer/srr.py
Normal file
320
replayer/srr.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Open-loop session-causal SRR (Sustainable Request Rate) loadgen.
|
||||
|
||||
Differs from `replayer.replay` in three ways:
|
||||
|
||||
- Sessions arrive at Poisson rate `lambda` (sessions/s) independent of the
|
||||
trace timestamps. The trace is treated as a *pool* of session templates;
|
||||
each arrival picks one and replays its turns sequentially.
|
||||
- A warmup window and a steady-state window are explicit; per-window
|
||||
attempted / completed / error counters are written to
|
||||
`window_summary.json` so latency tails are computable on steady state
|
||||
only.
|
||||
- Each arrival gets a fresh `session_id` and per-turn `request_id`,
|
||||
so APC and session affinity are not contaminated by repeated draws.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
||||
from .replay import (
|
||||
ReplayConfig,
|
||||
_apply_realized_prefix,
|
||||
_build_prompt_token_ids,
|
||||
_dispatch_request,
|
||||
)
|
||||
from .trace import TraceRequest, load_trace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SrrConfig:
|
||||
trace_path: Path
|
||||
output_path: Path
|
||||
endpoint_url: str
|
||||
arrival_rate: float # sessions per second
|
||||
warmup_s: float = 60.0
|
||||
steady_s: float = 300.0
|
||||
drain_s: float = 60.0
|
||||
concurrency_limit: int = 2000
|
||||
request_timeout_s: float = 600.0
|
||||
model_name: str = "default"
|
||||
session_pool_size: int | None = None # cap on distinct session templates
|
||||
rng_seed: int = 42
|
||||
request_limit: int | None = None # cap on raw trace lines loaded
|
||||
|
||||
|
||||
def _build_session_pool(
|
||||
requests: list[TraceRequest],
|
||||
pool_cap: int | None,
|
||||
) -> list[list[TraceRequest]]:
|
||||
"""Group trace requests by session id and keep turn order."""
|
||||
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
||||
for r in requests:
|
||||
by_session[r.session_id].append(r)
|
||||
sessions = []
|
||||
for sid, turns in by_session.items():
|
||||
turns.sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
||||
sessions.append(turns)
|
||||
if pool_cap is not None and pool_cap < len(sessions):
|
||||
sessions = sessions[:pool_cap]
|
||||
return sessions
|
||||
|
||||
|
||||
def _window_for(t_unix: float, warmup_end: float, steady_end: float) -> str:
|
||||
if t_unix < warmup_end:
|
||||
return "warmup"
|
||||
if t_unix < steady_end:
|
||||
return "steady"
|
||||
return "drain"
|
||||
|
||||
|
||||
def _clone_session_for_arrival(
|
||||
template: list[TraceRequest],
|
||||
arrival_idx: int,
|
||||
) -> list[TraceRequest]:
|
||||
"""Return a copy of the session turns with fresh ids for this arrival.
|
||||
|
||||
The replayer uses request.request_id as X-Request-Id and request.session_id
|
||||
as X-Session-Id. Reusing a template's ids across arrivals would alias on
|
||||
both server side and proxy. Stamp `srr<arrival_idx>` to disambiguate.
|
||||
"""
|
||||
new_session_id = f"srr{arrival_idx}_{template[0].session_id}"
|
||||
cloned: list[TraceRequest] = []
|
||||
for t in template:
|
||||
cloned.append(replace(
|
||||
t,
|
||||
session_id=new_session_id,
|
||||
request_id=f"{new_session_id}:{t.turn_id}:{t.chat_id}",
|
||||
))
|
||||
return cloned
|
||||
|
||||
|
||||
async def _run_one_session(
|
||||
*,
|
||||
turns: list[TraceRequest],
|
||||
config: SrrConfig,
|
||||
replay_cfg: ReplayConfig,
|
||||
client: httpx.AsyncClient,
|
||||
request_sem: asyncio.Semaphore,
|
||||
sink: IncrementalMetricSink,
|
||||
counters: dict[str, dict[str, int]],
|
||||
window_for_now,
|
||||
deadline_unix: float,
|
||||
) -> None:
|
||||
realized_context: list[int] = []
|
||||
for req in turns:
|
||||
t_dispatch_unix = time.time()
|
||||
if t_dispatch_unix > deadline_unix:
|
||||
return
|
||||
window = window_for_now(t_dispatch_unix)
|
||||
counters["attempted"][window] += 1
|
||||
|
||||
token_ids = _apply_realized_prefix(
|
||||
_build_prompt_token_ids(req),
|
||||
realized_context,
|
||||
)
|
||||
result = await _dispatch_request(
|
||||
client=client,
|
||||
config=replay_cfg,
|
||||
req=req,
|
||||
prompt_token_ids=token_ids,
|
||||
sem=request_sem,
|
||||
)
|
||||
metric = result.metric
|
||||
await sink.append(metric)
|
||||
if metric.error is None:
|
||||
counters["completed"][window] += 1
|
||||
realized_context = token_ids + result.output_token_ids
|
||||
else:
|
||||
counters["errored"][window] += 1
|
||||
|
||||
|
||||
async def run_srr(config: SrrConfig) -> dict[str, Any]:
|
||||
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
||||
pool = _build_session_pool(requests, config.session_pool_size)
|
||||
if not pool:
|
||||
raise ValueError(f"empty session pool from trace {config.trace_path}")
|
||||
|
||||
rng = random.Random(config.rng_seed)
|
||||
sink = IncrementalMetricSink(config.output_path)
|
||||
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
||||
|
||||
replay_cfg = ReplayConfig(
|
||||
trace_path=config.trace_path,
|
||||
output_path=config.output_path,
|
||||
endpoint_url=config.endpoint_url,
|
||||
model_name=config.model_name,
|
||||
concurrency_limit=config.concurrency_limit,
|
||||
request_timeout_s=config.request_timeout_s,
|
||||
)
|
||||
|
||||
run_start_unix = time.time()
|
||||
warmup_end_unix = run_start_unix + config.warmup_s
|
||||
steady_end_unix = warmup_end_unix + config.steady_s
|
||||
drain_end_unix = steady_end_unix + config.drain_s
|
||||
|
||||
def _window(t: float) -> str:
|
||||
return _window_for(t, warmup_end_unix, steady_end_unix)
|
||||
|
||||
counters: dict[str, dict[str, int]] = {
|
||||
"attempted": defaultdict(int),
|
||||
"completed": defaultdict(int),
|
||||
"errored": defaultdict(int),
|
||||
}
|
||||
inter_arrivals: list[float] = []
|
||||
arrival_unix: list[float] = []
|
||||
|
||||
limits = httpx.Limits(
|
||||
max_connections=2000,
|
||||
max_keepalive_connections=500,
|
||||
keepalive_expiry=30.0,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"SRR start: pool=%d lambda=%.4f sess/s warmup=%.0fs steady=%.0fs drain=%.0fs",
|
||||
len(pool), config.arrival_rate,
|
||||
config.warmup_s, config.steady_s, config.drain_s,
|
||||
)
|
||||
|
||||
arrival_idx = 0
|
||||
session_tasks: list[asyncio.Task] = []
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=config.request_timeout_s,
|
||||
trust_env=False,
|
||||
limits=limits,
|
||||
) as client:
|
||||
while True:
|
||||
now = time.time()
|
||||
if now >= steady_end_unix:
|
||||
break
|
||||
dt = rng.expovariate(config.arrival_rate)
|
||||
inter_arrivals.append(dt)
|
||||
await asyncio.sleep(dt)
|
||||
if time.time() >= steady_end_unix:
|
||||
break
|
||||
template = rng.choice(pool)
|
||||
cloned = _clone_session_for_arrival(template, arrival_idx)
|
||||
arrival_idx += 1
|
||||
arrival_unix.append(time.time())
|
||||
task = asyncio.create_task(_run_one_session(
|
||||
turns=cloned,
|
||||
config=config,
|
||||
replay_cfg=replay_cfg,
|
||||
client=client,
|
||||
request_sem=request_sem,
|
||||
sink=sink,
|
||||
counters=counters,
|
||||
window_for_now=_window,
|
||||
deadline_unix=drain_end_unix,
|
||||
))
|
||||
session_tasks.append(task)
|
||||
|
||||
# Drain phase: stop accepting new arrivals, wait for in-flight.
|
||||
drain_timeout = max(0.0, drain_end_unix - time.time())
|
||||
logger.info("SRR drain: waiting up to %.1fs for %d in-flight sessions",
|
||||
drain_timeout, sum(1 for t in session_tasks if not t.done()))
|
||||
if session_tasks:
|
||||
done, pending = await asyncio.wait(
|
||||
session_tasks, timeout=drain_timeout,
|
||||
)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
finally:
|
||||
sink.close()
|
||||
|
||||
summary = {
|
||||
"run_start_unix": run_start_unix,
|
||||
"warmup_end_unix": warmup_end_unix,
|
||||
"steady_end_unix": steady_end_unix,
|
||||
"drain_end_unix": drain_end_unix,
|
||||
"arrival_rate": config.arrival_rate,
|
||||
"session_pool_size": len(pool),
|
||||
"sessions_arrived": arrival_idx,
|
||||
"inter_arrival_mean_s": (
|
||||
sum(inter_arrivals) / len(inter_arrivals)
|
||||
if inter_arrivals else None
|
||||
),
|
||||
"attempted": dict(counters["attempted"]),
|
||||
"completed": dict(counters["completed"]),
|
||||
"errored": dict(counters["errored"]),
|
||||
"rng_seed": config.rng_seed,
|
||||
}
|
||||
summary_path = config.output_path.with_name(
|
||||
config.output_path.stem + ".window_summary.json"
|
||||
)
|
||||
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True))
|
||||
logger.info(
|
||||
"SRR done: arrived=%d attempted=%s completed=%s errored=%s",
|
||||
arrival_idx,
|
||||
dict(counters["attempted"]),
|
||||
dict(counters["completed"]),
|
||||
dict(counters["errored"]),
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None):
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description="Open-loop Poisson SRR loadgen with session-causal turns",
|
||||
)
|
||||
p.add_argument("--trace", type=Path, required=True)
|
||||
p.add_argument("--output", type=Path, required=True)
|
||||
p.add_argument("--endpoint", type=str, required=True)
|
||||
p.add_argument("--model", type=str, default="default")
|
||||
p.add_argument("--arrival-rate", type=float, required=True,
|
||||
help="Sessions per second (Poisson)")
|
||||
p.add_argument("--warmup-s", type=float, default=60.0)
|
||||
p.add_argument("--steady-s", type=float, default=300.0)
|
||||
p.add_argument("--drain-s", type=float, default=60.0)
|
||||
p.add_argument("--concurrency-limit", type=int, default=2000)
|
||||
p.add_argument("--request-timeout", type=float, default=600.0)
|
||||
p.add_argument("--session-pool-size", type=int, default=None)
|
||||
p.add_argument("--request-limit", type=int, default=None)
|
||||
p.add_argument("--rng-seed", type=int, default=42)
|
||||
p.add_argument("-v", "--verbose", action="store_true")
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> None:
|
||||
args = _parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
config = SrrConfig(
|
||||
trace_path=args.trace,
|
||||
output_path=args.output,
|
||||
endpoint_url=args.endpoint.rstrip("/"),
|
||||
model_name=args.model,
|
||||
arrival_rate=args.arrival_rate,
|
||||
warmup_s=args.warmup_s,
|
||||
steady_s=args.steady_s,
|
||||
drain_s=args.drain_s,
|
||||
concurrency_limit=args.concurrency_limit,
|
||||
request_timeout_s=args.request_timeout,
|
||||
session_pool_size=args.session_pool_size,
|
||||
request_limit=args.request_limit,
|
||||
rng_seed=args.rng_seed,
|
||||
)
|
||||
summary = asyncio.run(run_srr(config))
|
||||
print(json.dumps(summary, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
tests/test_srr_loadgen.py
Normal file
100
tests/test_srr_loadgen.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests for A4 SRR loadgen helpers (no network I/O)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from replayer.srr import (
|
||||
_build_session_pool,
|
||||
_clone_session_for_arrival,
|
||||
_window_for,
|
||||
)
|
||||
from replayer.trace import TraceRequest
|
||||
|
||||
|
||||
def _mk_req(session_id: str, turn: int, chat_id: int, ts: float) -> TraceRequest:
|
||||
return TraceRequest(
|
||||
request_id=f"{session_id}:{turn}:{chat_id}:{turn}",
|
||||
session_id=session_id,
|
||||
chat_id=chat_id,
|
||||
parent_chat_id=-1 if turn == 0 else chat_id - 1,
|
||||
timestamp_s=ts,
|
||||
input_length=100,
|
||||
output_length=10,
|
||||
request_type="user",
|
||||
turn_id=turn,
|
||||
hash_ids=(chat_id,),
|
||||
)
|
||||
|
||||
|
||||
def test_build_session_pool_groups_and_orders_turns():
|
||||
reqs = [
|
||||
_mk_req("s1", 1, 11, 2.0),
|
||||
_mk_req("s2", 0, 20, 1.5),
|
||||
_mk_req("s1", 0, 10, 1.0),
|
||||
]
|
||||
pool = _build_session_pool(reqs, pool_cap=None)
|
||||
assert len(pool) == 2
|
||||
by_sid = {turns[0].session_id: turns for turns in pool}
|
||||
assert [t.turn_id for t in by_sid["s1"]] == [0, 1]
|
||||
assert [t.turn_id for t in by_sid["s2"]] == [0]
|
||||
|
||||
|
||||
def test_build_session_pool_honors_pool_cap():
|
||||
reqs = [_mk_req(f"s{i}", 0, i, float(i)) for i in range(5)]
|
||||
pool = _build_session_pool(reqs, pool_cap=2)
|
||||
assert len(pool) == 2
|
||||
|
||||
|
||||
def test_window_for_classifies_correctly():
|
||||
warmup_end = 100.0
|
||||
steady_end = 400.0
|
||||
assert _window_for(50.0, warmup_end, steady_end) == "warmup"
|
||||
assert _window_for(100.0, warmup_end, steady_end) == "steady"
|
||||
assert _window_for(399.999, warmup_end, steady_end) == "steady"
|
||||
assert _window_for(400.0, warmup_end, steady_end) == "drain"
|
||||
assert _window_for(500.0, warmup_end, steady_end) == "drain"
|
||||
|
||||
|
||||
def test_clone_session_uses_fresh_ids_so_arrivals_do_not_alias():
|
||||
template = [_mk_req("orig", 0, 100, 1.0), _mk_req("orig", 1, 101, 2.0)]
|
||||
clone_a = _clone_session_for_arrival(template, arrival_idx=7)
|
||||
clone_b = _clone_session_for_arrival(template, arrival_idx=8)
|
||||
|
||||
for c in (clone_a, clone_b):
|
||||
assert c[0].session_id == c[1].session_id # within an arrival
|
||||
assert c[0].session_id.startswith("srr")
|
||||
assert c[0].request_id != template[0].request_id
|
||||
|
||||
assert clone_a[0].session_id != clone_b[0].session_id
|
||||
assert clone_a[0].request_id != clone_b[0].request_id
|
||||
|
||||
|
||||
def test_clone_session_preserves_token_payload_fields():
|
||||
template = [_mk_req("orig", 0, 100, 1.0)]
|
||||
template = [TraceRequest(
|
||||
request_id=template[0].request_id,
|
||||
session_id=template[0].session_id,
|
||||
chat_id=template[0].chat_id,
|
||||
parent_chat_id=template[0].parent_chat_id,
|
||||
timestamp_s=template[0].timestamp_s,
|
||||
input_length=4000,
|
||||
output_length=300,
|
||||
request_type="user",
|
||||
turn_id=0,
|
||||
hash_ids=(1, 2, 3, 4, 5),
|
||||
)]
|
||||
cloned = _clone_session_for_arrival(template, arrival_idx=1)
|
||||
assert cloned[0].input_length == 4000
|
||||
assert cloned[0].output_length == 300
|
||||
assert cloned[0].hash_ids == (1, 2, 3, 4, 5)
|
||||
assert cloned[0].turn_id == 0
|
||||
|
||||
|
||||
def test_poisson_inter_arrival_mean_matches_rate():
|
||||
"""Sanity check on the exponential RNG used for arrivals."""
|
||||
rng = random.Random(0)
|
||||
rate = 5.0 # 5 sess/s -> mean inter-arrival ~ 0.2 s
|
||||
samples = [rng.expovariate(rate) for _ in range(20000)]
|
||||
mean = sum(samples) / len(samples)
|
||||
assert abs(mean - 1.0 / rate) < 0.01
|
||||
Reference in New Issue
Block a user