"""Sample sessions from the full cluster-scale trace to fit a single machine. Preserves: - Complete session structure (all turns within a session kept together) - Original arrival timing (re-zeroed to t=0 but NOT compressed) - KV cache reuse patterns (both intra-session AND cross-session sharing) - Request type distribution Sampling strategy (--sample-ratio): 1. Take a contiguous time window from the trace (all sessions whose first request falls within the window). This preserves cross-session hash block sharing because sessions that share system prompts appear together in the same time region. 2. Within the window, randomly thin sessions by ratio to control QPS. 3. Re-zero timestamps so first event starts at t=0. The window is sized so that (window_sessions * thin_ratio) ≈ target count. Thin ratio is set high enough (≥0.5) to keep cross-session block sharing intact; the window width is narrowed to compensate. Usage: # Sample for 8 GPUs from a ~500-GPU cluster python scripts/sample_trace.py \\ --input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\ --output traces/sampled.jsonl \\ --sample-ratio 0.016 --seed 42 # Sample by request count (legacy, no sharing preservation) python scripts/sample_trace.py \\ --input ... --output ... --target-requests 1000 --seed 42 """ from __future__ import annotations import argparse import collections import json import random from pathlib import Path def load_raw_rows(path: Path) -> dict[str, list[dict]]: """Load trace, group rows by resolved session_id. Preserve file order.""" chat_to_session: dict[int, str] = {} rows_by_session: dict[str, list[dict]] = collections.OrderedDict() with path.open("r", encoding="utf-8") as fh: for line in fh: row = json.loads(line) cid = int(row["chat_id"]) pid = int(row["parent_chat_id"]) if "session_id" in row: sid = str(row["session_id"]) elif pid < 0: sid = str(cid) else: sid = chat_to_session.get(pid, str(pid)) chat_to_session[cid] = sid row["_session_id"] = sid rows_by_session.setdefault(sid, []).append(row) return rows_by_session def sample_sessions( rows_by_session: dict[str, list[dict]], *, sample_ratio: float | None = None, target_requests: int | None = None, max_single_turn_ratio: float | None = None, window_seconds: float | None = None, seed: int, ) -> list[str]: """Sample sessions preserving KV cache reuse.""" rng = random.Random(seed) if sample_ratio is not None: selected = _sample_window_then_thin(rows_by_session, sample_ratio, window_seconds, rng) elif target_requests is not None: all_sids = list(rows_by_session.keys()) rng.shuffle(all_sids) selected = [] total = 0 for sid in all_sids: selected.append(sid) total += len(rows_by_session[sid]) if total >= target_requests: break else: raise ValueError("Must specify --sample-ratio or --target-requests") if max_single_turn_ratio is not None: selected = _cap_single_turn(rows_by_session, selected, max_single_turn_ratio, rng) return selected def _cap_single_turn( rows_by_session: dict[str, list[dict]], selected: list[str], max_ratio: float, rng: random.Random, ) -> list[str]: """Thin single-turn sessions so they are at most max_ratio of total sessions.""" multi = [s for s in selected if len(rows_by_session[s]) > 1] single = [s for s in selected if len(rows_by_session[s]) == 1] # max_ratio of TOTAL sessions should be single-turn # n_single / (n_single + n_multi) <= max_ratio # n_single <= max_ratio * n_multi / (1 - max_ratio) max_single = int(max_ratio * len(multi) / (1 - max_ratio)) if len(single) <= max_single: return selected rng.shuffle(single) return multi + single[:max_single] def _sample_window_then_thin( rows_by_session: dict[str, list[dict]], ratio: float, window_seconds: float | None, rng: random.Random, ) -> list[str]: """Window + thin sampling that preserves cross-session sharing. 1. Compute first-request timestamp for each session. 2. Pick a contiguous time window: - If --window-seconds given: use that duration, thin by ratio within it. - Otherwise: auto-size so window_sessions * thin_ratio ≈ target. 3. Keep all sessions whose first request falls within the window. 4. Randomly thin sessions within the window to hit target count. """ session_starts: list[tuple[float, str]] = [] for sid, rows in rows_by_session.items(): t0 = min(float(r["timestamp"]) for r in rows) session_starts.append((t0, sid)) session_starts.sort() total_sessions = len(session_starts) target_n = max(1, int(total_sessions * ratio)) trace_start = session_starts[0][0] trace_end = session_starts[-1][0] trace_duration = trace_end - trace_start if window_seconds is not None: # Fixed window: pick a random start, thin to hit target ratio max_start_t = trace_end - window_seconds if max_start_t <= trace_start: win_start_t = trace_start else: win_start_t = trace_start + rng.random() * (max_start_t - trace_start) win_end_t = win_start_t + window_seconds window_sids = [sid for t, sid in session_starts if win_start_t <= t <= win_end_t] # Thin to target if len(window_sids) > target_n: thin_ratio = target_n / len(window_sids) window_sids = [s for s in window_sids if rng.random() < thin_ratio] return window_sids # Auto-size window thin_ratio = min(1.0, max(0.5, ratio * 10)) window_sessions = min(int(target_n / thin_ratio), total_sessions) max_start = total_sessions - window_sessions window_start = rng.randint(0, max_start) if max_start > 0 else 0 window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]] if thin_ratio < 1.0: window_sids = [s for s in window_sids if rng.random() < thin_ratio] if len(window_sids) > target_n * 1.2: rng.shuffle(window_sids) window_sids = window_sids[:int(target_n * 1.1)] return window_sids def build_output( rows_by_session: dict[str, list[dict]], selected: list[str], ) -> list[dict]: """Build output rows with re-zeroed timestamps (no time compression).""" out_rows = [] for sid in selected: for row in rows_by_session[sid]: out = {k: v for k, v in row.items() if not k.startswith("_")} out["session_id"] = sid out_rows.append(out) out_rows.sort(key=lambda r: float(r["timestamp"])) if not out_rows: return out_rows t0 = float(out_rows[0]["timestamp"]) for row in out_rows: row["timestamp"] = float(row["timestamp"]) - t0 return out_rows def print_summary( rows_by_session: dict[str, list[dict]], selected: list[str], out_rows: list[dict], ) -> None: n_sessions = len(selected) n_requests = len(out_rows) turns_per_session = [len(rows_by_session[s]) for s in selected] multi_turn = sum(1 for t in turns_per_session if t > 1) input_lens = [r["input_length"] for r in out_rows] output_lens = [r["output_length"] for r in out_rows] span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0 qps = n_requests / span_s if span_s > 0 else 0 # Hash block sharing block_freq: dict[int, int] = collections.Counter() for r in out_rows: for h in r.get("hash_ids", []): block_freq[h] += 1 total_blocks = len(block_freq) shared_blocks = sum(1 for c in block_freq.values() if c > 1) print(f"Sampled: {n_sessions} sessions, {n_requests} requests") print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)") print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} " f"avg={sum(turns_per_session)/len(turns_per_session):.1f}") print(f" Input length: min={min(input_lens)} max={max(input_lens)} " f"avg={sum(input_lens)/len(input_lens):.0f}") print(f" Output length: min={min(output_lens)} max={max(output_lens)} " f"avg={sum(output_lens)/len(output_lens):.0f}") print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)") print(f" QPS: {qps:.2f} req/s") print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)") def main() -> None: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--input", type=Path, required=True, help="Path to the full trace JSONL file") p.add_argument("--output", type=Path, required=True, help="Path to write sampled trace JSONL") p.add_argument("--sample-ratio", type=float, default=None, help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)") p.add_argument("--target-requests", type=int, default=None, help="Target number of requests (legacy, no sharing preservation)") p.add_argument("--max-single-turn-ratio", type=float, default=None, help="Cap single-turn sessions to this fraction of total (e.g. 0.3)") p.add_argument("--window-seconds", type=float, default=None, help="Time window duration in seconds (e.g. 600 for 10 min)") p.add_argument("--seed", type=int, default=42) args = p.parse_args() if args.sample_ratio is None and args.target_requests is None: p.error("Must specify --sample-ratio or --target-requests") print(f"Loading trace from {args.input} ...") rows_by_session = load_raw_rows(args.input) total_sessions = len(rows_by_session) total_requests = sum(len(v) for v in rows_by_session.values()) print(f"Full trace: {total_sessions} sessions, {total_requests} requests") selected = sample_sessions( rows_by_session, sample_ratio=args.sample_ratio, target_requests=args.target_requests, max_single_turn_ratio=args.max_single_turn_ratio, window_seconds=args.window_seconds, seed=args.seed, ) out_rows = build_output(rows_by_session, selected) print_summary(rows_by_session, selected, out_rows) args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as fh: for row in out_rows: fh.write(json.dumps(row, ensure_ascii=False) + "\n") print(f"\nWrote {len(out_rows)} rows to {args.output}") if __name__ == "__main__": main()