Files
agentic-kvc/scripts/sample_trace.py
Gahow Wang 1e1e2e774d Fix sampler: window+thin preserves cross-session KV cache sharing
Random session sampling destroys cross-session hash block sharing
(52% -> 16%) because sessions sharing system prompts get scattered.

New approach: take a contiguous time window from the trace (preserving
temporal locality of shared-prefix sessions), then thin within the
window to hit target QPS. This preserves both intra-session reuse
(62% of reusable tokens) and cross-session sharing (38%).

Results (block sharing rate):
  Old random r=0.002:  16.0%  ->  Window+thin: 29.7%
  Old random r=0.016:  19.5%  ->  Window+thin: 42.7%
  Full trace baseline: 52%

Also corrected the "91% intra-session" claim: actual split is
62% intra / 38% cross (token-level), making cross-session sharing
preservation critical for valid APC benchmarks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-23 14:03:12 +08:00

252 lines
8.9 KiB
Python

"""Sample sessions from the full cluster-scale trace to fit a single machine.
Preserves:
- Complete session structure (all turns within a session kept together)
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
- KV cache reuse patterns (both intra-session AND cross-session sharing)
- Request type distribution
Sampling strategy (--sample-ratio):
1. Take a contiguous time window from the trace (all sessions whose
first request falls within the window). This preserves cross-session
hash block sharing because sessions that share system prompts appear
together in the same time region.
2. Within the window, randomly thin sessions by ratio to control QPS.
3. Re-zero timestamps so first event starts at t=0.
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
intact; the window width is narrowed to compensate.
Usage:
# Sample for 8 GPUs from a ~500-GPU cluster
python scripts/sample_trace.py \\
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
--output traces/sampled.jsonl \\
--sample-ratio 0.016 --seed 42
# Sample by request count (legacy, no sharing preservation)
python scripts/sample_trace.py \\
--input ... --output ... --target-requests 1000 --seed 42
"""
from __future__ import annotations
import argparse
import collections
import json
import random
from pathlib import Path
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
"""Load trace, group rows by resolved session_id. Preserve file order."""
chat_to_session: dict[int, str] = {}
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
with path.open("r", encoding="utf-8") as fh:
for line in fh:
row = json.loads(line)
cid = int(row["chat_id"])
pid = int(row["parent_chat_id"])
if "session_id" in row:
sid = str(row["session_id"])
elif pid < 0:
sid = str(cid)
else:
sid = chat_to_session.get(pid, str(pid))
chat_to_session[cid] = sid
row["_session_id"] = sid
rows_by_session.setdefault(sid, []).append(row)
return rows_by_session
def sample_sessions(
rows_by_session: dict[str, list[dict]],
*,
sample_ratio: float | None = None,
target_requests: int | None = None,
seed: int,
) -> list[str]:
"""Sample sessions preserving KV cache reuse."""
rng = random.Random(seed)
if sample_ratio is not None:
return _sample_window_then_thin(rows_by_session, sample_ratio, rng)
if target_requests is not None:
all_sids = list(rows_by_session.keys())
rng.shuffle(all_sids)
selected = []
total = 0
for sid in all_sids:
selected.append(sid)
total += len(rows_by_session[sid])
if total >= target_requests:
break
return selected
raise ValueError("Must specify --sample-ratio or --target-requests")
def _sample_window_then_thin(
rows_by_session: dict[str, list[dict]],
ratio: float,
rng: random.Random,
) -> list[str]:
"""Window + thin sampling that preserves cross-session sharing.
1. Compute first-request timestamp for each session.
2. Pick a contiguous time window sized so that
window_sessions * thin_ratio ≈ total_sessions * ratio.
thin_ratio is kept >= 0.5 to preserve cross-session sharing.
3. Randomly drop (1 - thin_ratio) of sessions within the window.
"""
# Session start times (timestamp of first request)
session_starts: list[tuple[float, str]] = []
for sid, rows in rows_by_session.items():
t0 = min(float(r["timestamp"]) for r in rows)
session_starts.append((t0, sid))
session_starts.sort()
total_sessions = len(session_starts)
target_n = max(1, int(total_sessions * ratio))
# Determine thin_ratio and window size
# thin_ratio >= 0.5 to preserve sharing; prefer 1.0 if window fits
# window_sessions = target_n / thin_ratio
thin_ratio = min(1.0, max(0.5, ratio * 10))
window_sessions = int(target_n / thin_ratio)
window_sessions = min(window_sessions, total_sessions)
# Pick window start: random position in the trace
max_start = total_sessions - window_sessions
if max_start <= 0:
window_start = 0
else:
window_start = rng.randint(0, max_start)
window_sids = [sid for _, sid in session_starts[window_start:window_start + window_sessions]]
# Thin within window
if thin_ratio >= 1.0:
selected = window_sids
else:
selected = [sid for sid in window_sids if rng.random() < thin_ratio]
# Ensure we don't overshoot target by too much
if len(selected) > target_n * 1.2:
rng.shuffle(selected)
selected = selected[:int(target_n * 1.1)]
return selected
def build_output(
rows_by_session: dict[str, list[dict]],
selected: list[str],
) -> list[dict]:
"""Build output rows with re-zeroed timestamps (no time compression)."""
out_rows = []
for sid in selected:
for row in rows_by_session[sid]:
out = {k: v for k, v in row.items() if not k.startswith("_")}
out["session_id"] = sid
out_rows.append(out)
out_rows.sort(key=lambda r: float(r["timestamp"]))
if not out_rows:
return out_rows
t0 = float(out_rows[0]["timestamp"])
for row in out_rows:
row["timestamp"] = float(row["timestamp"]) - t0
return out_rows
def print_summary(
rows_by_session: dict[str, list[dict]],
selected: list[str],
out_rows: list[dict],
) -> None:
n_sessions = len(selected)
n_requests = len(out_rows)
turns_per_session = [len(rows_by_session[s]) for s in selected]
multi_turn = sum(1 for t in turns_per_session if t > 1)
input_lens = [r["input_length"] for r in out_rows]
output_lens = [r["output_length"] for r in out_rows]
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
qps = n_requests / span_s if span_s > 0 else 0
# Hash block sharing
block_freq: dict[int, int] = collections.Counter()
for r in out_rows:
for h in r.get("hash_ids", []):
block_freq[h] += 1
total_blocks = len(block_freq)
shared_blocks = sum(1 for c in block_freq.values() if c > 1)
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
f"avg={sum(input_lens)/len(input_lens):.0f}")
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
f"avg={sum(output_lens)/len(output_lens):.0f}")
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
print(f" QPS: {qps:.2f} req/s")
print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True,
help="Path to the full trace JSONL file")
p.add_argument("--output", type=Path, required=True,
help="Path to write sampled trace JSONL")
p.add_argument("--sample-ratio", type=float, default=None,
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
p.add_argument("--target-requests", type=int, default=None,
help="Target number of requests (legacy, no sharing preservation)")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if args.sample_ratio is None and args.target_requests is None:
p.error("Must specify --sample-ratio or --target-requests")
print(f"Loading trace from {args.input} ...")
rows_by_session = load_raw_rows(args.input)
total_sessions = len(rows_by_session)
total_requests = sum(len(v) for v in rows_by_session.values())
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
selected = sample_sessions(
rows_by_session,
sample_ratio=args.sample_ratio,
target_requests=args.target_requests,
seed=args.seed,
)
out_rows = build_output(rows_by_session, selected)
print_summary(rows_by_session, selected, out_rows)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as fh:
for row in out_rows:
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"\nWrote {len(out_rows)} rows to {args.output}")
if __name__ == "__main__":
main()