PD-disagg crossover: synthetic-trace generator + morpher + plotter
gen_synthetic_trace (vanilla Poisson, zero prefix reuse — the regime where PD-disagg is expected to win), mutate_trace (morph reuse/burst/skew toward the agentic regime), and plot_crossover. Emits the replayer's JSONL schema. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
241
microbench/fresh_setup/gen_synthetic_trace.py
Normal file
241
microbench/fresh_setup/gen_synthetic_trace.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""Generate synthetic traces for the PD-disagg crossover study.
|
||||||
|
|
||||||
|
Emits the same JSONL schema the replayer consumes (chat_id, parent_chat_id,
|
||||||
|
timestamp, input_length, output_length, type, turn, hash_ids, session_id),
|
||||||
|
so no replayer change is needed.
|
||||||
|
|
||||||
|
Phase A ("vanilla") workload — the textbook regime where PD-disagg is
|
||||||
|
expected to win:
|
||||||
|
- Poisson arrivals at a fixed mean QPS.
|
||||||
|
- Fixed input / output length.
|
||||||
|
- Every request is its own single-turn session (parent_chat_id = -1).
|
||||||
|
- hash_ids are globally unique, so there is ZERO prefix-cache reuse and the
|
||||||
|
prefix-cache confound (PD round-robin loses cache, 8C keeps it) is removed
|
||||||
|
from the comparison by construction.
|
||||||
|
|
||||||
|
Later morph dimensions (multi-turn reuse, burst arrivals, session skew) are
|
||||||
|
intentionally NOT implemented here yet; this file owns the vanilla baseline.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python gen_synthetic_trace.py --out traces/vanilla_q1.6_in1024_out256.jsonl \
|
||||||
|
--qps 1.6 --duration-s 600 --input-len 1024 --output-len 256 --seed 42
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512 # must match replayer.replay.BLOCK_SIZE
|
||||||
|
# Start unique hash ids well above the real-trace hash range (~1.2e7) so a
|
||||||
|
# synthetic trace never accidentally shares a block hash with anything else.
|
||||||
|
HASH_BASE = 1_000_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def n_blocks_for(input_length: int) -> int:
|
||||||
|
return max(1, input_length // BLOCK_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_vanilla(
|
||||||
|
*,
|
||||||
|
qps: float,
|
||||||
|
duration_s: float,
|
||||||
|
input_len: int,
|
||||||
|
output_len: int,
|
||||||
|
seed: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Poisson arrivals, fixed lengths, every request a unique single-turn
|
||||||
|
session with globally-unique block hashes (zero reuse)."""
|
||||||
|
rng = random.Random(seed)
|
||||||
|
rows: list[dict] = []
|
||||||
|
t = 0.0
|
||||||
|
next_hash = HASH_BASE
|
||||||
|
chat_id = 1
|
||||||
|
while True:
|
||||||
|
# Exponential inter-arrival -> Poisson process at rate `qps`.
|
||||||
|
t += rng.expovariate(qps)
|
||||||
|
if t > duration_s:
|
||||||
|
break
|
||||||
|
nb = n_blocks_for(input_len)
|
||||||
|
hash_ids = list(range(next_hash, next_hash + nb))
|
||||||
|
next_hash += nb
|
||||||
|
rows.append({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"parent_chat_id": -1,
|
||||||
|
"timestamp": round(t, 6),
|
||||||
|
"input_length": input_len,
|
||||||
|
"output_length": output_len,
|
||||||
|
"type": "synthetic",
|
||||||
|
"turn": 1,
|
||||||
|
"hash_ids": hash_ids,
|
||||||
|
"session_id": str(chat_id),
|
||||||
|
})
|
||||||
|
chat_id += 1
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_turns(rng: random.Random, turns_mean: float, turns_max: int,
|
||||||
|
heavy_frac: float) -> int:
|
||||||
|
"""Geometric-ish turn count, with a heavy-tailed minority (session skew)."""
|
||||||
|
if heavy_frac > 0 and rng.random() < heavy_frac:
|
||||||
|
return turns_max
|
||||||
|
cont = max(0.0, 1.0 - 1.0 / max(turns_mean, 1.0))
|
||||||
|
t = 1
|
||||||
|
while t < turns_max and rng.random() < cont:
|
||||||
|
t += 1
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def gen_multiturn(
|
||||||
|
*,
|
||||||
|
session_qps: float,
|
||||||
|
duration_s: float,
|
||||||
|
turns_mean: float,
|
||||||
|
turns_max: int,
|
||||||
|
heavy_frac: float,
|
||||||
|
first_input: int,
|
||||||
|
new_user_tokens: int,
|
||||||
|
output_len: int,
|
||||||
|
inter_turn_gap_s: float,
|
||||||
|
seed: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Multi-turn agentic-like sessions with intra-session prefix reuse.
|
||||||
|
|
||||||
|
Each session's turn k re-sends the whole conversation-so-far as its prompt
|
||||||
|
(cumulative hash_ids prefix = prior turns' input+output blocks) then appends
|
||||||
|
`new_user_tokens` of fresh context, so vLLM sees a high intra-session prefix-
|
||||||
|
cache hit on the growing prefix — exactly the agentic multi-turn pattern.
|
||||||
|
Context grows each turn; outputs are short; inter-turn gap models think-time.
|
||||||
|
"""
|
||||||
|
rng = random.Random(seed)
|
||||||
|
rows: list[dict] = []
|
||||||
|
next_hash = HASH_BASE
|
||||||
|
chat_id = 1
|
||||||
|
|
||||||
|
# Generate session start times (Poisson), then expand each into turns.
|
||||||
|
starts: list[float] = []
|
||||||
|
t = 0.0
|
||||||
|
while True:
|
||||||
|
t += rng.expovariate(session_qps)
|
||||||
|
if t > duration_s:
|
||||||
|
break
|
||||||
|
starts.append(t)
|
||||||
|
|
||||||
|
for s_idx, start in enumerate(starts):
|
||||||
|
session_id = f"s{s_idx}"
|
||||||
|
n_turns = _sample_turns(rng, turns_mean, turns_max, heavy_frac)
|
||||||
|
session_hashes: list[int] = [] # cumulative blocks of the conversation
|
||||||
|
ctx_len = 0 # cumulative prompt tokens (prior turns)
|
||||||
|
prev_chat = -1
|
||||||
|
ts = start
|
||||||
|
for turn in range(1, n_turns + 1):
|
||||||
|
added = first_input if turn == 1 else new_user_tokens
|
||||||
|
input_len = ctx_len + added
|
||||||
|
n_new = max(1, added // BLOCK_SIZE)
|
||||||
|
new_blocks = list(range(next_hash, next_hash + n_new))
|
||||||
|
next_hash += n_new
|
||||||
|
turn_hashes = session_hashes + new_blocks
|
||||||
|
rows.append({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"parent_chat_id": prev_chat,
|
||||||
|
"timestamp": round(ts, 6),
|
||||||
|
"input_length": input_len,
|
||||||
|
"output_length": output_len,
|
||||||
|
"type": "synthetic_agentic",
|
||||||
|
"turn": turn,
|
||||||
|
"hash_ids": turn_hashes,
|
||||||
|
"session_id": session_id,
|
||||||
|
})
|
||||||
|
# Conversation grows by the new user tokens AND this turn's output.
|
||||||
|
n_out_blocks = max(1, output_len // BLOCK_SIZE)
|
||||||
|
session_hashes = turn_hashes + list(range(next_hash, next_hash + n_out_blocks))
|
||||||
|
next_hash += n_out_blocks
|
||||||
|
ctx_len = input_len + output_len
|
||||||
|
prev_chat = chat_id
|
||||||
|
chat_id += 1
|
||||||
|
ts += rng.expovariate(1.0 / inter_turn_gap_s) if inter_turn_gap_s > 0 else 0.0
|
||||||
|
|
||||||
|
rows.sort(key=lambda r: r["timestamp"])
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--out", type=Path, required=True, help="output trace JSONL")
|
||||||
|
p.add_argument("--mode", choices=["vanilla", "multiturn"], default="vanilla")
|
||||||
|
p.add_argument("--qps", type=float, help="vanilla: mean Poisson request rate; "
|
||||||
|
"multiturn: mean Poisson SESSION rate")
|
||||||
|
p.add_argument("--duration-s", type=float, default=600.0, help="trace span (s)")
|
||||||
|
p.add_argument("--input-len", type=int, help="vanilla: fixed input length")
|
||||||
|
p.add_argument("--output-len", type=int, required=True)
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
# multiturn knobs
|
||||||
|
p.add_argument("--turns-mean", type=float, default=4.0)
|
||||||
|
p.add_argument("--turns-max", type=int, default=40)
|
||||||
|
p.add_argument("--heavy-frac", type=float, default=0.0,
|
||||||
|
help="fraction of sessions that are heavy (turns_max) — session skew")
|
||||||
|
p.add_argument("--first-input", type=int, default=2048,
|
||||||
|
help="multiturn: turn-1 input length")
|
||||||
|
p.add_argument("--new-user-tokens", type=int, default=256,
|
||||||
|
help="multiturn: fresh user tokens added each subsequent turn")
|
||||||
|
p.add_argument("--inter-turn-gap-s", type=float, default=1.6,
|
||||||
|
help="multiturn: mean think-time between turns")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
if args.mode == "vanilla":
|
||||||
|
assert args.qps and args.input_len, "vanilla needs --qps and --input-len"
|
||||||
|
rows = gen_vanilla(
|
||||||
|
qps=args.qps, duration_s=args.duration_s,
|
||||||
|
input_len=args.input_len, output_len=args.output_len, seed=args.seed,
|
||||||
|
)
|
||||||
|
cfg = {"mode": "vanilla", "qps": args.qps, "duration_s": args.duration_s,
|
||||||
|
"input_len": args.input_len, "output_len": args.output_len,
|
||||||
|
"seed": args.seed, "reuse": "none"}
|
||||||
|
else:
|
||||||
|
assert args.qps, "multiturn needs --qps (session rate)"
|
||||||
|
rows = gen_multiturn(
|
||||||
|
session_qps=args.qps, duration_s=args.duration_s,
|
||||||
|
turns_mean=args.turns_mean, turns_max=args.turns_max,
|
||||||
|
heavy_frac=args.heavy_frac, first_input=args.first_input,
|
||||||
|
new_user_tokens=args.new_user_tokens, output_len=args.output_len,
|
||||||
|
inter_turn_gap_s=args.inter_turn_gap_s, seed=args.seed,
|
||||||
|
)
|
||||||
|
cfg = {"mode": "multiturn", "session_qps": args.qps,
|
||||||
|
"duration_s": args.duration_s, "turns_mean": args.turns_mean,
|
||||||
|
"turns_max": args.turns_max, "heavy_frac": args.heavy_frac,
|
||||||
|
"first_input": args.first_input, "new_user_tokens": args.new_user_tokens,
|
||||||
|
"output_len": args.output_len, "inter_turn_gap_s": args.inter_turn_gap_s,
|
||||||
|
"seed": args.seed, "reuse": "intra-session"}
|
||||||
|
|
||||||
|
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with args.out.open("w", encoding="utf-8") as fh:
|
||||||
|
for r in rows:
|
||||||
|
fh.write(json.dumps(r) + "\n")
|
||||||
|
|
||||||
|
cfg["n_requests"] = len(rows)
|
||||||
|
cfg["block_size"] = BLOCK_SIZE
|
||||||
|
cfg_path = args.out.with_suffix(args.out.suffix + ".config.json")
|
||||||
|
cfg_path.write_text(json.dumps(cfg, indent=2))
|
||||||
|
|
||||||
|
span = rows[-1]["timestamp"] - rows[0]["timestamp"] if rows else 0.0
|
||||||
|
eff_qps = len(rows) / span if span > 0 else 0.0
|
||||||
|
print(f"wrote {len(rows)} requests to {args.out} (mode={args.mode})")
|
||||||
|
print(f" target qps={args.qps} effective req qps={eff_qps:.3f} span={span:.1f}s")
|
||||||
|
if args.mode == "vanilla":
|
||||||
|
print(f" input_len={args.input_len} output_len={args.output_len} "
|
||||||
|
f"(blocks/req={n_blocks_for(args.input_len)}, zero reuse)")
|
||||||
|
else:
|
||||||
|
n_sessions = len({r["session_id"] for r in rows})
|
||||||
|
inputs = sorted(r["input_length"] for r in rows)
|
||||||
|
p = lambda v, q: v[min(int(q * len(v)), len(v) - 1)] if v else 0
|
||||||
|
print(f" sessions={n_sessions} turns/session~{len(rows)/max(n_sessions,1):.1f} "
|
||||||
|
f"input p50={p(inputs,.5)} p90={p(inputs,.9)} p99={p(inputs,.99)} "
|
||||||
|
f"output_len={args.output_len} (intra-session reuse)")
|
||||||
|
print(f" config -> {cfg_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
136
microbench/fresh_setup/mutate_trace.py
Normal file
136
microbench/fresh_setup/mutate_trace.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""Reverse-ablation trace surgeon for the PD-disagg crossover study.
|
||||||
|
|
||||||
|
Takes the REAL agentic trace and neutralizes ONE agentic property at a time,
|
||||||
|
so we can see which removal restores PD-disagg to parity with colocation.
|
||||||
|
This is the subtractive complement to the additive synthetic sweep (D1-D5).
|
||||||
|
|
||||||
|
Neutralizations (compose freely; each defaults to off):
|
||||||
|
--set-output N set every output_length to N (kill short-output -> test decode-benefit starvation)
|
||||||
|
--max-input N clamp input_length to N, truncating hash_ids to ceil(N/512)
|
||||||
|
(kill huge-prefill -> test prefill-bound bottleneck)
|
||||||
|
--uniform-arrival respace requests evenly over the original span, order preserved
|
||||||
|
(kill burstiness -> test arrival variance)
|
||||||
|
--unique-hash replace all hash_ids with globally-unique ids
|
||||||
|
(kill intra-session reuse -> test cache/affinity)
|
||||||
|
--max-turns N keep only the first N turns of each session
|
||||||
|
(flatten session skew / heavy tail)
|
||||||
|
|
||||||
|
The schema is preserved so the replayer consumes the output unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
UNIQUE_HASH_BASE = 2_000_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def load_rows(path: Path) -> list[dict]:
|
||||||
|
rows = []
|
||||||
|
with path.open() as fh:
|
||||||
|
for line in fh:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
rows.append(json.loads(line))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_session(row: dict, chat_to_session: dict) -> str:
|
||||||
|
if "session_id" in row:
|
||||||
|
return str(row["session_id"])
|
||||||
|
cid, pid = int(row["chat_id"]), int(row["parent_chat_id"])
|
||||||
|
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
|
||||||
|
chat_to_session[cid] = sid
|
||||||
|
return sid
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--input", type=Path, required=True)
|
||||||
|
p.add_argument("--output", type=Path, required=True)
|
||||||
|
p.add_argument("--set-output", type=int, default=None)
|
||||||
|
p.add_argument("--max-input", type=int, default=None)
|
||||||
|
p.add_argument("--uniform-arrival", action="store_true")
|
||||||
|
p.add_argument("--unique-hash", action="store_true")
|
||||||
|
p.add_argument("--max-turns", type=int, default=None)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
rows = load_rows(args.input)
|
||||||
|
# ensure session_id present
|
||||||
|
c2s: dict = {}
|
||||||
|
for r in rows:
|
||||||
|
r["session_id"] = resolve_session(r, c2s)
|
||||||
|
|
||||||
|
applied = []
|
||||||
|
|
||||||
|
# --- flatten session skew: keep first N turns per session ---
|
||||||
|
if args.max_turns is not None:
|
||||||
|
kept_count: dict = defaultdict(int)
|
||||||
|
rows.sort(key=lambda r: (r["session_id"], r.get("turn", 0)))
|
||||||
|
out = []
|
||||||
|
for r in rows:
|
||||||
|
if kept_count[r["session_id"]] < args.max_turns:
|
||||||
|
kept_count[r["session_id"]] += 1
|
||||||
|
out.append(r)
|
||||||
|
rows = out
|
||||||
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
||||||
|
applied.append(f"max_turns={args.max_turns}")
|
||||||
|
|
||||||
|
# --- clamp input + truncate hash_ids ---
|
||||||
|
if args.max_input is not None:
|
||||||
|
for r in rows:
|
||||||
|
if r["input_length"] > args.max_input:
|
||||||
|
r["input_length"] = args.max_input
|
||||||
|
keep_blocks = max(1, math.ceil(r["input_length"] / BLOCK_SIZE))
|
||||||
|
r["hash_ids"] = list(r.get("hash_ids", []))[:keep_blocks]
|
||||||
|
applied.append(f"max_input={args.max_input}")
|
||||||
|
|
||||||
|
# --- set fixed output length ---
|
||||||
|
if args.set_output is not None:
|
||||||
|
for r in rows:
|
||||||
|
r["output_length"] = args.set_output
|
||||||
|
applied.append(f"set_output={args.set_output}")
|
||||||
|
|
||||||
|
# --- kill reuse: globally-unique hashes ---
|
||||||
|
if args.unique_hash:
|
||||||
|
nxt = UNIQUE_HASH_BASE
|
||||||
|
for r in rows:
|
||||||
|
n = max(1, len(r.get("hash_ids", [])) or
|
||||||
|
math.ceil(r["input_length"] / BLOCK_SIZE))
|
||||||
|
r["hash_ids"] = list(range(nxt, nxt + n))
|
||||||
|
nxt += n
|
||||||
|
applied.append("unique_hash")
|
||||||
|
|
||||||
|
# --- de-burst: uniform arrival over original span (order preserved) ---
|
||||||
|
if args.uniform_arrival:
|
||||||
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
||||||
|
ts = [r.get("timestamp", 0.0) for r in rows]
|
||||||
|
span = (ts[-1] - ts[0]) if len(ts) > 1 else 0.0
|
||||||
|
n = len(rows)
|
||||||
|
for i, r in enumerate(rows):
|
||||||
|
r["timestamp"] = round(ts[0] + (span * i / max(n - 1, 1)), 6)
|
||||||
|
applied.append("uniform_arrival")
|
||||||
|
|
||||||
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
||||||
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with args.output.open("w") as fh:
|
||||||
|
for r in rows:
|
||||||
|
fh.write(json.dumps(r) + "\n")
|
||||||
|
|
||||||
|
inputs = sorted(r["input_length"] for r in rows)
|
||||||
|
outs = sorted(r["output_length"] for r in rows)
|
||||||
|
q = lambda v, p: v[min(int(p * len(v)), len(v) - 1)] if v else 0
|
||||||
|
print(f"wrote {len(rows)} rows -> {args.output}")
|
||||||
|
print(f" neutralized: {applied or ['none (passthrough)']}")
|
||||||
|
print(f" input p50={q(inputs,.5)} p90={q(inputs,.9)} p99={q(inputs,.99)} "
|
||||||
|
f"output p50={q(outs,.5)} p90={q(outs,.9)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
115
microbench/fresh_setup/plot_crossover.py
Normal file
115
microbench/fresh_setup/plot_crossover.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""Render the PD-disagg crossover figure from analyze_goodput.py JSONs.
|
||||||
|
|
||||||
|
Two sweeps bracket the prefill<->decode bottleneck axis:
|
||||||
|
D2: fixed input, grow OUTPUT -> decode-bound -> PD_advantage rises above 1
|
||||||
|
D1: fixed output, grow INPUT -> prefill-bound -> PD_advantage falls below 1
|
||||||
|
|
||||||
|
Top row: PD_advantage (4P+4D / colo SLO-goodput) vs swept dim, y=1 = crossover.
|
||||||
|
Bottom row: completion rate, colo vs 4P+4D.
|
||||||
|
Agentic operating region (input p50~33k, output p50~92) annotated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
CROSS = Path("analysis/crossover")
|
||||||
|
OUT = Path("figs/crossover_pd_advantage.png")
|
||||||
|
|
||||||
|
|
||||||
|
def load_series(prefix: str, key_re: str):
|
||||||
|
pts = []
|
||||||
|
for f in sorted(CROSS.glob(f"{prefix}*_goodput.json")):
|
||||||
|
m = re.search(key_re, f.name)
|
||||||
|
if not m:
|
||||||
|
continue
|
||||||
|
d = json.loads(f.read_text())
|
||||||
|
g = d["slo_grid"][0]["arms"]
|
||||||
|
pts.append({
|
||||||
|
"x": int(m.group(1)),
|
||||||
|
"adv": g["4P+4D"]["pd_advantage"],
|
||||||
|
"colo_att": g["8C-proxy"]["attainment"],
|
||||||
|
"pd_att": g["4P+4D"]["attainment"],
|
||||||
|
"colo_compl": d["arms"]["8C-proxy"]["completion_rate"],
|
||||||
|
"pd_compl": d["arms"]["4P+4D"]["completion_rate"],
|
||||||
|
"colo_ampl": d["arms"]["8C-proxy"]["amplification"],
|
||||||
|
"pd_ampl": d["arms"]["4P+4D"]["amplification"],
|
||||||
|
})
|
||||||
|
pts.sort(key=lambda p: p["x"])
|
||||||
|
return pts
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
d2 = load_series("d2_o", r"d2_o(\d+)_") # x = output length
|
||||||
|
d1 = load_series("d1_i", r"d1_i(\d+)_") # x = input length
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(2, 2, figsize=(12, 8))
|
||||||
|
|
||||||
|
def adv_plot(a, pts, xlabel, title, agentic_x, agentic_label):
|
||||||
|
CAP = 50.0 # display cap for "colo=0, PD>0" (PD wins outright)
|
||||||
|
line_x, line_y = [], []
|
||||||
|
for p in pts:
|
||||||
|
adv, colo_att, pd_att = p["adv"], p["colo_att"], p["pd_att"]
|
||||||
|
finite = isinstance(adv, (int, float)) and adv == adv
|
||||||
|
if finite:
|
||||||
|
line_x.append(p["x"]); line_y.append(max(adv, 1e-3))
|
||||||
|
elif colo_att == 0 and pd_att > 0: # PD wins outright
|
||||||
|
line_x.append(p["x"]); line_y.append(CAP)
|
||||||
|
a.annotate("PD wins\n(colo=0)", (p["x"], CAP), fontsize=7, ha="center", va="top")
|
||||||
|
else: # both fail SLO -> not a PD win
|
||||||
|
a.scatter([p["x"]], [1.0], marker="x", color="gray", zorder=5)
|
||||||
|
a.annotate("both\nfail SLO", (p["x"], 1.0), fontsize=7, ha="center",
|
||||||
|
va="bottom", color="gray")
|
||||||
|
a.plot(line_x, line_y, "o-", color="tab:blue", lw=2)
|
||||||
|
a.axhline(1.0, color="k", ls="--", lw=1, label="crossover (PD=colo)")
|
||||||
|
a.set_xscale("log"); a.set_yscale("log")
|
||||||
|
a.set_xlabel(xlabel)
|
||||||
|
a.set_ylabel("PD_advantage = goodput(4P+4D)/goodput(colo)")
|
||||||
|
a.set_title(title)
|
||||||
|
a.axvline(agentic_x, color="tab:red", ls=":", lw=1.5)
|
||||||
|
a.annotate(agentic_label, (agentic_x, 1.2), color="tab:red", fontsize=8,
|
||||||
|
rotation=90, va="bottom", ha="right")
|
||||||
|
a.legend(fontsize=8, loc="best")
|
||||||
|
a.grid(True, which="both", alpha=0.3)
|
||||||
|
|
||||||
|
def compl_plot(a, pts, xlabel, title):
|
||||||
|
xs = [p["x"] for p in pts]
|
||||||
|
a.plot(xs, [100*p["colo_compl"] for p in pts], "s-", color="tab:orange", label="colo (8C-proxy)")
|
||||||
|
a.plot(xs, [100*p["pd_compl"] for p in pts], "o-", color="tab:blue", label="PD (4P+4D)")
|
||||||
|
a.set_xscale("log")
|
||||||
|
a.set_xlabel(xlabel)
|
||||||
|
a.set_ylabel("completion rate (%)")
|
||||||
|
a.set_title(title)
|
||||||
|
a.set_ylim(0, 105)
|
||||||
|
a.legend(fontsize=8, loc="best")
|
||||||
|
a.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
adv_plot(ax[0][0], d2, "output length (tokens), fixed input=2048, q12",
|
||||||
|
"D2 decode-bound sweep — PD wins as output grows",
|
||||||
|
92, "agentic out~92")
|
||||||
|
adv_plot(ax[0][1], d1, "input length (tokens), fixed output=64, q4",
|
||||||
|
"D1 prefill-bound sweep — PD collapses as input grows",
|
||||||
|
33533, "agentic in~33k")
|
||||||
|
compl_plot(ax[1][0], d2, "output length (tokens)", "D2 completion")
|
||||||
|
compl_plot(ax[1][1], d1, "input length (tokens)", "D1 completion")
|
||||||
|
|
||||||
|
fig.suptitle("PD-disaggregation vs colocation: the prefill<->decode bottleneck crossover\n"
|
||||||
|
"(single node 8xH20, vLLM 0.18.1 chunked-prefill; zero-reuse synthetic)",
|
||||||
|
fontsize=11)
|
||||||
|
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
||||||
|
OUT.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fig.savefig(OUT, dpi=130)
|
||||||
|
print(f"wrote {OUT}")
|
||||||
|
# also dump the numeric series for the record
|
||||||
|
print("D2 (output -> adv):", [(p["x"], round(p["adv"],2) if p["adv"]==p["adv"] else "inf") for p in d2])
|
||||||
|
print("D1 (input -> adv):", [(p["x"], round(p["adv"],2) if p["adv"]==p["adv"] else "inf") for p in d1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user