Files
agentic-kvc/scripts/sample_trace.py
Gahow Wang bf037594c4 Production-realistic baseline: APC 67.5%, TPOT +139% from interference
Updated methodology:
- Window+thin sampling preserves cross-session sharing (48% vs 16%)
- --max-single-turn-ratio 0.3 boosts multi-turn to 70%
- --window-seconds 600 for 10-min contiguous window
- Trace-driven replay (no session limit, no time compression)
- Daily config: --requests 850 (~13 min, APC~76%)

Key result: TPOT p90=0.175s (vs 0.073s in legacy 1-req/GPU setup),
confirming prefill-decode interference is real at production concurrency.
APC 67.5% (vs 44%) from better KV reuse preservation.

Also fixed KV reuse breakdown: 62% intra-session / 38% cross-session
(was incorrectly reported as 91% / 9%).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-23 15:44:34 +08:00

296 lines
11 KiB
Python

"""Sample sessions from the full cluster-scale trace to fit a single machine.
Preserves:
- Complete session structure (all turns within a session kept together)
- Original arrival timing (re-zeroed to t=0 but NOT compressed)
- KV cache reuse patterns (both intra-session AND cross-session sharing)
- Request type distribution
Sampling strategy (--sample-ratio):
1. Take a contiguous time window from the trace (all sessions whose
first request falls within the window). This preserves cross-session
hash block sharing because sessions that share system prompts appear
together in the same time region.
2. Within the window, randomly thin sessions by ratio to control QPS.
3. Re-zero timestamps so first event starts at t=0.
The window is sized so that (window_sessions * thin_ratio) ≈ target count.
Thin ratio is set high enough (≥0.5) to keep cross-session block sharing
intact; the window width is narrowed to compensate.
Usage:
# Sample for 8 GPUs from a ~500-GPU cluster
python scripts/sample_trace.py \\
--input ~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl \\
--output traces/sampled.jsonl \\
--sample-ratio 0.016 --seed 42
# Sample by request count (legacy, no sharing preservation)
python scripts/sample_trace.py \\
--input ... --output ... --target-requests 1000 --seed 42
"""
from __future__ import annotations
import argparse
import collections
import json
import random
from pathlib import Path
def load_raw_rows(path: Path) -> dict[str, list[dict]]:
"""Load trace, group rows by resolved session_id. Preserve file order."""
chat_to_session: dict[int, str] = {}
rows_by_session: dict[str, list[dict]] = collections.OrderedDict()
with path.open("r", encoding="utf-8") as fh:
for line in fh:
row = json.loads(line)
cid = int(row["chat_id"])
pid = int(row["parent_chat_id"])
if "session_id" in row:
sid = str(row["session_id"])
elif pid < 0:
sid = str(cid)
else:
sid = chat_to_session.get(pid, str(pid))
chat_to_session[cid] = sid
row["_session_id"] = sid
rows_by_session.setdefault(sid, []).append(row)
return rows_by_session
def sample_sessions(
rows_by_session: dict[str, list[dict]],
*,
sample_ratio: float | None = None,
target_requests: int | None = None,
max_single_turn_ratio: float | None = None,
window_seconds: float | None = None,
seed: int,
) -> list[str]:
"""Sample sessions preserving KV cache reuse."""
rng = random.Random(seed)
if sample_ratio is not None:
selected = _sample_window_then_thin(rows_by_session, sample_ratio,
window_seconds, rng)
elif target_requests is not None:
all_sids = list(rows_by_session.keys())
rng.shuffle(all_sids)
selected = []
total = 0
for sid in all_sids:
selected.append(sid)
total += len(rows_by_session[sid])
if total >= target_requests:
break
else:
raise ValueError("Must specify --sample-ratio or --target-requests")
if max_single_turn_ratio is not None:
selected = _cap_single_turn(rows_by_session, selected,
max_single_turn_ratio, rng)
return selected
def _cap_single_turn(
rows_by_session: dict[str, list[dict]],
selected: list[str],
max_ratio: float,
rng: random.Random,
) -> list[str]:
"""Thin single-turn sessions so they are at most max_ratio of total sessions."""
multi = [s for s in selected if len(rows_by_session[s]) > 1]
single = [s for s in selected if len(rows_by_session[s]) == 1]
# max_ratio of TOTAL sessions should be single-turn
# n_single / (n_single + n_multi) <= max_ratio
# n_single <= max_ratio * n_multi / (1 - max_ratio)
max_single = int(max_ratio * len(multi) / (1 - max_ratio))
if len(single) <= max_single:
return selected
rng.shuffle(single)
return multi + single[:max_single]
def _sample_window_then_thin(
rows_by_session: dict[str, list[dict]],
ratio: float,
window_seconds: float | None,
rng: random.Random,
) -> list[str]:
"""Window + thin sampling that preserves cross-session sharing.
1. Compute first-request timestamp for each session.
2. Pick a contiguous time window:
- If --window-seconds given: use that duration, thin by ratio within it.
- Otherwise: auto-size so window_sessions * thin_ratio ≈ target.
3. Keep all sessions whose first request falls within the window.
4. Randomly thin sessions within the window to hit target count.
"""
session_starts: list[tuple[float, str]] = []
for sid, rows in rows_by_session.items():
t0 = min(float(r["timestamp"]) for r in rows)
session_starts.append((t0, sid))
session_starts.sort()
total_sessions = len(session_starts)
target_n = max(1, int(total_sessions * ratio))
trace_start = session_starts[0][0]
trace_end = session_starts[-1][0]
trace_duration = trace_end - trace_start
if window_seconds is not None:
# Fixed window: pick a random start, thin to hit target ratio
max_start_t = trace_end - window_seconds
if max_start_t <= trace_start:
win_start_t = trace_start
else:
win_start_t = trace_start + rng.random() * (max_start_t - trace_start)
win_end_t = win_start_t + window_seconds
window_sids = [sid for t, sid in session_starts
if win_start_t <= t <= win_end_t]
# Thin to target
if len(window_sids) > target_n:
thin_ratio = target_n / len(window_sids)
window_sids = [s for s in window_sids if rng.random() < thin_ratio]
return window_sids
# Auto-size window
thin_ratio = min(1.0, max(0.5, ratio * 10))
window_sessions = min(int(target_n / thin_ratio), total_sessions)
max_start = total_sessions - window_sessions
window_start = rng.randint(0, max_start) if max_start > 0 else 0
window_sids = [sid for _, sid in
session_starts[window_start:window_start + window_sessions]]
if thin_ratio < 1.0:
window_sids = [s for s in window_sids if rng.random() < thin_ratio]
if len(window_sids) > target_n * 1.2:
rng.shuffle(window_sids)
window_sids = window_sids[:int(target_n * 1.1)]
return window_sids
def build_output(
rows_by_session: dict[str, list[dict]],
selected: list[str],
) -> list[dict]:
"""Build output rows with re-zeroed timestamps (no time compression)."""
out_rows = []
for sid in selected:
for row in rows_by_session[sid]:
out = {k: v for k, v in row.items() if not k.startswith("_")}
out["session_id"] = sid
out_rows.append(out)
out_rows.sort(key=lambda r: float(r["timestamp"]))
if not out_rows:
return out_rows
t0 = float(out_rows[0]["timestamp"])
for row in out_rows:
row["timestamp"] = float(row["timestamp"]) - t0
return out_rows
def print_summary(
rows_by_session: dict[str, list[dict]],
selected: list[str],
out_rows: list[dict],
) -> None:
n_sessions = len(selected)
n_requests = len(out_rows)
turns_per_session = [len(rows_by_session[s]) for s in selected]
multi_turn = sum(1 for t in turns_per_session if t > 1)
input_lens = [r["input_length"] for r in out_rows]
output_lens = [r["output_length"] for r in out_rows]
span_s = float(out_rows[-1]["timestamp"]) if out_rows else 0
qps = n_requests / span_s if span_s > 0 else 0
# Hash block sharing
block_freq: dict[int, int] = collections.Counter()
for r in out_rows:
for h in r.get("hash_ids", []):
block_freq[h] += 1
total_blocks = len(block_freq)
shared_blocks = sum(1 for c in block_freq.values() if c > 1)
print(f"Sampled: {n_sessions} sessions, {n_requests} requests")
print(f" Multi-turn sessions: {multi_turn} ({multi_turn/n_sessions*100:.1f}%)")
print(f" Turns/session: min={min(turns_per_session)} max={max(turns_per_session)} "
f"avg={sum(turns_per_session)/len(turns_per_session):.1f}")
print(f" Input length: min={min(input_lens)} max={max(input_lens)} "
f"avg={sum(input_lens)/len(input_lens):.0f}")
print(f" Output length: min={min(output_lens)} max={max(output_lens)} "
f"avg={sum(output_lens)/len(output_lens):.0f}")
print(f" Trace span: {span_s:.1f}s ({span_s/60:.1f} min)")
print(f" QPS: {qps:.2f} req/s")
print(f" Hash blocks: {total_blocks} unique, {shared_blocks} shared ({shared_blocks*100/total_blocks:.1f}%)")
def main() -> None:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--input", type=Path, required=True,
help="Path to the full trace JSONL file")
p.add_argument("--output", type=Path, required=True,
help="Path to write sampled trace JSONL")
p.add_argument("--sample-ratio", type=float, default=None,
help="Fraction of sessions to sample (e.g. 0.016 for 8/500 GPU ratio)")
p.add_argument("--target-requests", type=int, default=None,
help="Target number of requests (legacy, no sharing preservation)")
p.add_argument("--max-single-turn-ratio", type=float, default=None,
help="Cap single-turn sessions to this fraction of total (e.g. 0.3)")
p.add_argument("--window-seconds", type=float, default=None,
help="Time window duration in seconds (e.g. 600 for 10 min)")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if args.sample_ratio is None and args.target_requests is None:
p.error("Must specify --sample-ratio or --target-requests")
print(f"Loading trace from {args.input} ...")
rows_by_session = load_raw_rows(args.input)
total_sessions = len(rows_by_session)
total_requests = sum(len(v) for v in rows_by_session.values())
print(f"Full trace: {total_sessions} sessions, {total_requests} requests")
selected = sample_sessions(
rows_by_session,
sample_ratio=args.sample_ratio,
target_requests=args.target_requests,
max_single_turn_ratio=args.max_single_turn_ratio,
window_seconds=args.window_seconds,
seed=args.seed,
)
out_rows = build_output(rows_by_session, selected)
print_summary(rows_by_session, selected, out_rows)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as fh:
for row in out_rows:
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"\nWrote {len(out_rows)} rows to {args.output}")
if __name__ == "__main__":
main()