Files
agentic-kvc/microbench/fresh_setup/gen_synthetic_trace.py
Gahow Wang bad512d3c5 PD-disagg crossover: synthetic-trace generator + morpher + plotter
gen_synthetic_trace (vanilla Poisson, zero prefix reuse — the regime where
PD-disagg is expected to win), mutate_trace (morph reuse/burst/skew toward
the agentic regime), and plot_crossover. Emits the replayer's JSONL schema.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:21 +08:00

242 lines
9.6 KiB
Python

"""Generate synthetic traces for the PD-disagg crossover study.
Emits the same JSONL schema the replayer consumes (chat_id, parent_chat_id,
timestamp, input_length, output_length, type, turn, hash_ids, session_id),
so no replayer change is needed.
Phase A ("vanilla") workload — the textbook regime where PD-disagg is
expected to win:
- Poisson arrivals at a fixed mean QPS.
- Fixed input / output length.
- Every request is its own single-turn session (parent_chat_id = -1).
- hash_ids are globally unique, so there is ZERO prefix-cache reuse and the
prefix-cache confound (PD round-robin loses cache, 8C keeps it) is removed
from the comparison by construction.
Later morph dimensions (multi-turn reuse, burst arrivals, session skew) are
intentionally NOT implemented here yet; this file owns the vanilla baseline.
Usage:
python gen_synthetic_trace.py --out traces/vanilla_q1.6_in1024_out256.jsonl \
--qps 1.6 --duration-s 600 --input-len 1024 --output-len 256 --seed 42
"""
from __future__ import annotations
import argparse
import json
import random
from pathlib import Path
BLOCK_SIZE = 512 # must match replayer.replay.BLOCK_SIZE
# Start unique hash ids well above the real-trace hash range (~1.2e7) so a
# synthetic trace never accidentally shares a block hash with anything else.
HASH_BASE = 1_000_000_000
def n_blocks_for(input_length: int) -> int:
return max(1, input_length // BLOCK_SIZE)
def gen_vanilla(
*,
qps: float,
duration_s: float,
input_len: int,
output_len: int,
seed: int,
) -> list[dict]:
"""Poisson arrivals, fixed lengths, every request a unique single-turn
session with globally-unique block hashes (zero reuse)."""
rng = random.Random(seed)
rows: list[dict] = []
t = 0.0
next_hash = HASH_BASE
chat_id = 1
while True:
# Exponential inter-arrival -> Poisson process at rate `qps`.
t += rng.expovariate(qps)
if t > duration_s:
break
nb = n_blocks_for(input_len)
hash_ids = list(range(next_hash, next_hash + nb))
next_hash += nb
rows.append({
"chat_id": chat_id,
"parent_chat_id": -1,
"timestamp": round(t, 6),
"input_length": input_len,
"output_length": output_len,
"type": "synthetic",
"turn": 1,
"hash_ids": hash_ids,
"session_id": str(chat_id),
})
chat_id += 1
return rows
def _sample_turns(rng: random.Random, turns_mean: float, turns_max: int,
heavy_frac: float) -> int:
"""Geometric-ish turn count, with a heavy-tailed minority (session skew)."""
if heavy_frac > 0 and rng.random() < heavy_frac:
return turns_max
cont = max(0.0, 1.0 - 1.0 / max(turns_mean, 1.0))
t = 1
while t < turns_max and rng.random() < cont:
t += 1
return t
def gen_multiturn(
*,
session_qps: float,
duration_s: float,
turns_mean: float,
turns_max: int,
heavy_frac: float,
first_input: int,
new_user_tokens: int,
output_len: int,
inter_turn_gap_s: float,
seed: int,
) -> list[dict]:
"""Multi-turn agentic-like sessions with intra-session prefix reuse.
Each session's turn k re-sends the whole conversation-so-far as its prompt
(cumulative hash_ids prefix = prior turns' input+output blocks) then appends
`new_user_tokens` of fresh context, so vLLM sees a high intra-session prefix-
cache hit on the growing prefix — exactly the agentic multi-turn pattern.
Context grows each turn; outputs are short; inter-turn gap models think-time.
"""
rng = random.Random(seed)
rows: list[dict] = []
next_hash = HASH_BASE
chat_id = 1
# Generate session start times (Poisson), then expand each into turns.
starts: list[float] = []
t = 0.0
while True:
t += rng.expovariate(session_qps)
if t > duration_s:
break
starts.append(t)
for s_idx, start in enumerate(starts):
session_id = f"s{s_idx}"
n_turns = _sample_turns(rng, turns_mean, turns_max, heavy_frac)
session_hashes: list[int] = [] # cumulative blocks of the conversation
ctx_len = 0 # cumulative prompt tokens (prior turns)
prev_chat = -1
ts = start
for turn in range(1, n_turns + 1):
added = first_input if turn == 1 else new_user_tokens
input_len = ctx_len + added
n_new = max(1, added // BLOCK_SIZE)
new_blocks = list(range(next_hash, next_hash + n_new))
next_hash += n_new
turn_hashes = session_hashes + new_blocks
rows.append({
"chat_id": chat_id,
"parent_chat_id": prev_chat,
"timestamp": round(ts, 6),
"input_length": input_len,
"output_length": output_len,
"type": "synthetic_agentic",
"turn": turn,
"hash_ids": turn_hashes,
"session_id": session_id,
})
# Conversation grows by the new user tokens AND this turn's output.
n_out_blocks = max(1, output_len // BLOCK_SIZE)
session_hashes = turn_hashes + list(range(next_hash, next_hash + n_out_blocks))
next_hash += n_out_blocks
ctx_len = input_len + output_len
prev_chat = chat_id
chat_id += 1
ts += rng.expovariate(1.0 / inter_turn_gap_s) if inter_turn_gap_s > 0 else 0.0
rows.sort(key=lambda r: r["timestamp"])
return rows
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--out", type=Path, required=True, help="output trace JSONL")
p.add_argument("--mode", choices=["vanilla", "multiturn"], default="vanilla")
p.add_argument("--qps", type=float, help="vanilla: mean Poisson request rate; "
"multiturn: mean Poisson SESSION rate")
p.add_argument("--duration-s", type=float, default=600.0, help="trace span (s)")
p.add_argument("--input-len", type=int, help="vanilla: fixed input length")
p.add_argument("--output-len", type=int, required=True)
p.add_argument("--seed", type=int, default=42)
# multiturn knobs
p.add_argument("--turns-mean", type=float, default=4.0)
p.add_argument("--turns-max", type=int, default=40)
p.add_argument("--heavy-frac", type=float, default=0.0,
help="fraction of sessions that are heavy (turns_max) — session skew")
p.add_argument("--first-input", type=int, default=2048,
help="multiturn: turn-1 input length")
p.add_argument("--new-user-tokens", type=int, default=256,
help="multiturn: fresh user tokens added each subsequent turn")
p.add_argument("--inter-turn-gap-s", type=float, default=1.6,
help="multiturn: mean think-time between turns")
args = p.parse_args()
if args.mode == "vanilla":
assert args.qps and args.input_len, "vanilla needs --qps and --input-len"
rows = gen_vanilla(
qps=args.qps, duration_s=args.duration_s,
input_len=args.input_len, output_len=args.output_len, seed=args.seed,
)
cfg = {"mode": "vanilla", "qps": args.qps, "duration_s": args.duration_s,
"input_len": args.input_len, "output_len": args.output_len,
"seed": args.seed, "reuse": "none"}
else:
assert args.qps, "multiturn needs --qps (session rate)"
rows = gen_multiturn(
session_qps=args.qps, duration_s=args.duration_s,
turns_mean=args.turns_mean, turns_max=args.turns_max,
heavy_frac=args.heavy_frac, first_input=args.first_input,
new_user_tokens=args.new_user_tokens, output_len=args.output_len,
inter_turn_gap_s=args.inter_turn_gap_s, seed=args.seed,
)
cfg = {"mode": "multiturn", "session_qps": args.qps,
"duration_s": args.duration_s, "turns_mean": args.turns_mean,
"turns_max": args.turns_max, "heavy_frac": args.heavy_frac,
"first_input": args.first_input, "new_user_tokens": args.new_user_tokens,
"output_len": args.output_len, "inter_turn_gap_s": args.inter_turn_gap_s,
"seed": args.seed, "reuse": "intra-session"}
args.out.parent.mkdir(parents=True, exist_ok=True)
with args.out.open("w", encoding="utf-8") as fh:
for r in rows:
fh.write(json.dumps(r) + "\n")
cfg["n_requests"] = len(rows)
cfg["block_size"] = BLOCK_SIZE
cfg_path = args.out.with_suffix(args.out.suffix + ".config.json")
cfg_path.write_text(json.dumps(cfg, indent=2))
span = rows[-1]["timestamp"] - rows[0]["timestamp"] if rows else 0.0
eff_qps = len(rows) / span if span > 0 else 0.0
print(f"wrote {len(rows)} requests to {args.out} (mode={args.mode})")
print(f" target qps={args.qps} effective req qps={eff_qps:.3f} span={span:.1f}s")
if args.mode == "vanilla":
print(f" input_len={args.input_len} output_len={args.output_len} "
f"(blocks/req={n_blocks_for(args.input_len)}, zero reuse)")
else:
n_sessions = len({r["session_id"] for r in rows})
inputs = sorted(r["input_length"] for r in rows)
p = lambda v, q: v[min(int(q * len(v)), len(v) - 1)] if v else 0
print(f" sessions={n_sessions} turns/session~{len(rows)/max(n_sessions,1):.1f} "
f"input p50={p(inputs,.5)} p90={p(inputs,.9)} p99={p(inputs,.99)} "
f"output_len={args.output_len} (intra-session reuse)")
print(f" config -> {cfg_path}")
if __name__ == "__main__":
main()