PD-disagg crossover: synthetic-trace generator + morpher + plotter

gen_synthetic_trace (vanilla Poisson, zero prefix reuse — the regime where
PD-disagg is expected to win), mutate_trace (morph reuse/burst/skew toward
the agentic regime), and plot_crossover. Emits the replayer's JSONL schema.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 11:53:21 +08:00
parent 41a0c1c48f
commit bad512d3c5
3 changed files with 492 additions and 0 deletions

View File

@@ -0,0 +1,241 @@
"""Generate synthetic traces for the PD-disagg crossover study.
Emits the same JSONL schema the replayer consumes (chat_id, parent_chat_id,
timestamp, input_length, output_length, type, turn, hash_ids, session_id),
so no replayer change is needed.
Phase A ("vanilla") workload — the textbook regime where PD-disagg is
expected to win:
- Poisson arrivals at a fixed mean QPS.
- Fixed input / output length.
- Every request is its own single-turn session (parent_chat_id = -1).
- hash_ids are globally unique, so there is ZERO prefix-cache reuse and the
prefix-cache confound (PD round-robin loses cache, 8C keeps it) is removed
from the comparison by construction.
Later morph dimensions (multi-turn reuse, burst arrivals, session skew) are
intentionally NOT implemented here yet; this file owns the vanilla baseline.
Usage:
python gen_synthetic_trace.py --out traces/vanilla_q1.6_in1024_out256.jsonl \
--qps 1.6 --duration-s 600 --input-len 1024 --output-len 256 --seed 42
"""
from __future__ import annotations
import argparse
import json
import random
from pathlib import Path
BLOCK_SIZE = 512 # must match replayer.replay.BLOCK_SIZE
# Start unique hash ids well above the real-trace hash range (~1.2e7) so a
# synthetic trace never accidentally shares a block hash with anything else.
HASH_BASE = 1_000_000_000
def n_blocks_for(input_length: int) -> int:
return max(1, input_length // BLOCK_SIZE)
def gen_vanilla(
*,
qps: float,
duration_s: float,
input_len: int,
output_len: int,
seed: int,
) -> list[dict]:
"""Poisson arrivals, fixed lengths, every request a unique single-turn
session with globally-unique block hashes (zero reuse)."""
rng = random.Random(seed)
rows: list[dict] = []
t = 0.0
next_hash = HASH_BASE
chat_id = 1
while True:
# Exponential inter-arrival -> Poisson process at rate `qps`.
t += rng.expovariate(qps)
if t > duration_s:
break
nb = n_blocks_for(input_len)
hash_ids = list(range(next_hash, next_hash + nb))
next_hash += nb
rows.append({
"chat_id": chat_id,
"parent_chat_id": -1,
"timestamp": round(t, 6),
"input_length": input_len,
"output_length": output_len,
"type": "synthetic",
"turn": 1,
"hash_ids": hash_ids,
"session_id": str(chat_id),
})
chat_id += 1
return rows
def _sample_turns(rng: random.Random, turns_mean: float, turns_max: int,
heavy_frac: float) -> int:
"""Geometric-ish turn count, with a heavy-tailed minority (session skew)."""
if heavy_frac > 0 and rng.random() < heavy_frac:
return turns_max
cont = max(0.0, 1.0 - 1.0 / max(turns_mean, 1.0))
t = 1
while t < turns_max and rng.random() < cont:
t += 1
return t
def gen_multiturn(
*,
session_qps: float,
duration_s: float,
turns_mean: float,
turns_max: int,
heavy_frac: float,
first_input: int,
new_user_tokens: int,
output_len: int,
inter_turn_gap_s: float,
seed: int,
) -> list[dict]:
"""Multi-turn agentic-like sessions with intra-session prefix reuse.
Each session's turn k re-sends the whole conversation-so-far as its prompt
(cumulative hash_ids prefix = prior turns' input+output blocks) then appends
`new_user_tokens` of fresh context, so vLLM sees a high intra-session prefix-
cache hit on the growing prefix — exactly the agentic multi-turn pattern.
Context grows each turn; outputs are short; inter-turn gap models think-time.
"""
rng = random.Random(seed)
rows: list[dict] = []
next_hash = HASH_BASE
chat_id = 1
# Generate session start times (Poisson), then expand each into turns.
starts: list[float] = []
t = 0.0
while True:
t += rng.expovariate(session_qps)
if t > duration_s:
break
starts.append(t)
for s_idx, start in enumerate(starts):
session_id = f"s{s_idx}"
n_turns = _sample_turns(rng, turns_mean, turns_max, heavy_frac)
session_hashes: list[int] = [] # cumulative blocks of the conversation
ctx_len = 0 # cumulative prompt tokens (prior turns)
prev_chat = -1
ts = start
for turn in range(1, n_turns + 1):
added = first_input if turn == 1 else new_user_tokens
input_len = ctx_len + added
n_new = max(1, added // BLOCK_SIZE)
new_blocks = list(range(next_hash, next_hash + n_new))
next_hash += n_new
turn_hashes = session_hashes + new_blocks
rows.append({
"chat_id": chat_id,
"parent_chat_id": prev_chat,
"timestamp": round(ts, 6),
"input_length": input_len,
"output_length": output_len,
"type": "synthetic_agentic",
"turn": turn,
"hash_ids": turn_hashes,
"session_id": session_id,
})
# Conversation grows by the new user tokens AND this turn's output.
n_out_blocks = max(1, output_len // BLOCK_SIZE)
session_hashes = turn_hashes + list(range(next_hash, next_hash + n_out_blocks))
next_hash += n_out_blocks
ctx_len = input_len + output_len
prev_chat = chat_id
chat_id += 1
ts += rng.expovariate(1.0 / inter_turn_gap_s) if inter_turn_gap_s > 0 else 0.0
rows.sort(key=lambda r: r["timestamp"])
return rows
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--out", type=Path, required=True, help="output trace JSONL")
p.add_argument("--mode", choices=["vanilla", "multiturn"], default="vanilla")
p.add_argument("--qps", type=float, help="vanilla: mean Poisson request rate; "
"multiturn: mean Poisson SESSION rate")
p.add_argument("--duration-s", type=float, default=600.0, help="trace span (s)")
p.add_argument("--input-len", type=int, help="vanilla: fixed input length")
p.add_argument("--output-len", type=int, required=True)
p.add_argument("--seed", type=int, default=42)
# multiturn knobs
p.add_argument("--turns-mean", type=float, default=4.0)
p.add_argument("--turns-max", type=int, default=40)
p.add_argument("--heavy-frac", type=float, default=0.0,
help="fraction of sessions that are heavy (turns_max) — session skew")
p.add_argument("--first-input", type=int, default=2048,
help="multiturn: turn-1 input length")
p.add_argument("--new-user-tokens", type=int, default=256,
help="multiturn: fresh user tokens added each subsequent turn")
p.add_argument("--inter-turn-gap-s", type=float, default=1.6,
help="multiturn: mean think-time between turns")
args = p.parse_args()
if args.mode == "vanilla":
assert args.qps and args.input_len, "vanilla needs --qps and --input-len"
rows = gen_vanilla(
qps=args.qps, duration_s=args.duration_s,
input_len=args.input_len, output_len=args.output_len, seed=args.seed,
)
cfg = {"mode": "vanilla", "qps": args.qps, "duration_s": args.duration_s,
"input_len": args.input_len, "output_len": args.output_len,
"seed": args.seed, "reuse": "none"}
else:
assert args.qps, "multiturn needs --qps (session rate)"
rows = gen_multiturn(
session_qps=args.qps, duration_s=args.duration_s,
turns_mean=args.turns_mean, turns_max=args.turns_max,
heavy_frac=args.heavy_frac, first_input=args.first_input,
new_user_tokens=args.new_user_tokens, output_len=args.output_len,
inter_turn_gap_s=args.inter_turn_gap_s, seed=args.seed,
)
cfg = {"mode": "multiturn", "session_qps": args.qps,
"duration_s": args.duration_s, "turns_mean": args.turns_mean,
"turns_max": args.turns_max, "heavy_frac": args.heavy_frac,
"first_input": args.first_input, "new_user_tokens": args.new_user_tokens,
"output_len": args.output_len, "inter_turn_gap_s": args.inter_turn_gap_s,
"seed": args.seed, "reuse": "intra-session"}
args.out.parent.mkdir(parents=True, exist_ok=True)
with args.out.open("w", encoding="utf-8") as fh:
for r in rows:
fh.write(json.dumps(r) + "\n")
cfg["n_requests"] = len(rows)
cfg["block_size"] = BLOCK_SIZE
cfg_path = args.out.with_suffix(args.out.suffix + ".config.json")
cfg_path.write_text(json.dumps(cfg, indent=2))
span = rows[-1]["timestamp"] - rows[0]["timestamp"] if rows else 0.0
eff_qps = len(rows) / span if span > 0 else 0.0
print(f"wrote {len(rows)} requests to {args.out} (mode={args.mode})")
print(f" target qps={args.qps} effective req qps={eff_qps:.3f} span={span:.1f}s")
if args.mode == "vanilla":
print(f" input_len={args.input_len} output_len={args.output_len} "
f"(blocks/req={n_blocks_for(args.input_len)}, zero reuse)")
else:
n_sessions = len({r["session_id"] for r in rows})
inputs = sorted(r["input_length"] for r in rows)
p = lambda v, q: v[min(int(q * len(v)), len(v) - 1)] if v else 0
print(f" sessions={n_sessions} turns/session~{len(rows)/max(n_sessions,1):.1f} "
f"input p50={p(inputs,.5)} p90={p(inputs,.9)} p99={p(inputs,.99)} "
f"output_len={args.output_len} (intra-session reuse)")
print(f" config -> {cfg_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,136 @@
"""Reverse-ablation trace surgeon for the PD-disagg crossover study.
Takes the REAL agentic trace and neutralizes ONE agentic property at a time,
so we can see which removal restores PD-disagg to parity with colocation.
This is the subtractive complement to the additive synthetic sweep (D1-D5).
Neutralizations (compose freely; each defaults to off):
--set-output N set every output_length to N (kill short-output -> test decode-benefit starvation)
--max-input N clamp input_length to N, truncating hash_ids to ceil(N/512)
(kill huge-prefill -> test prefill-bound bottleneck)
--uniform-arrival respace requests evenly over the original span, order preserved
(kill burstiness -> test arrival variance)
--unique-hash replace all hash_ids with globally-unique ids
(kill intra-session reuse -> test cache/affinity)
--max-turns N keep only the first N turns of each session
(flatten session skew / heavy tail)
The schema is preserved so the replayer consumes the output unchanged.
"""
from __future__ import annotations
import argparse
import json
import math
from collections import defaultdict
from pathlib import Path
BLOCK_SIZE = 512
UNIQUE_HASH_BASE = 2_000_000_000
def load_rows(path: Path) -> list[dict]:
rows = []
with path.open() as fh:
for line in fh:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def resolve_session(row: dict, chat_to_session: dict) -> str:
if "session_id" in row:
return str(row["session_id"])
cid, pid = int(row["chat_id"]), int(row["parent_chat_id"])
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
chat_to_session[cid] = sid
return sid
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--set-output", type=int, default=None)
p.add_argument("--max-input", type=int, default=None)
p.add_argument("--uniform-arrival", action="store_true")
p.add_argument("--unique-hash", action="store_true")
p.add_argument("--max-turns", type=int, default=None)
args = p.parse_args()
rows = load_rows(args.input)
# ensure session_id present
c2s: dict = {}
for r in rows:
r["session_id"] = resolve_session(r, c2s)
applied = []
# --- flatten session skew: keep first N turns per session ---
if args.max_turns is not None:
kept_count: dict = defaultdict(int)
rows.sort(key=lambda r: (r["session_id"], r.get("turn", 0)))
out = []
for r in rows:
if kept_count[r["session_id"]] < args.max_turns:
kept_count[r["session_id"]] += 1
out.append(r)
rows = out
rows.sort(key=lambda r: r.get("timestamp", 0.0))
applied.append(f"max_turns={args.max_turns}")
# --- clamp input + truncate hash_ids ---
if args.max_input is not None:
for r in rows:
if r["input_length"] > args.max_input:
r["input_length"] = args.max_input
keep_blocks = max(1, math.ceil(r["input_length"] / BLOCK_SIZE))
r["hash_ids"] = list(r.get("hash_ids", []))[:keep_blocks]
applied.append(f"max_input={args.max_input}")
# --- set fixed output length ---
if args.set_output is not None:
for r in rows:
r["output_length"] = args.set_output
applied.append(f"set_output={args.set_output}")
# --- kill reuse: globally-unique hashes ---
if args.unique_hash:
nxt = UNIQUE_HASH_BASE
for r in rows:
n = max(1, len(r.get("hash_ids", [])) or
math.ceil(r["input_length"] / BLOCK_SIZE))
r["hash_ids"] = list(range(nxt, nxt + n))
nxt += n
applied.append("unique_hash")
# --- de-burst: uniform arrival over original span (order preserved) ---
if args.uniform_arrival:
rows.sort(key=lambda r: r.get("timestamp", 0.0))
ts = [r.get("timestamp", 0.0) for r in rows]
span = (ts[-1] - ts[0]) if len(ts) > 1 else 0.0
n = len(rows)
for i, r in enumerate(rows):
r["timestamp"] = round(ts[0] + (span * i / max(n - 1, 1)), 6)
applied.append("uniform_arrival")
rows.sort(key=lambda r: r.get("timestamp", 0.0))
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w") as fh:
for r in rows:
fh.write(json.dumps(r) + "\n")
inputs = sorted(r["input_length"] for r in rows)
outs = sorted(r["output_length"] for r in rows)
q = lambda v, p: v[min(int(p * len(v)), len(v) - 1)] if v else 0
print(f"wrote {len(rows)} rows -> {args.output}")
print(f" neutralized: {applied or ['none (passthrough)']}")
print(f" input p50={q(inputs,.5)} p90={q(inputs,.9)} p99={q(inputs,.99)} "
f"output p50={q(outs,.5)} p90={q(outs,.9)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
"""Render the PD-disagg crossover figure from analyze_goodput.py JSONs.
Two sweeps bracket the prefill<->decode bottleneck axis:
D2: fixed input, grow OUTPUT -> decode-bound -> PD_advantage rises above 1
D1: fixed output, grow INPUT -> prefill-bound -> PD_advantage falls below 1
Top row: PD_advantage (4P+4D / colo SLO-goodput) vs swept dim, y=1 = crossover.
Bottom row: completion rate, colo vs 4P+4D.
Agentic operating region (input p50~33k, output p50~92) annotated.
"""
from __future__ import annotations
import json
import re
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
CROSS = Path("analysis/crossover")
OUT = Path("figs/crossover_pd_advantage.png")
def load_series(prefix: str, key_re: str):
pts = []
for f in sorted(CROSS.glob(f"{prefix}*_goodput.json")):
m = re.search(key_re, f.name)
if not m:
continue
d = json.loads(f.read_text())
g = d["slo_grid"][0]["arms"]
pts.append({
"x": int(m.group(1)),
"adv": g["4P+4D"]["pd_advantage"],
"colo_att": g["8C-proxy"]["attainment"],
"pd_att": g["4P+4D"]["attainment"],
"colo_compl": d["arms"]["8C-proxy"]["completion_rate"],
"pd_compl": d["arms"]["4P+4D"]["completion_rate"],
"colo_ampl": d["arms"]["8C-proxy"]["amplification"],
"pd_ampl": d["arms"]["4P+4D"]["amplification"],
})
pts.sort(key=lambda p: p["x"])
return pts
def main():
d2 = load_series("d2_o", r"d2_o(\d+)_") # x = output length
d1 = load_series("d1_i", r"d1_i(\d+)_") # x = input length
fig, ax = plt.subplots(2, 2, figsize=(12, 8))
def adv_plot(a, pts, xlabel, title, agentic_x, agentic_label):
CAP = 50.0 # display cap for "colo=0, PD>0" (PD wins outright)
line_x, line_y = [], []
for p in pts:
adv, colo_att, pd_att = p["adv"], p["colo_att"], p["pd_att"]
finite = isinstance(adv, (int, float)) and adv == adv
if finite:
line_x.append(p["x"]); line_y.append(max(adv, 1e-3))
elif colo_att == 0 and pd_att > 0: # PD wins outright
line_x.append(p["x"]); line_y.append(CAP)
a.annotate("PD wins\n(colo=0)", (p["x"], CAP), fontsize=7, ha="center", va="top")
else: # both fail SLO -> not a PD win
a.scatter([p["x"]], [1.0], marker="x", color="gray", zorder=5)
a.annotate("both\nfail SLO", (p["x"], 1.0), fontsize=7, ha="center",
va="bottom", color="gray")
a.plot(line_x, line_y, "o-", color="tab:blue", lw=2)
a.axhline(1.0, color="k", ls="--", lw=1, label="crossover (PD=colo)")
a.set_xscale("log"); a.set_yscale("log")
a.set_xlabel(xlabel)
a.set_ylabel("PD_advantage = goodput(4P+4D)/goodput(colo)")
a.set_title(title)
a.axvline(agentic_x, color="tab:red", ls=":", lw=1.5)
a.annotate(agentic_label, (agentic_x, 1.2), color="tab:red", fontsize=8,
rotation=90, va="bottom", ha="right")
a.legend(fontsize=8, loc="best")
a.grid(True, which="both", alpha=0.3)
def compl_plot(a, pts, xlabel, title):
xs = [p["x"] for p in pts]
a.plot(xs, [100*p["colo_compl"] for p in pts], "s-", color="tab:orange", label="colo (8C-proxy)")
a.plot(xs, [100*p["pd_compl"] for p in pts], "o-", color="tab:blue", label="PD (4P+4D)")
a.set_xscale("log")
a.set_xlabel(xlabel)
a.set_ylabel("completion rate (%)")
a.set_title(title)
a.set_ylim(0, 105)
a.legend(fontsize=8, loc="best")
a.grid(True, alpha=0.3)
adv_plot(ax[0][0], d2, "output length (tokens), fixed input=2048, q12",
"D2 decode-bound sweep — PD wins as output grows",
92, "agentic out~92")
adv_plot(ax[0][1], d1, "input length (tokens), fixed output=64, q4",
"D1 prefill-bound sweep — PD collapses as input grows",
33533, "agentic in~33k")
compl_plot(ax[1][0], d2, "output length (tokens)", "D2 completion")
compl_plot(ax[1][1], d1, "input length (tokens)", "D1 completion")
fig.suptitle("PD-disaggregation vs colocation: the prefill<->decode bottleneck crossover\n"
"(single node 8xH20, vLLM 0.18.1 chunked-prefill; zero-reuse synthetic)",
fontsize=11)
fig.tight_layout(rect=[0, 0, 1, 0.96])
OUT.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(OUT, dpi=130)
print(f"wrote {OUT}")
# also dump the numeric series for the record
print("D2 (output -> adv):", [(p["x"], round(p["adv"],2) if p["adv"]==p["adv"] else "inf") for p in d2])
print("D1 (input -> adv):", [(p["x"], round(p["adv"],2) if p["adv"]==p["adv"] else "inf") for p in d1])
if __name__ == "__main__":
main()