Updated methodology: - Window+thin sampling preserves cross-session sharing (48% vs 16%) - --max-single-turn-ratio 0.3 boosts multi-turn to 70% - --window-seconds 600 for 10-min contiguous window - Trace-driven replay (no session limit, no time compression) - Daily config: --requests 850 (~13 min, APC~76%) Key result: TPOT p90=0.175s (vs 0.073s in legacy 1-req/GPU setup), confirming prefill-decode interference is real at production concurrency. APC 67.5% (vs 44%) from better KV reuse preservation. Also fixed KV reuse breakdown: 62% intra-session / 38% cross-session (was incorrectly reported as 91% / 9%). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""Sample sessions from the full cluster-scale trace to fit a single machine.
|
|
|
|
Preserves:
|
|
- Complete session structure (all turns within a session kept together)
|
|
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
|
|
- KV cache reuse patterns (both intra-session AND cross-session sharing)
|
|
- Request type distribution
|
|
|
|
Sampling strategy (--sample-ratio):
|
|
1. Take a contiguous time window from the trace (all sessions whose
|
|
first request falls within the window). This preserves cross-session
|
|
hash block sharing because sessions that share system prompts appear
|
|
together in the same time region.
|
|
2. Within the window, randomly thin sessions by ratio to control QPS.
|
|
3. Re-zero timestamps so first event starts at t=0.
|
|
|
|
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
|
|
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
|
|
intact; the window width is narrowed to compensate.
|
|
|
|
Usage:
|
|
# Sample for 8 GPUs from a ~500-GPU cluster
|
|
python scripts/sample_trace.py \\
|
|
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
|
|
--output traces/sampled.jsonl \\
|
|
--sample-ratio 0.016 --seed 42
|
|
|
|
# Sample by request count (legacy, no sharing preservation)
|
|
python scripts/sample_trace.py \\
|
|
--input ... --output ... --target-requests 1000 --seed 42
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import collections
|
|
import json
|
|
import random
|
|
from pathlib import Path
|
|
|
|
|
|
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
|
|
"""Load trace, group rows by resolved session_id. Preserve file order."""
|
|
chat_to_session: dict[int, str] = {}
|
|
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
|
|
|
|
with path.open("r", encoding="utf-8") as fh:
|
|
for line in fh:
|
|
row = json.loads(line)
|
|
cid = int(row["chat_id"])
|
|
pid = int(row["parent_chat_id"])
|
|
|
|
if "session_id" in row:
|
|
sid = str(row["session_id"])
|
|
elif pid < 0:
|
|
sid = str(cid)
|
|
else:
|
|
sid = chat_to_session.get(pid, str(pid))
|
|
chat_to_session[cid] = sid
|
|
|
|
row["_session_id"] = sid
|
|
rows_by_session.setdefault(sid, []).append(row)
|
|
|
|
return rows_by_session
|
|
|
|
|
|
def sample_sessions(
|
|
rows_by_session: dict[str, list[dict]],
|
|
*,
|
|
sample_ratio: float | None = None,
|
|
target_requests: int | None = None,
|
|
max_single_turn_ratio: float | None = None,
|
|
window_seconds: float | None = None,
|
|
seed: int,
|
|
) -> list[str]:
|
|
"""Sample sessions preserving KV cache reuse."""
|
|
rng = random.Random(seed)
|
|
|
|
if sample_ratio is not None:
|
|
selected = _sample_window_then_thin(rows_by_session, sample_ratio,
|
|
window_seconds, rng)
|
|
elif target_requests is not None:
|
|
all_sids = list(rows_by_session.keys())
|
|
rng.shuffle(all_sids)
|
|
selected = []
|
|
total = 0
|
|
for sid in all_sids:
|
|
selected.append(sid)
|
|
total += len(rows_by_session[sid])
|
|
if total >= target_requests:
|
|
break
|
|
else:
|
|
raise ValueError("Must specify --sample-ratio or --target-requests")
|
|
|
|
if max_single_turn_ratio is not None:
|
|
selected = _cap_single_turn(rows_by_session, selected,
|
|
max_single_turn_ratio, rng)
|
|
|
|
return selected
|
|
|
|
|
|
def _cap_single_turn(
|
|
rows_by_session: dict[str, list[dict]],
|
|
selected: list[str],
|
|
max_ratio: float,
|
|
rng: random.Random,
|
|
) -> list[str]:
|
|
"""Thin single-turn sessions so they are at most max_ratio of total sessions."""
|
|
multi = [s for s in selected if len(rows_by_session[s]) > 1]
|
|
single = [s for s in selected if len(rows_by_session[s]) == 1]
|
|
|
|
# max_ratio of TOTAL sessions should be single-turn
|
|
# n_single / (n_single + n_multi) <= max_ratio
|
|
# n_single <= max_ratio * n_multi / (1 - max_ratio)
|
|
max_single = int(max_ratio * len(multi) / (1 - max_ratio))
|
|
if len(single) <= max_single:
|
|
return selected
|
|
|
|
rng.shuffle(single)
|
|
return multi + single[:max_single]
|
|
|
|
|
|
def _sample_window_then_thin(
|
|
rows_by_session: dict[str, list[dict]],
|
|
ratio: float,
|
|
window_seconds: float | None,
|
|
rng: random.Random,
|
|
) -> list[str]:
|
|
"""Window + thin sampling that preserves cross-session sharing.
|
|
|
|
1. Compute first-request timestamp for each session.
|
|
2. Pick a contiguous time window:
|
|
- If --window-seconds given: use that duration, thin by ratio within it.
|
|
- Otherwise: auto-size so window_sessions * thin_ratio ≈ target.
|
|
3. Keep all sessions whose first request falls within the window.
|
|
4. Randomly thin sessions within the window to hit target count.
|
|
"""
|
|
session_starts: list[tuple[float, str]] = []
|
|
for sid, rows in rows_by_session.items():
|
|
t0 = min(float(r["timestamp"]) for r in rows)
|
|
session_starts.append((t0, sid))
|
|
session_starts.sort()
|
|
|
|
total_sessions = len(session_starts)
|
|
target_n = max(1, int(total_sessions * ratio))
|
|
trace_start = session_starts[0][0]
|
|
trace_end = session_starts[-1][0]
|
|
trace_duration = trace_end - trace_start
|
|
|
|
if window_seconds is not None:
|
|
# Fixed window: pick a random start, thin to hit target ratio
|
|
max_start_t = trace_end - window_seconds
|
|
if max_start_t <= trace_start:
|
|
win_start_t = trace_start
|
|
else:
|
|
win_start_t = trace_start + rng.random() * (max_start_t - trace_start)
|
|
win_end_t = win_start_t + window_seconds
|
|
|
|
window_sids = [sid for t, sid in session_starts
|
|
if win_start_t <= t <= win_end_t]
|
|
# Thin to target
|
|
if len(window_sids) > target_n:
|
|
thin_ratio = target_n / len(window_sids)
|
|
window_sids = [s for s in window_sids if rng.random() < thin_ratio]
|
|
return window_sids
|
|
|
|
# Auto-size window
|
|
thin_ratio = min(1.0, max(0.5, ratio * 10))
|
|
window_sessions = min(int(target_n / thin_ratio), total_sessions)
|
|
|
|
max_start = total_sessions - window_sessions
|
|
window_start = rng.randint(0, max_start) if max_start > 0 else 0
|
|
window_sids = [sid for _, sid in
|
|
session_starts[window_start:window_start + window_sessions]]
|
|
|
|
if thin_ratio < 1.0:
|
|
window_sids = [s for s in window_sids if rng.random() < thin_ratio]
|
|
|
|
if len(window_sids) > target_n * 1.2:
|
|
rng.shuffle(window_sids)
|
|
window_sids = window_sids[:int(target_n * 1.1)]
|
|
|
|
return window_sids
|
|
|
|
|
|
def build_output(
|
|
rows_by_session: dict[str, list[dict]],
|
|
selected: list[str],
|
|
) -> list[dict]:
|
|
"""Build output rows with re-zeroed timestamps (no time compression)."""
|
|
out_rows = []
|
|
for sid in selected:
|
|
for row in rows_by_session[sid]:
|
|
out = {k: v for k, v in row.items() if not k.startswith("_")}
|
|
out["session_id"] = sid
|
|
out_rows.append(out)
|
|
|
|
out_rows.sort(key=lambda r: float(r["timestamp"]))
|
|
|
|
if not out_rows:
|
|
return out_rows
|
|
|
|
t0 = float(out_rows[0]["timestamp"])
|
|
for row in out_rows:
|
|
row["timestamp"] = float(row["timestamp"]) - t0
|
|
|
|
return out_rows
|
|
|
|
|
|
def print_summary(
|
|
rows_by_session: dict[str, list[dict]],
|
|
selected: list[str],
|
|
out_rows: list[dict],
|
|
) -> None:
|
|
n_sessions = len(selected)
|
|
n_requests = len(out_rows)
|
|
turns_per_session = [len(rows_by_session[s]) for s in selected]
|
|
multi_turn = sum(1 for t in turns_per_session if t > 1)
|
|
|
|
input_lens = [r["input_length"] for r in out_rows]
|
|
output_lens = [r["output_length"] for r in out_rows]
|
|
|
|
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
|
|
qps = n_requests / span_s if span_s > 0 else 0
|
|
|
|
# Hash block sharing
|
|
block_freq: dict[int, int] = collections.Counter()
|
|
for r in out_rows:
|
|
for h in r.get("hash_ids", []):
|
|
block_freq[h] += 1
|
|
total_blocks = len(block_freq)
|
|
shared_blocks = sum(1 for c in block_freq.values() if c > 1)
|
|
|
|
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
|
|
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
|
|
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
|
|
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
|
|
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
|
|
f"avg={sum(input_lens)/len(input_lens):.0f}")
|
|
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
|
|
f"avg={sum(output_lens)/len(output_lens):.0f}")
|
|
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
|
|
print(f" QPS: {qps:.2f} req/s")
|
|
print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
p.add_argument("--input", type=Path, required=True,
|
|
help="Path to the full trace JSONL file")
|
|
p.add_argument("--output", type=Path, required=True,
|
|
help="Path to write sampled trace JSONL")
|
|
p.add_argument("--sample-ratio", type=float, default=None,
|
|
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
|
|
p.add_argument("--target-requests", type=int, default=None,
|
|
help="Target number of requests (legacy, no sharing preservation)")
|
|
p.add_argument("--max-single-turn-ratio", type=float, default=None,
|
|
help="Cap single-turn sessions to this fraction of total (e.g. 0.3)")
|
|
p.add_argument("--window-seconds", type=float, default=None,
|
|
help="Time window duration in seconds (e.g. 600 for 10 min)")
|
|
p.add_argument("--seed", type=int, default=42)
|
|
args = p.parse_args()
|
|
|
|
if args.sample_ratio is None and args.target_requests is None:
|
|
p.error("Must specify --sample-ratio or --target-requests")
|
|
|
|
print(f"Loading trace from {args.input} ...")
|
|
rows_by_session = load_raw_rows(args.input)
|
|
total_sessions = len(rows_by_session)
|
|
total_requests = sum(len(v) for v in rows_by_session.values())
|
|
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
|
|
|
|
selected = sample_sessions(
|
|
rows_by_session,
|
|
sample_ratio=args.sample_ratio,
|
|
target_requests=args.target_requests,
|
|
max_single_turn_ratio=args.max_single_turn_ratio,
|
|
window_seconds=args.window_seconds,
|
|
seed=args.seed,
|
|
)
|
|
|
|
out_rows = build_output(rows_by_session, selected)
|
|
|
|
print_summary(rows_by_session, selected, out_rows)
|
|
|
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
|
with args.output.open("w", encoding="utf-8") as fh:
|
|
for row in out_rows:
|
|
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
print(f"\nWrote {len(out_rows)} rows to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|