Files
agentic-kvc/microbench/fresh_setup/mutate_trace.py
Gahow Wang bad512d3c5 PD-disagg crossover: synthetic-trace generator + morpher + plotter
gen_synthetic_trace (vanilla Poisson, zero prefix reuse — the regime where
PD-disagg is expected to win), mutate_trace (morph reuse/burst/skew toward
the agentic regime), and plot_crossover. Emits the replayer's JSONL schema.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-05-29 11:53:21 +08:00

137 lines
5.2 KiB
Python

"""Reverse-ablation trace surgeon for the PD-disagg crossover study.
Takes the REAL agentic trace and neutralizes ONE agentic property at a time,
so we can see which removal restores PD-disagg to parity with colocation.
This is the subtractive complement to the additive synthetic sweep (D1-D5).
Neutralizations (compose freely; each defaults to off):
--set-output N set every output_length to N (kill short-output -> test decode-benefit starvation)
--max-input N clamp input_length to N, truncating hash_ids to ceil(N/512)
(kill huge-prefill -> test prefill-bound bottleneck)
--uniform-arrival respace requests evenly over the original span, order preserved
(kill burstiness -> test arrival variance)
--unique-hash replace all hash_ids with globally-unique ids
(kill intra-session reuse -> test cache/affinity)
--max-turns N keep only the first N turns of each session
(flatten session skew / heavy tail)
The schema is preserved so the replayer consumes the output unchanged.
"""
from __future__ import annotations
import argparse
import json
import math
from collections import defaultdict
from pathlib import Path
BLOCK_SIZE = 512
UNIQUE_HASH_BASE = 2_000_000_000
def load_rows(path: Path) -> list[dict]:
rows = []
with path.open() as fh:
for line in fh:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def resolve_session(row: dict, chat_to_session: dict) -> str:
if "session_id" in row:
return str(row["session_id"])
cid, pid = int(row["chat_id"]), int(row["parent_chat_id"])
sid = str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))
chat_to_session[cid] = sid
return sid
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--set-output", type=int, default=None)
p.add_argument("--max-input", type=int, default=None)
p.add_argument("--uniform-arrival", action="store_true")
p.add_argument("--unique-hash", action="store_true")
p.add_argument("--max-turns", type=int, default=None)
args = p.parse_args()
rows = load_rows(args.input)
# ensure session_id present
c2s: dict = {}
for r in rows:
r["session_id"] = resolve_session(r, c2s)
applied = []
# --- flatten session skew: keep first N turns per session ---
if args.max_turns is not None:
kept_count: dict = defaultdict(int)
rows.sort(key=lambda r: (r["session_id"], r.get("turn", 0)))
out = []
for r in rows:
if kept_count[r["session_id"]] < args.max_turns:
kept_count[r["session_id"]] += 1
out.append(r)
rows = out
rows.sort(key=lambda r: r.get("timestamp", 0.0))
applied.append(f"max_turns={args.max_turns}")
# --- clamp input + truncate hash_ids ---
if args.max_input is not None:
for r in rows:
if r["input_length"] > args.max_input:
r["input_length"] = args.max_input
keep_blocks = max(1, math.ceil(r["input_length"] / BLOCK_SIZE))
r["hash_ids"] = list(r.get("hash_ids", []))[:keep_blocks]
applied.append(f"max_input={args.max_input}")
# --- set fixed output length ---
if args.set_output is not None:
for r in rows:
r["output_length"] = args.set_output
applied.append(f"set_output={args.set_output}")
# --- kill reuse: globally-unique hashes ---
if args.unique_hash:
nxt = UNIQUE_HASH_BASE
for r in rows:
n = max(1, len(r.get("hash_ids", [])) or
math.ceil(r["input_length"] / BLOCK_SIZE))
r["hash_ids"] = list(range(nxt, nxt + n))
nxt += n
applied.append("unique_hash")
# --- de-burst: uniform arrival over original span (order preserved) ---
if args.uniform_arrival:
rows.sort(key=lambda r: r.get("timestamp", 0.0))
ts = [r.get("timestamp", 0.0) for r in rows]
span = (ts[-1] - ts[0]) if len(ts) > 1 else 0.0
n = len(rows)
for i, r in enumerate(rows):
r["timestamp"] = round(ts[0] + (span * i / max(n - 1, 1)), 6)
applied.append("uniform_arrival")
rows.sort(key=lambda r: r.get("timestamp", 0.0))
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w") as fh:
for r in rows:
fh.write(json.dumps(r) + "\n")
inputs = sorted(r["input_length"] for r in rows)
outs = sorted(r["output_length"] for r in rows)
q = lambda v, p: v[min(int(p * len(v)), len(v) - 1)] if v else 0
print(f"wrote {len(rows)} rows -> {args.output}")
print(f" neutralized: {applied or ['none (passthrough)']}")
print(f" input p50={q(inputs,.5)} p90={q(inputs,.9)} p99={q(inputs,.99)} "
f"output p50={q(outs,.5)} p90={q(outs,.9)}")
if __name__ == "__main__":
main()