gen_synthetic_trace (vanilla Poisson, zero prefix reuse — the regime where PD-disagg is expected to win), mutate_trace (morph reuse/burst/skew toward the agentic regime), and plot_crossover. Emits the replayer's JSONL schema. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
"""Reverse-ablation trace surgeon for the PD-disagg crossover study.
|
|
|
|
Takes the REAL agentic trace and neutralizes ONE agentic property at a time,
|
|
so we can see which removal restores PD-disagg to parity with colocation.
|
|
This is the subtractive complement to the additive synthetic sweep (D1-D5).
|
|
|
|
Neutralizations (compose freely; each defaults to off):
|
|
--set-output N set every output_length to N (kill short-output -> test decode-benefit starvation)
|
|
--max-input N clamp input_length to N, truncating hash_ids to ceil(N/512)
|
|
(kill huge-prefill -> test prefill-bound bottleneck)
|
|
--uniform-arrival respace requests evenly over the original span, order preserved
|
|
(kill burstiness -> test arrival variance)
|
|
--unique-hash replace all hash_ids with globally-unique ids
|
|
(kill intra-session reuse -> test cache/affinity)
|
|
--max-turns N keep only the first N turns of each session
|
|
(flatten session skew / heavy tail)
|
|
|
|
The schema is preserved so the replayer consumes the output unchanged.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
BLOCK_SIZE = 512
|
|
UNIQUE_HASH_BASE = 2_000_000_000
|
|
|
|
|
|
def load_rows(path: Path) -> list[dict]:
|
|
rows = []
|
|
with path.open() as fh:
|
|
for line in fh:
|
|
line = line.strip()
|
|
if line:
|
|
rows.append(json.loads(line))
|
|
return rows
|
|
|
|
|
|
def resolve_session(row: dict, chat_to_session: dict) -> str:
|
|
if "session_id" in row:
|
|
return str(row["session_id"])
|
|
cid, pid = int(row["chat_id"]), int(row["parent_chat_id"])
|
|
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
|
|
chat_to_session[cid] = sid
|
|
return sid
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
p.add_argument("--input", type=Path, required=True)
|
|
p.add_argument("--output", type=Path, required=True)
|
|
p.add_argument("--set-output", type=int, default=None)
|
|
p.add_argument("--max-input", type=int, default=None)
|
|
p.add_argument("--uniform-arrival", action="store_true")
|
|
p.add_argument("--unique-hash", action="store_true")
|
|
p.add_argument("--max-turns", type=int, default=None)
|
|
args = p.parse_args()
|
|
|
|
rows = load_rows(args.input)
|
|
# ensure session_id present
|
|
c2s: dict = {}
|
|
for r in rows:
|
|
r["session_id"] = resolve_session(r, c2s)
|
|
|
|
applied = []
|
|
|
|
# --- flatten session skew: keep first N turns per session ---
|
|
if args.max_turns is not None:
|
|
kept_count: dict = defaultdict(int)
|
|
rows.sort(key=lambda r: (r["session_id"], r.get("turn", 0)))
|
|
out = []
|
|
for r in rows:
|
|
if kept_count[r["session_id"]] < args.max_turns:
|
|
kept_count[r["session_id"]] += 1
|
|
out.append(r)
|
|
rows = out
|
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
|
applied.append(f"max_turns={args.max_turns}")
|
|
|
|
# --- clamp input + truncate hash_ids ---
|
|
if args.max_input is not None:
|
|
for r in rows:
|
|
if r["input_length"] > args.max_input:
|
|
r["input_length"] = args.max_input
|
|
keep_blocks = max(1, math.ceil(r["input_length"] / BLOCK_SIZE))
|
|
r["hash_ids"] = list(r.get("hash_ids", []))[:keep_blocks]
|
|
applied.append(f"max_input={args.max_input}")
|
|
|
|
# --- set fixed output length ---
|
|
if args.set_output is not None:
|
|
for r in rows:
|
|
r["output_length"] = args.set_output
|
|
applied.append(f"set_output={args.set_output}")
|
|
|
|
# --- kill reuse: globally-unique hashes ---
|
|
if args.unique_hash:
|
|
nxt = UNIQUE_HASH_BASE
|
|
for r in rows:
|
|
n = max(1, len(r.get("hash_ids", [])) or
|
|
math.ceil(r["input_length"] / BLOCK_SIZE))
|
|
r["hash_ids"] = list(range(nxt, nxt + n))
|
|
nxt += n
|
|
applied.append("unique_hash")
|
|
|
|
# --- de-burst: uniform arrival over original span (order preserved) ---
|
|
if args.uniform_arrival:
|
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
|
ts = [r.get("timestamp", 0.0) for r in rows]
|
|
span = (ts[-1] - ts[0]) if len(ts) > 1 else 0.0
|
|
n = len(rows)
|
|
for i, r in enumerate(rows):
|
|
r["timestamp"] = round(ts[0] + (span * i / max(n - 1, 1)), 6)
|
|
applied.append("uniform_arrival")
|
|
|
|
rows.sort(key=lambda r: r.get("timestamp", 0.0))
|
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
with args.output.open("w") as fh:
|
|
for r in rows:
|
|
fh.write(json.dumps(r) + "\n")
|
|
|
|
inputs = sorted(r["input_length"] for r in rows)
|
|
outs = sorted(r["output_length"] for r in rows)
|
|
q = lambda v, p: v[min(int(p * len(v)), len(v) - 1)] if v else 0
|
|
print(f"wrote {len(rows)} rows -> {args.output}")
|
|
print(f" neutralized: {applied or ['none (passthrough)']}")
|
|
print(f" input p50={q(inputs,.5)} p90={q(inputs,.9)} p99={q(inputs,.99)} "
|
|
f"output p50={q(outs,.5)} p90={q(outs,.9)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|