Fix sampler: window+thin preserves cross-session KV cache sharing
Random session sampling destroys cross-session hash block sharing (52% -> 16%) because sessions sharing system prompts get scattered. New approach: take a contiguous time window from the trace (preserving temporal locality of shared-prefix sessions), then thin within the window to hit target QPS. This preserves both intra-session reuse (62% of reusable tokens) and cross-session sharing (38%). Results (block sharing rate): Old random r=0.002: 16.0% -> Window+thin: 29.7% Old random r=0.016: 19.5% -> Window+thin: 42.7% Full trace baseline: 52% Also corrected the "91% intra-session" claim: actual split is 62% intra / 38% cross (token-level), making cross-session sharing preservation critical for valid APC benchmarks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,25 +3,29 @@
|
|||||||
Preserves:
|
Preserves:
|
||||||
- Complete session structure (all turns within a session kept together)
|
- Complete session structure (all turns within a session kept together)
|
||||||
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
|
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
|
||||||
- hash_ids for KV cache reuse patterns
|
- KV cache reuse patterns (both intra-session AND cross-session sharing)
|
||||||
- Request type distribution
|
- Request type distribution
|
||||||
|
|
||||||
Sampling strategy:
|
Sampling strategy (--sample-ratio):
|
||||||
1. Group requests by session (derived from parent_chat_id chains)
|
1. Take a contiguous time window from the trace (all sessions whose
|
||||||
2. Randomly sample a fraction of sessions (--sample-ratio)
|
first request falls within the window). This preserves cross-session
|
||||||
OR sample until target request count (--target-requests)
|
hash block sharing because sessions that share system prompts appear
|
||||||
3. Re-zero timestamps so first event starts at t=0
|
together in the same time region.
|
||||||
4. The resulting QPS is proportional to the sample ratio,
|
2. Within the window, randomly thin sessions by ratio to control QPS.
|
||||||
preserving the production arrival pattern
|
3. Re-zero timestamps so first event starts at t=0.
|
||||||
|
|
||||||
|
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
|
||||||
|
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
|
||||||
|
intact; the window width is narrowed to compensate.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
# Sample 1.6% of sessions (e.g., 8 GPUs / 500 cluster GPUs)
|
# Sample for 8 GPUs from a ~500-GPU cluster
|
||||||
python scripts/sample_trace.py \\
|
python scripts/sample_trace.py \\
|
||||||
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
||||||
--output traces/sampled_ratio016.jsonl \\
|
--output traces/sampled.jsonl \\
|
||||||
--sample-ratio 0.016 --seed 42
|
--sample-ratio 0.016 --seed 42
|
||||||
|
|
||||||
# Sample by request count (legacy)
|
# Sample by request count (legacy, no sharing preservation)
|
||||||
python scripts/sample_trace.py \\
|
python scripts/sample_trace.py \\
|
||||||
--input ... --output ... --target-requests 1000 --seed 42
|
--input ... --output ... --target-requests 1000 --seed 42
|
||||||
"""
|
"""
|
||||||
@@ -32,7 +36,6 @@ import argparse
|
|||||||
import collections
|
import collections
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -68,16 +71,15 @@ def sample_sessions(
|
|||||||
target_requests: int | None = None,
|
target_requests: int | None = None,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Select sessions by ratio or until target request count."""
|
"""Sample sessions preserving KV cache reuse."""
|
||||||
all_sids = list(rows_by_session.keys())
|
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
rng.shuffle(all_sids)
|
|
||||||
|
|
||||||
if sample_ratio is not None:
|
if sample_ratio is not None:
|
||||||
n_select = max(1, int(len(all_sids) * sample_ratio))
|
return _sample_window_then_thin(rows_by_session, sample_ratio, rng)
|
||||||
return all_sids[:n_select]
|
|
||||||
|
|
||||||
if target_requests is not None:
|
if target_requests is not None:
|
||||||
|
all_sids = list(rows_by_session.keys())
|
||||||
|
rng.shuffle(all_sids)
|
||||||
selected = []
|
selected = []
|
||||||
total = 0
|
total = 0
|
||||||
for sid in all_sids:
|
for sid in all_sids:
|
||||||
@@ -90,6 +92,59 @@ def sample_sessions(
|
|||||||
raise ValueError("Must specify --sample-ratio or --target-requests")
|
raise ValueError("Must specify --sample-ratio or --target-requests")
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_window_then_thin(
|
||||||
|
rows_by_session: dict[str, list[dict]],
|
||||||
|
ratio: float,
|
||||||
|
rng: random.Random,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Window + thin sampling that preserves cross-session sharing.
|
||||||
|
|
||||||
|
1. Compute first-request timestamp for each session.
|
||||||
|
2. Pick a contiguous time window sized so that
|
||||||
|
window_sessions * thin_ratio ≈ total_sessions * ratio.
|
||||||
|
thin_ratio is kept >= 0.5 to preserve cross-session sharing.
|
||||||
|
3. Randomly drop (1 - thin_ratio) of sessions within the window.
|
||||||
|
"""
|
||||||
|
# Session start times (timestamp of first request)
|
||||||
|
session_starts: list[tuple[float, str]] = []
|
||||||
|
for sid, rows in rows_by_session.items():
|
||||||
|
t0 = min(float(r["timestamp"]) for r in rows)
|
||||||
|
session_starts.append((t0, sid))
|
||||||
|
session_starts.sort()
|
||||||
|
|
||||||
|
total_sessions = len(session_starts)
|
||||||
|
target_n = max(1, int(total_sessions * ratio))
|
||||||
|
|
||||||
|
# Determine thin_ratio and window size
|
||||||
|
# thin_ratio >= 0.5 to preserve sharing; prefer 1.0 if window fits
|
||||||
|
# window_sessions = target_n / thin_ratio
|
||||||
|
thin_ratio = min(1.0, max(0.5, ratio * 10))
|
||||||
|
window_sessions = int(target_n / thin_ratio)
|
||||||
|
window_sessions = min(window_sessions, total_sessions)
|
||||||
|
|
||||||
|
# Pick window start: random position in the trace
|
||||||
|
max_start = total_sessions - window_sessions
|
||||||
|
if max_start <= 0:
|
||||||
|
window_start = 0
|
||||||
|
else:
|
||||||
|
window_start = rng.randint(0, max_start)
|
||||||
|
|
||||||
|
window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]]
|
||||||
|
|
||||||
|
# Thin within window
|
||||||
|
if thin_ratio >= 1.0:
|
||||||
|
selected = window_sids
|
||||||
|
else:
|
||||||
|
selected = [sid for sid in window_sids if rng.random() < thin_ratio]
|
||||||
|
|
||||||
|
# Ensure we don't overshoot target by too much
|
||||||
|
if len(selected) > target_n * 1.2:
|
||||||
|
rng.shuffle(selected)
|
||||||
|
selected = selected[:int(target_n * 1.1)]
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def build_output(
|
def build_output(
|
||||||
rows_by_session: dict[str, list[dict]],
|
rows_by_session: dict[str, list[dict]],
|
||||||
selected: list[str],
|
selected: list[str],
|
||||||
@@ -130,17 +185,13 @@ def print_summary(
|
|||||||
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
||||||
qps = n_requests / span_s if span_s > 0 else 0
|
qps = n_requests / span_s if span_s > 0 else 0
|
||||||
|
|
||||||
session_starts = {}
|
# Hash block sharing
|
||||||
|
block_freq: dict[int, int] = collections.Counter()
|
||||||
for r in out_rows:
|
for r in out_rows:
|
||||||
sid = r["session_id"]
|
for h in r.get("hash_ids", []):
|
||||||
ts = float(r["timestamp"])
|
block_freq[h] += 1
|
||||||
if sid not in session_starts:
|
total_blocks = len(block_freq)
|
||||||
session_starts[sid] = ts
|
shared_blocks = sum(1 for c in block_freq.values() if c > 1)
|
||||||
|
|
||||||
# hash_ids overlap
|
|
||||||
all_hashes = set()
|
|
||||||
for r in out_rows:
|
|
||||||
all_hashes.update(r.get("hash_ids", []))
|
|
||||||
|
|
||||||
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
||||||
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
||||||
@@ -152,7 +203,7 @@ def print_summary(
|
|||||||
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
||||||
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
||||||
print(f" QPS: {qps:.2f} req/s")
|
print(f" QPS: {qps:.2f} req/s")
|
||||||
print(f" Unique hash blocks: {len(all_hashes)}")
|
print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@@ -165,7 +216,7 @@ def main() -> None:
|
|||||||
p.add_argument("--sample-ratio", type=float, default=None,
|
p.add_argument("--sample-ratio", type=float, default=None,
|
||||||
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
|
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
|
||||||
p.add_argument("--target-requests", type=int, default=None,
|
p.add_argument("--target-requests", type=int, default=None,
|
||||||
help="Target number of requests (legacy, stops after session that crosses it)")
|
help="Target number of requests (legacy, no sharing preservation)")
|
||||||
p.add_argument("--seed", type=int, default=42)
|
p.add_argument("--seed", type=int, default=42)
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user