Fix sampler: window+thin preserves cross-session KV cache sharing

Random session sampling destroys cross-session hash block sharing
(52% -> 16%) because sessions sharing system prompts get scattered.

New approach: take a contiguous time window from the trace (preserving
temporal locality of shared-prefix sessions), then thin within the
window to hit target QPS. This preserves both intra-session reuse
(62% of reusable tokens) and cross-session sharing (38%).

Results (block sharing rate):
  Old random r=0.002:  16.0%  ->  Window+thin: 29.7%
  Old random r=0.016:  19.5%  ->  Window+thin: 42.7%
  Full trace baseline: 52%

Also corrected the "91% intra-session" claim: actual split is
62% intra / 38% cross (token-level), making cross-session sharing
preservation critical for valid APC benchmarks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-23 14:03:12 +08:00
parent 4089ffd63f
commit 1e1e2e774d

View File

@@ -3,25 +3,29 @@
Preserves: Preserves:
- Complete session structure (all turns within a session kept together) - Complete session structure (all turns within a session kept together)
- Original arrival timing (re-zeroed to t=0 but NOT compressed) - Original arrival timing (re-zeroed to t=0 but NOT compressed)
- hash_ids for KV cache reuse patterns - KV cache reuse patterns (both intra-session AND cross-session sharing)
- Request type distribution - Request type distribution
Sampling strategy: Sampling strategy (--sample-ratio):
1. Group requests by session (derived from parent_chat_id chains) 1. Take a contiguous time window from the trace (all sessions whose
2. Randomly sample a fraction of sessions (--sample-ratio) first request falls within the window). This preserves cross-session
OR sample until target request count (--target-requests) hash block sharing because sessions that share system prompts appear
3. Re-zero timestamps so first event starts at t=0 together in the same time region.
4. The resulting QPS is proportional to the sample ratio, 2. Within the window, randomly thin sessions by ratio to control QPS.
preserving the production arrival pattern 3. Re-zero timestamps so first event starts at t=0.
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
intact; the window width is narrowed to compensate.
Usage: Usage:
# Sample 1.6% of sessions (e.g., 8 GPUs / 500 cluster GPUs) # Sample for 8 GPUs from a ~500-GPU cluster
python scripts/sample_trace.py \\ python scripts/sample_trace.py \\
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\ --input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
--output traces/sampled_ratio016.jsonl \\ --output traces/sampled.jsonl \\
--sample-ratio 0.016 --seed 42 --sample-ratio 0.016 --seed 42
# Sample by request count (legacy) # Sample by request count (legacy, no sharing preservation)
python scripts/sample_trace.py \\ python scripts/sample_trace.py \\
--input ... --output ... --target-requests 1000 --seed 42 --input ... --output ... --target-requests 1000 --seed 42
""" """
@@ -32,7 +36,6 @@ import argparse
import collections import collections
import json import json
import random import random
import sys
from pathlib import Path from pathlib import Path
@@ -68,16 +71,15 @@ def sample_sessions(
target_requests: int | None = None, target_requests: int | None = None,
seed: int, seed: int,
) -> list[str]: ) -> list[str]:
"""Select sessions by ratio or until target request count.""" """Sample sessions preserving KV cache reuse."""
all_sids = list(rows_by_session.keys())
rng = random.Random(seed) rng = random.Random(seed)
rng.shuffle(all_sids)
if sample_ratio is not None: if sample_ratio is not None:
n_select = max(1, int(len(all_sids) * sample_ratio)) return _sample_window_then_thin(rows_by_session, sample_ratio, rng)
return all_sids[:n_select]
if target_requests is not None: if target_requests is not None:
all_sids = list(rows_by_session.keys())
rng.shuffle(all_sids)
selected = [] selected = []
total = 0 total = 0
for sid in all_sids: for sid in all_sids:
@@ -90,6 +92,59 @@ def sample_sessions(
raise ValueError("Must specify --sample-ratio or --target-requests") raise ValueError("Must specify --sample-ratio or --target-requests")
def _sample_window_then_thin(
rows_by_session: dict[str, list[dict]],
ratio: float,
rng: random.Random,
) -> list[str]:
"""Window + thin sampling that preserves cross-session sharing.
1. Compute first-request timestamp for each session.
2. Pick a contiguous time window sized so that
window_sessions * thin_ratio ≈ total_sessions * ratio.
thin_ratio is kept >= 0.5 to preserve cross-session sharing.
3. Randomly drop (1 - thin_ratio) of sessions within the window.
"""
# Session start times (timestamp of first request)
session_starts: list[tuple[float, str]] = []
for sid, rows in rows_by_session.items():
t0 = min(float(r["timestamp"]) for r in rows)
session_starts.append((t0, sid))
session_starts.sort()
total_sessions = len(session_starts)
target_n = max(1, int(total_sessions * ratio))
# Determine thin_ratio and window size
# thin_ratio >= 0.5 to preserve sharing; prefer 1.0 if window fits
# window_sessions = target_n / thin_ratio
thin_ratio = min(1.0, max(0.5, ratio * 10))
window_sessions = int(target_n / thin_ratio)
window_sessions = min(window_sessions, total_sessions)
# Pick window start: random position in the trace
max_start = total_sessions - window_sessions
if max_start <= 0:
window_start = 0
else:
window_start = rng.randint(0, max_start)
window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]]
# Thin within window
if thin_ratio >= 1.0:
selected = window_sids
else:
selected = [sid for sid in window_sids if rng.random() < thin_ratio]
# Ensure we don't overshoot target by too much
if len(selected) > target_n * 1.2:
rng.shuffle(selected)
selected = selected[:int(target_n * 1.1)]
return selected
def build_output( def build_output(
rows_by_session: dict[str, list[dict]], rows_by_session: dict[str, list[dict]],
selected: list[str], selected: list[str],
@@ -130,17 +185,13 @@ def print_summary(
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0 span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
qps = n_requests / span_s if span_s > 0 else 0 qps = n_requests / span_s if span_s > 0 else 0
session_starts = {} # Hash block sharing
block_freq: dict[int, int] = collections.Counter()
for r in out_rows: for r in out_rows:
sid = r["session_id"] for h in r.get("hash_ids", []):
ts = float(r["timestamp"]) block_freq[h] += 1
if sid not in session_starts: total_blocks = len(block_freq)
session_starts[sid] = ts shared_blocks = sum(1 for c in block_freq.values() if c > 1)
# hash_ids overlap
all_hashes = set()
for r in out_rows:
all_hashes.update(r.get("hash_ids", []))
print(f"Sampled: {n_sessions} sessions, {n_requests} requests") print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)") print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
@@ -152,7 +203,7 @@ def print_summary(
f"avg={sum(output_lens)/len(output_lens):.0f}") f"avg={sum(output_lens)/len(output_lens):.0f}")
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)") print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
print(f" QPS: {qps:.2f} req/s") print(f" QPS: {qps:.2f} req/s")
print(f" Unique hash blocks: {len(all_hashes)}") print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
def main() -> None: def main() -> None:
@@ -165,7 +216,7 @@ def main() -> None:
p.add_argument("--sample-ratio", type=float, default=None, p.add_argument("--sample-ratio", type=float, default=None,
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)") help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
p.add_argument("--target-requests", type=int, default=None, p.add_argument("--target-requests", type=int, default=None,
help="Target number of requests (legacy, stops after session that crosses it)") help="Target number of requests (legacy, no sharing preservation)")
p.add_argument("--seed", type=int, default=42) p.add_argument("--seed", type=int, default=42)
args = p.parse_args() args = p.parse_args()