New replayer/srr.py drives a Poisson session-arrival load against the existing proxy, with strict per-session turn sequentiality, explicit warmup/steady/drain windows, and per-arrival fresh session_id + request_id so APC/session-affinity counters are not contaminated by repeated draws from the trace pool. Writes window_summary.json with attempted/completed/errored split by window so latency tails can be read on the steady-state window only. Required by Batch 4 SRR sweep; trace-timestamp dispatch in replay.py cannot drive arrival rate independently. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
321 lines
11 KiB
Python
321 lines
11 KiB
Python
"""Open-loop session-causal SRR (Sustainable Request Rate) loadgen.
|
|
|
|
Differs from `replayer.replay` in three ways:
|
|
|
|
- Sessions arrive at Poisson rate `lambda` (sessions/s) independent of the
|
|
trace timestamps. The trace is treated as a *pool* of session templates;
|
|
each arrival picks one and replays its turns sequentially.
|
|
- A warmup window and a steady-state window are explicit; per-window
|
|
attempted / completed / error counters are written to
|
|
`window_summary.json` so latency tails are computable on steady state
|
|
only.
|
|
- Each arrival gets a fresh `session_id` and per-turn `request_id`,
|
|
so APC and session affinity are not contaminated by repeated draws.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import random
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, replace
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from .metrics import IncrementalMetricSink, RequestMetrics, write_summary_json
|
|
from .replay import (
|
|
ReplayConfig,
|
|
_apply_realized_prefix,
|
|
_build_prompt_token_ids,
|
|
_dispatch_request,
|
|
)
|
|
from .trace import TraceRequest, load_trace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SrrConfig:
|
|
trace_path: Path
|
|
output_path: Path
|
|
endpoint_url: str
|
|
arrival_rate: float # sessions per second
|
|
warmup_s: float = 60.0
|
|
steady_s: float = 300.0
|
|
drain_s: float = 60.0
|
|
concurrency_limit: int = 2000
|
|
request_timeout_s: float = 600.0
|
|
model_name: str = "default"
|
|
session_pool_size: int | None = None # cap on distinct session templates
|
|
rng_seed: int = 42
|
|
request_limit: int | None = None # cap on raw trace lines loaded
|
|
|
|
|
|
def _build_session_pool(
|
|
requests: list[TraceRequest],
|
|
pool_cap: int | None,
|
|
) -> list[list[TraceRequest]]:
|
|
"""Group trace requests by session id and keep turn order."""
|
|
by_session: dict[str, list[TraceRequest]] = defaultdict(list)
|
|
for r in requests:
|
|
by_session[r.session_id].append(r)
|
|
sessions = []
|
|
for sid, turns in by_session.items():
|
|
turns.sort(key=lambda r: (r.turn_id, r.timestamp_s))
|
|
sessions.append(turns)
|
|
if pool_cap is not None and pool_cap < len(sessions):
|
|
sessions = sessions[:pool_cap]
|
|
return sessions
|
|
|
|
|
|
def _window_for(t_unix: float, warmup_end: float, steady_end: float) -> str:
|
|
if t_unix < warmup_end:
|
|
return "warmup"
|
|
if t_unix < steady_end:
|
|
return "steady"
|
|
return "drain"
|
|
|
|
|
|
def _clone_session_for_arrival(
|
|
template: list[TraceRequest],
|
|
arrival_idx: int,
|
|
) -> list[TraceRequest]:
|
|
"""Return a copy of the session turns with fresh ids for this arrival.
|
|
|
|
The replayer uses request.request_id as X-Request-Id and request.session_id
|
|
as X-Session-Id. Reusing a template's ids across arrivals would alias on
|
|
both server side and proxy. Stamp `srr<arrival_idx>` to disambiguate.
|
|
"""
|
|
new_session_id = f"srr{arrival_idx}_{template[0].session_id}"
|
|
cloned: list[TraceRequest] = []
|
|
for t in template:
|
|
cloned.append(replace(
|
|
t,
|
|
session_id=new_session_id,
|
|
request_id=f"{new_session_id}:{t.turn_id}:{t.chat_id}",
|
|
))
|
|
return cloned
|
|
|
|
|
|
async def _run_one_session(
|
|
*,
|
|
turns: list[TraceRequest],
|
|
config: SrrConfig,
|
|
replay_cfg: ReplayConfig,
|
|
client: httpx.AsyncClient,
|
|
request_sem: asyncio.Semaphore,
|
|
sink: IncrementalMetricSink,
|
|
counters: dict[str, dict[str, int]],
|
|
window_for_now,
|
|
deadline_unix: float,
|
|
) -> None:
|
|
realized_context: list[int] = []
|
|
for req in turns:
|
|
t_dispatch_unix = time.time()
|
|
if t_dispatch_unix > deadline_unix:
|
|
return
|
|
window = window_for_now(t_dispatch_unix)
|
|
counters["attempted"][window] += 1
|
|
|
|
token_ids = _apply_realized_prefix(
|
|
_build_prompt_token_ids(req),
|
|
realized_context,
|
|
)
|
|
result = await _dispatch_request(
|
|
client=client,
|
|
config=replay_cfg,
|
|
req=req,
|
|
prompt_token_ids=token_ids,
|
|
sem=request_sem,
|
|
)
|
|
metric = result.metric
|
|
await sink.append(metric)
|
|
if metric.error is None:
|
|
counters["completed"][window] += 1
|
|
realized_context = token_ids + result.output_token_ids
|
|
else:
|
|
counters["errored"][window] += 1
|
|
|
|
|
|
async def run_srr(config: SrrConfig) -> dict[str, Any]:
|
|
requests = load_trace(config.trace_path, request_limit=config.request_limit)
|
|
pool = _build_session_pool(requests, config.session_pool_size)
|
|
if not pool:
|
|
raise ValueError(f"empty session pool from trace {config.trace_path}")
|
|
|
|
rng = random.Random(config.rng_seed)
|
|
sink = IncrementalMetricSink(config.output_path)
|
|
request_sem = asyncio.Semaphore(config.concurrency_limit)
|
|
|
|
replay_cfg = ReplayConfig(
|
|
trace_path=config.trace_path,
|
|
output_path=config.output_path,
|
|
endpoint_url=config.endpoint_url,
|
|
model_name=config.model_name,
|
|
concurrency_limit=config.concurrency_limit,
|
|
request_timeout_s=config.request_timeout_s,
|
|
)
|
|
|
|
run_start_unix = time.time()
|
|
warmup_end_unix = run_start_unix + config.warmup_s
|
|
steady_end_unix = warmup_end_unix + config.steady_s
|
|
drain_end_unix = steady_end_unix + config.drain_s
|
|
|
|
def _window(t: float) -> str:
|
|
return _window_for(t, warmup_end_unix, steady_end_unix)
|
|
|
|
counters: dict[str, dict[str, int]] = {
|
|
"attempted": defaultdict(int),
|
|
"completed": defaultdict(int),
|
|
"errored": defaultdict(int),
|
|
}
|
|
inter_arrivals: list[float] = []
|
|
arrival_unix: list[float] = []
|
|
|
|
limits = httpx.Limits(
|
|
max_connections=2000,
|
|
max_keepalive_connections=500,
|
|
keepalive_expiry=30.0,
|
|
)
|
|
|
|
logger.info(
|
|
"SRR start: pool=%d lambda=%.4f sess/s warmup=%.0fs steady=%.0fs drain=%.0fs",
|
|
len(pool), config.arrival_rate,
|
|
config.warmup_s, config.steady_s, config.drain_s,
|
|
)
|
|
|
|
arrival_idx = 0
|
|
session_tasks: list[asyncio.Task] = []
|
|
try:
|
|
async with httpx.AsyncClient(
|
|
timeout=config.request_timeout_s,
|
|
trust_env=False,
|
|
limits=limits,
|
|
) as client:
|
|
while True:
|
|
now = time.time()
|
|
if now >= steady_end_unix:
|
|
break
|
|
dt = rng.expovariate(config.arrival_rate)
|
|
inter_arrivals.append(dt)
|
|
await asyncio.sleep(dt)
|
|
if time.time() >= steady_end_unix:
|
|
break
|
|
template = rng.choice(pool)
|
|
cloned = _clone_session_for_arrival(template, arrival_idx)
|
|
arrival_idx += 1
|
|
arrival_unix.append(time.time())
|
|
task = asyncio.create_task(_run_one_session(
|
|
turns=cloned,
|
|
config=config,
|
|
replay_cfg=replay_cfg,
|
|
client=client,
|
|
request_sem=request_sem,
|
|
sink=sink,
|
|
counters=counters,
|
|
window_for_now=_window,
|
|
deadline_unix=drain_end_unix,
|
|
))
|
|
session_tasks.append(task)
|
|
|
|
# Drain phase: stop accepting new arrivals, wait for in-flight.
|
|
drain_timeout = max(0.0, drain_end_unix - time.time())
|
|
logger.info("SRR drain: waiting up to %.1fs for %d in-flight sessions",
|
|
drain_timeout, sum(1 for t in session_tasks if not t.done()))
|
|
if session_tasks:
|
|
done, pending = await asyncio.wait(
|
|
session_tasks, timeout=drain_timeout,
|
|
)
|
|
for t in pending:
|
|
t.cancel()
|
|
finally:
|
|
sink.close()
|
|
|
|
summary = {
|
|
"run_start_unix": run_start_unix,
|
|
"warmup_end_unix": warmup_end_unix,
|
|
"steady_end_unix": steady_end_unix,
|
|
"drain_end_unix": drain_end_unix,
|
|
"arrival_rate": config.arrival_rate,
|
|
"session_pool_size": len(pool),
|
|
"sessions_arrived": arrival_idx,
|
|
"inter_arrival_mean_s": (
|
|
sum(inter_arrivals) / len(inter_arrivals)
|
|
if inter_arrivals else None
|
|
),
|
|
"attempted": dict(counters["attempted"]),
|
|
"completed": dict(counters["completed"]),
|
|
"errored": dict(counters["errored"]),
|
|
"rng_seed": config.rng_seed,
|
|
}
|
|
summary_path = config.output_path.with_name(
|
|
config.output_path.stem + ".window_summary.json"
|
|
)
|
|
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True))
|
|
logger.info(
|
|
"SRR done: arrived=%d attempted=%s completed=%s errored=%s",
|
|
arrival_idx,
|
|
dict(counters["attempted"]),
|
|
dict(counters["completed"]),
|
|
dict(counters["errored"]),
|
|
)
|
|
return summary
|
|
|
|
|
|
def _parse_args(argv: list[str] | None = None):
|
|
import argparse
|
|
|
|
p = argparse.ArgumentParser(
|
|
description="Open-loop Poisson SRR loadgen with session-causal turns",
|
|
)
|
|
p.add_argument("--trace", type=Path, required=True)
|
|
p.add_argument("--output", type=Path, required=True)
|
|
p.add_argument("--endpoint", type=str, required=True)
|
|
p.add_argument("--model", type=str, default="default")
|
|
p.add_argument("--arrival-rate", type=float, required=True,
|
|
help="Sessions per second (Poisson)")
|
|
p.add_argument("--warmup-s", type=float, default=60.0)
|
|
p.add_argument("--steady-s", type=float, default=300.0)
|
|
p.add_argument("--drain-s", type=float, default=60.0)
|
|
p.add_argument("--concurrency-limit", type=int, default=2000)
|
|
p.add_argument("--request-timeout", type=float, default=600.0)
|
|
p.add_argument("--session-pool-size", type=int, default=None)
|
|
p.add_argument("--request-limit", type=int, default=None)
|
|
p.add_argument("--rng-seed", type=int, default=42)
|
|
p.add_argument("-v", "--verbose", action="store_true")
|
|
return p.parse_args(argv)
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> None:
|
|
args = _parse_args(argv)
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
config = SrrConfig(
|
|
trace_path=args.trace,
|
|
output_path=args.output,
|
|
endpoint_url=args.endpoint.rstrip("/"),
|
|
model_name=args.model,
|
|
arrival_rate=args.arrival_rate,
|
|
warmup_s=args.warmup_s,
|
|
steady_s=args.steady_s,
|
|
drain_s=args.drain_s,
|
|
concurrency_limit=args.concurrency_limit,
|
|
request_timeout_s=args.request_timeout,
|
|
session_pool_size=args.session_pool_size,
|
|
request_limit=args.request_limit,
|
|
rng_seed=args.rng_seed,
|
|
)
|
|
summary = asyncio.run(run_srr(config))
|
|
print(json.dumps(summary, indent=2, sort_keys=True))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|