"""Compute theoretical APC upper bound from a trace's hash_ids. vLLM prefix caching is block-level (BLOCK_SIZE token chunks) and chain- aware: block i hits the cache iff hash_ids[0..i] matches some previously seen request's hash_ids[0..i]. The trace's `hash_ids` field is the same hash chain. Three variants: - any_session : trie built across all previously seen requests - intra_session : trie scoped to the request's own session_id - shared_prefix : trie built from blocks at position 0, that appear across >= K distinct sessions (system-prompt proxy) """ from __future__ import annotations import argparse import json from collections import defaultdict from pathlib import Path BLOCK_SIZE_DEFAULT = 512 def _walk(trie: dict, hashes: list[int]) -> int: depth = 0 node = trie for h in hashes: if h in node: depth += 1 node = node[h] else: break return depth def _insert(trie: dict, hashes: list[int]) -> None: node = trie for h in hashes: node = node.setdefault(h, {}) def _resolve_session(row: dict, chat_to_session: dict[int, str]) -> str: if "session_id" in row: return str(row["session_id"]) cid = int(row["chat_id"]) pcid = int(row["parent_chat_id"]) if pcid < 0: sid = str(cid) else: sid = chat_to_session.get(pcid, str(pcid)) chat_to_session[cid] = sid return sid def main() -> None: p = argparse.ArgumentParser() p.add_argument("--trace", type=Path, required=True) p.add_argument("--block-size", type=int, default=BLOCK_SIZE_DEFAULT) p.add_argument("--shared-prefix-min-sessions", type=int, default=8) p.add_argument("--out", type=Path, default=None) args = p.parse_args() rows: list[dict] = [] chat_to_session: dict[int, str] = {} with args.trace.open("r", encoding="utf-8") as fh: for line in fh: line = line.strip() if not line: continue r = json.loads(line) r["session_id"] = _resolve_session(r, chat_to_session) rows.append(r) rows.sort(key=lambda r: float(r.get("timestamp", 0.0))) global_trie: dict = {} session_tries: dict[str, dict] = defaultdict(dict) # First-position block stats: how many sessions hit each top-level # hash. We approximate "system prefix" as a block seen at position 0 # across many sessions. pos0_session_set: dict[int, set[str]] = defaultdict(set) for r in rows: hids = list(r.get("hash_ids") or []) if hids: pos0_session_set[hids[0]].add(r["session_id"]) shared_pos0 = {h for h, s in pos0_session_set.items() if len(s) >= args.shared_prefix_min_sessions} total_input = 0 cache_any = 0 cache_intra = 0 cache_shared_only = 0 per_session_input: dict[str, int] = defaultdict(int) per_session_intra: dict[str, int] = defaultdict(int) n_with_any_hit = 0 n_with_intra_hit = 0 n_with_shared_hit = 0 for r in rows: hids = list(r.get("hash_ids") or []) input_len = int(r.get("input_length") or 0) sid = r["session_id"] total_input += input_len per_session_input[sid] += input_len g_depth = _walk(global_trie, hids) s_depth = _walk(session_tries[sid], hids) # Shared-prefix-only: greedy depth, but stop at first non-shared # pos0 block (because then no later blocks can be from system # prefix as a contiguous chain). sh_depth = 0 if hids and hids[0] in shared_pos0: sh_depth = 1 # subsequent blocks at deeper positions are NOT modeled as # "shared system" in this conservative bound. # accumulate g_tokens = min(g_depth * args.block_size, input_len) s_tokens = min(s_depth * args.block_size, input_len) sh_tokens = min(sh_depth * args.block_size, input_len) cache_any += g_tokens cache_intra += s_tokens cache_shared_only += sh_tokens per_session_intra[sid] += s_tokens if g_tokens > 0: n_with_any_hit += 1 if s_tokens > 0: n_with_intra_hit += 1 if sh_tokens > 0: n_with_shared_hit += 1 # update tries _insert(global_trie, hids) _insert(session_tries[sid], hids) out = { "trace": str(args.trace), "n_requests": len(rows), "n_sessions": len({r["session_id"] for r in rows}), "block_size": args.block_size, "shared_prefix_min_sessions": args.shared_prefix_min_sessions, "total_input_tokens": total_input, "apc_upper_any_session": cache_any / total_input, "apc_upper_intra_session": cache_intra / total_input, "apc_upper_shared_prefix_only": cache_shared_only / total_input, "cached_tokens_any_session": cache_any, "cached_tokens_intra_session": cache_intra, "cached_tokens_shared_prefix_only": cache_shared_only, "n_requests_any_hit": n_with_any_hit, "n_requests_intra_hit": n_with_intra_hit, "n_requests_shared_hit": n_with_shared_hit, "n_shared_pos0_blocks": len(shared_pos0), } text = json.dumps(out, indent=2) if args.out: args.out.write_text(text) print(text) if __name__ == "__main__": main()