diff --git a/scripts/sample_trace.py b/scripts/sample_trace.py index 0c76348..887d7ba 100644 --- a/scripts/sample_trace.py +++ b/scripts/sample_trace.py @@ -3,25 +3,29 @@ Preserves: - Complete session structure (all turns within a session kept together) - Original arrival timing (re-zeroed to t=0 but NOT compressed) - - hash_ids for KV cache reuse patterns + - KV cache reuse patterns (both intra-session AND cross-session sharing) - Request type distribution -Sampling strategy: - 1. Group requests by session (derived from parent_chat_id chains) - 2. Randomly sample a fraction of sessions (--sample-ratio) - OR sample until target request count (--target-requests) - 3. Re-zero timestamps so first event starts at t=0 - 4. The resulting QPS is proportional to the sample ratio, - preserving the production arrival pattern +Sampling strategy (--sample-ratio): + 1. Take a contiguous time window from the trace (all sessions whose + first request falls within the window). This preserves cross-session + hash block sharing because sessions that share system prompts appear + together in the same time region. + 2. Within the window, randomly thin sessions by ratio to control QPS. + 3. Re-zero timestamps so first event starts at t=0. + +The window is sized so that (window_sessions * thin_ratio) ≈ target count. +Thin ratio is set high enough (≥0.5) to keep cross-session block sharing +intact; the window width is narrowed to compensate. Usage: - # Sample 1.6% of sessions (e.g., 8 GPUs / 500 cluster GPUs) + # Sample for 8 GPUs from a ~500-GPU cluster python scripts/sample_trace.py \\ --input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\ - --output traces/sampled_ratio016.jsonl \\ + --output traces/sampled.jsonl \\ --sample-ratio 0.016 --seed 42 - # Sample by request count (legacy) + # Sample by request count (legacy, no sharing preservation) python scripts/sample_trace.py \\ --input ... --output ... --target-requests 1000 --seed 42 """ @@ -32,7 +36,6 @@ import argparse import collections import json import random -import sys from pathlib import Path @@ -68,16 +71,15 @@ def sample_sessions( target_requests: int | None = None, seed: int, ) -> list[str]: - """Select sessions by ratio or until target request count.""" - all_sids = list(rows_by_session.keys()) + """Sample sessions preserving KV cache reuse.""" rng = random.Random(seed) - rng.shuffle(all_sids) if sample_ratio is not None: - n_select = max(1, int(len(all_sids) * sample_ratio)) - return all_sids[:n_select] + return _sample_window_then_thin(rows_by_session, sample_ratio, rng) if target_requests is not None: + all_sids = list(rows_by_session.keys()) + rng.shuffle(all_sids) selected = [] total = 0 for sid in all_sids: @@ -90,6 +92,59 @@ def sample_sessions( raise ValueError("Must specify --sample-ratio or --target-requests") +def _sample_window_then_thin( + rows_by_session: dict[str, list[dict]], + ratio: float, + rng: random.Random, +) -> list[str]: + """Window + thin sampling that preserves cross-session sharing. + + 1. Compute first-request timestamp for each session. + 2. Pick a contiguous time window sized so that + window_sessions * thin_ratio ≈ total_sessions * ratio. + thin_ratio is kept >= 0.5 to preserve cross-session sharing. + 3. Randomly drop (1 - thin_ratio) of sessions within the window. + """ + # Session start times (timestamp of first request) + session_starts: list[tuple[float, str]] = [] + for sid, rows in rows_by_session.items(): + t0 = min(float(r["timestamp"]) for r in rows) + session_starts.append((t0, sid)) + session_starts.sort() + + total_sessions = len(session_starts) + target_n = max(1, int(total_sessions * ratio)) + + # Determine thin_ratio and window size + # thin_ratio >= 0.5 to preserve sharing; prefer 1.0 if window fits + # window_sessions = target_n / thin_ratio + thin_ratio = min(1.0, max(0.5, ratio * 10)) + window_sessions = int(target_n / thin_ratio) + window_sessions = min(window_sessions, total_sessions) + + # Pick window start: random position in the trace + max_start = total_sessions - window_sessions + if max_start <= 0: + window_start = 0 + else: + window_start = rng.randint(0, max_start) + + window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]] + + # Thin within window + if thin_ratio >= 1.0: + selected = window_sids + else: + selected = [sid for sid in window_sids if rng.random() < thin_ratio] + + # Ensure we don't overshoot target by too much + if len(selected) > target_n * 1.2: + rng.shuffle(selected) + selected = selected[:int(target_n * 1.1)] + + return selected + + def build_output( rows_by_session: dict[str, list[dict]], selected: list[str], @@ -130,17 +185,13 @@ def print_summary( span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0 qps = n_requests / span_s if span_s > 0 else 0 - session_starts = {} + # Hash block sharing + block_freq: dict[int, int] = collections.Counter() for r in out_rows: - sid = r["session_id"] - ts = float(r["timestamp"]) - if sid not in session_starts: - session_starts[sid] = ts - - # hash_ids overlap - all_hashes = set() - for r in out_rows: - all_hashes.update(r.get("hash_ids", [])) + for h in r.get("hash_ids", []): + block_freq[h] += 1 + total_blocks = len(block_freq) + shared_blocks = sum(1 for c in block_freq.values() if c > 1) print(f"Sampled: {n_sessions} sessions, {n_requests} requests") print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)") @@ -152,7 +203,7 @@ def print_summary( f"avg={sum(output_lens)/len(output_lens):.0f}") print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)") print(f" QPS: {qps:.2f} req/s") - print(f" Unique hash blocks: {len(all_hashes)}") + print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)") def main() -> None: @@ -165,7 +216,7 @@ def main() -> None: p.add_argument("--sample-ratio", type=float, default=None, help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)") p.add_argument("--target-requests", type=int, default=None, - help="Target number of requests (legacy, stops after session that crosses it)") + help="Target number of requests (legacy, no sharing preservation)") p.add_argument("--seed", type=int, default=42) args = p.parse_args()