Random session sampling destroys cross-session hash block sharing (52% -> 16%) because sessions sharing system prompts get scattered. New approach: take a contiguous time window from the trace (preserving temporal locality of shared-prefix sessions), then thin within the window to hit target QPS. This preserves both intra-session reuse (62% of reusable tokens) and cross-session sharing (38%). Results (block sharing rate): Old random r=0.002: 16.0% -> Window+thin: 29.7% Old random r=0.016: 19.5% -> Window+thin: 42.7% Full trace baseline: 52% Also corrected the "91% intra-session" claim: actual split is 62% intra / 38% cross (token-level), making cross-session sharing preservation critical for valid APC benchmarks. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
252 lines
8.9 KiB
Python
252 lines
8.9 KiB
Python
"""Sample sessions from the full cluster-scale trace to fit a single machine.
|
|
|
|
Preserves:
|
|
- Complete session structure (all turns within a session kept together)
|
|
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
|
|
- KV cache reuse patterns (both intra-session AND cross-session sharing)
|
|
- Request type distribution
|
|
|
|
Sampling strategy (--sample-ratio):
|
|
1. Take a contiguous time window from the trace (all sessions whose
|
|
first request falls within the window). This preserves cross-session
|
|
hash block sharing because sessions that share system prompts appear
|
|
together in the same time region.
|
|
2. Within the window, randomly thin sessions by ratio to control QPS.
|
|
3. Re-zero timestamps so first event starts at t=0.
|
|
|
|
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
|
|
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
|
|
intact; the window width is narrowed to compensate.
|
|
|
|
Usage:
|
|
# Sample for 8 GPUs from a ~500-GPU cluster
|
|
python scripts/sample_trace.py \\
|
|
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
|
--output traces/sampled.jsonl \\
|
|
--sample-ratio 0.016 --seed 42
|
|
|
|
# Sample by request count (legacy, no sharing preservation)
|
|
python scripts/sample_trace.py \\
|
|
--input ... --output ... --target-requests 1000 --seed 42
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
import json
|
|
import random
|
|
from pathlib import Path
|
|
|
|
|
|
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
|
|
"""Load trace, group rows by resolved session_id. Preserve file order."""
|
|
chat_to_session: dict[int, str] = {}
|
|
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
|
|
|
|
with path.open("r", encoding="utf-8") as fh:
|
|
for line in fh:
|
|
row = json.loads(line)
|
|
cid = int(row["chat_id"])
|
|
pid = int(row["parent_chat_id"])
|
|
|
|
if "session_id" in row:
|
|
sid = str(row["session_id"])
|
|
elif pid < 0:
|
|
sid = str(cid)
|
|
else:
|
|
sid = chat_to_session.get(pid, str(pid))
|
|
chat_to_session[cid] = sid
|
|
|
|
row["_session_id"] = sid
|
|
rows_by_session.setdefault(sid, []).append(row)
|
|
|
|
return rows_by_session
|
|
|
|
|
|
def sample_sessions(
|
|
rows_by_session: dict[str, list[dict]],
|
|
*,
|
|
sample_ratio: float | None = None,
|
|
target_requests: int | None = None,
|
|
seed: int,
|
|
) -> list[str]:
|
|
"""Sample sessions preserving KV cache reuse."""
|
|
rng = random.Random(seed)
|
|
|
|
if sample_ratio is not None:
|
|
return _sample_window_then_thin(rows_by_session, sample_ratio, rng)
|
|
|
|
if target_requests is not None:
|
|
all_sids = list(rows_by_session.keys())
|
|
rng.shuffle(all_sids)
|
|
selected = []
|
|
total = 0
|
|
for sid in all_sids:
|
|
selected.append(sid)
|
|
total += len(rows_by_session[sid])
|
|
if total >= target_requests:
|
|
break
|
|
return selected
|
|
|
|
raise ValueError("Must specify --sample-ratio or --target-requests")
|
|
|
|
|
|
def _sample_window_then_thin(
|
|
rows_by_session: dict[str, list[dict]],
|
|
ratio: float,
|
|
rng: random.Random,
|
|
) -> list[str]:
|
|
"""Window + thin sampling that preserves cross-session sharing.
|
|
|
|
1. Compute first-request timestamp for each session.
|
|
2. Pick a contiguous time window sized so that
|
|
window_sessions * thin_ratio ≈ total_sessions * ratio.
|
|
thin_ratio is kept >= 0.5 to preserve cross-session sharing.
|
|
3. Randomly drop (1 - thin_ratio) of sessions within the window.
|
|
"""
|
|
# Session start times (timestamp of first request)
|
|
session_starts: list[tuple[float, str]] = []
|
|
for sid, rows in rows_by_session.items():
|
|
t0 = min(float(r["timestamp"]) for r in rows)
|
|
session_starts.append((t0, sid))
|
|
session_starts.sort()
|
|
|
|
total_sessions = len(session_starts)
|
|
target_n = max(1, int(total_sessions * ratio))
|
|
|
|
# Determine thin_ratio and window size
|
|
# thin_ratio >= 0.5 to preserve sharing; prefer 1.0 if window fits
|
|
# window_sessions = target_n / thin_ratio
|
|
thin_ratio = min(1.0, max(0.5, ratio * 10))
|
|
window_sessions = int(target_n / thin_ratio)
|
|
window_sessions = min(window_sessions, total_sessions)
|
|
|
|
# Pick window start: random position in the trace
|
|
max_start = total_sessions - window_sessions
|
|
if max_start <= 0:
|
|
window_start = 0
|
|
else:
|
|
window_start = rng.randint(0, max_start)
|
|
|
|
window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]]
|
|
|
|
# Thin within window
|
|
if thin_ratio >= 1.0:
|
|
selected = window_sids
|
|
else:
|
|
selected = [sid for sid in window_sids if rng.random() < thin_ratio]
|
|
|
|
# Ensure we don't overshoot target by too much
|
|
if len(selected) > target_n * 1.2:
|
|
rng.shuffle(selected)
|
|
selected = selected[:int(target_n * 1.1)]
|
|
|
|
return selected
|
|
|
|
|
|
def build_output(
|
|
rows_by_session: dict[str, list[dict]],
|
|
selected: list[str],
|
|
) -> list[dict]:
|
|
"""Build output rows with re-zeroed timestamps (no time compression)."""
|
|
out_rows = []
|
|
for sid in selected:
|
|
for row in rows_by_session[sid]:
|
|
out = {k: v for k, v in row.items() if not k.startswith("_")}
|
|
out["session_id"] = sid
|
|
out_rows.append(out)
|
|
|
|
out_rows.sort(key=lambda r: float(r["timestamp"]))
|
|
|
|
if not out_rows:
|
|
return out_rows
|
|
|
|
t0 = float(out_rows[0]["timestamp"])
|
|
for row in out_rows:
|
|
row["timestamp"] = float(row["timestamp"]) - t0
|
|
|
|
return out_rows
|
|
|
|
|
|
def print_summary(
|
|
rows_by_session: dict[str, list[dict]],
|
|
selected: list[str],
|
|
out_rows: list[dict],
|
|
) -> None:
|
|
n_sessions = len(selected)
|
|
n_requests = len(out_rows)
|
|
turns_per_session = [len(rows_by_session[s]) for s in selected]
|
|
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
|
|
|
input_lens = [r["input_length"] for r in out_rows]
|
|
output_lens = [r["output_length"] for r in out_rows]
|
|
|
|
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
|
qps = n_requests / span_s if span_s > 0 else 0
|
|
|
|
# Hash block sharing
|
|
block_freq: dict[int, int] = collections.Counter()
|
|
for r in out_rows:
|
|
for h in r.get("hash_ids", []):
|
|
block_freq[h] += 1
|
|
total_blocks = len(block_freq)
|
|
shared_blocks = sum(1 for c in block_freq.values() if c > 1)
|
|
|
|
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
|
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
|
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
|
|
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
|
|
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
|
|
f"avg={sum(input_lens)/len(input_lens):.0f}")
|
|
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
|
|
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
|
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
|
print(f" QPS: {qps:.2f} req/s")
|
|
print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
p.add_argument("--input", type=Path, required=True,
|
|
help="Path to the full trace JSONL file")
|
|
p.add_argument("--output", type=Path, required=True,
|
|
help="Path to write sampled trace JSONL")
|
|
p.add_argument("--sample-ratio", type=float, default=None,
|
|
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
|
|
p.add_argument("--target-requests", type=int, default=None,
|
|
help="Target number of requests (legacy, no sharing preservation)")
|
|
p.add_argument("--seed", type=int, default=42)
|
|
args = p.parse_args()
|
|
|
|
if args.sample_ratio is None and args.target_requests is None:
|
|
p.error("Must specify --sample-ratio or --target-requests")
|
|
|
|
print(f"Loading trace from {args.input} ...")
|
|
rows_by_session = load_raw_rows(args.input)
|
|
total_sessions = len(rows_by_session)
|
|
total_requests = sum(len(v) for v in rows_by_session.values())
|
|
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
|
|
|
|
selected = sample_sessions(
|
|
rows_by_session,
|
|
sample_ratio=args.sample_ratio,
|
|
target_requests=args.target_requests,
|
|
seed=args.seed,
|
|
)
|
|
|
|
out_rows = build_output(rows_by_session, selected)
|
|
|
|
print_summary(rows_by_session, selected, out_rows)
|
|
|
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
with args.output.open("w", encoding="utf-8") as fh:
|
|
for row in out_rows:
|
|
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
print(f"\nWrote {len(out_rows)} rows to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|