"""Cap per-session turn count to isolate session-mass effects in B3. Input trace is grouped by session_id (or reconstructed from parent_chat_id chains). Sessions with more than --max-turns turns are truncated to keep only their first N turns in trace order. The output preserves the original line ordering (timestamp order). """ from __future__ import annotations import argparse import json from collections import defaultdict from pathlib import Path def _resolve_session_id(row: dict, chat_to_session: dict[int, str]) -> str: if "session_id" in row: return str(row["session_id"]) cid = int(row["chat_id"]) pcid = int(row["parent_chat_id"]) if pcid < 0: sid = str(cid) else: sid = chat_to_session.get(pcid, str(pcid)) chat_to_session[cid] = sid return sid def main() -> None: p = argparse.ArgumentParser(description="Cap per-session turn count") p.add_argument("--input", type=Path, required=True) p.add_argument("--output", type=Path, required=True) p.add_argument("--max-turns", type=int, default=8, help="Keep at most N earliest turns per session") args = p.parse_args() chat_to_session: dict[int, str] = {} kept: dict[str, int] = defaultdict(int) rows: list[tuple[str, dict]] = [] with args.input.open("r", encoding="utf-8") as fh: for line in fh: line = line.strip() if not line: continue row = json.loads(line) sid = _resolve_session_id(row, chat_to_session) row["session_id"] = sid rows.append((sid, row)) in_n = len(rows) sessions = len({sid for sid, _ in rows}) rows.sort(key=lambda x: (x[1]["session_id"], x[1].get("turn", 0))) capped_rows: list[dict] = [] for sid, row in rows: if kept[sid] >= args.max_turns: continue kept[sid] += 1 capped_rows.append(row) capped_rows.sort(key=lambda r: r.get("timestamp", 0.0)) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as fh: for r in capped_rows: fh.write(json.dumps(r) + "\n") print(f"input rows: {in_n}, sessions: {sessions}") print(f"capped rows: {len(capped_rows)} (max_turns={args.max_turns})") dropped = in_n - len(capped_rows) print(f"dropped: {dropped} ({100 * dropped / max(in_n, 1):.1f}%)") if capped_rows: from collections import Counter turns_dist = Counter(kept[s] for s in kept) top = sorted(turns_dist.items())[:6] print(f"turns/session (capped) sample: {top}") if __name__ == "__main__": main()