Files
agentic-kvc/scripts/compute_apc_upper_bound.py
Gahow Wang b7902061d1 Window 1 analysis: APC upper bound, B2 window-overlap, figure renderer
Three CPU-only analysis pieces that turn raw Window 1 artifacts into
publishable numbers and figures.

scripts/compute_apc_upper_bound.py
  Block-level trie walk over hash_ids to compute the theoretical APC
  ceiling on a trace, decomposed into intra-session / any-session /
  shared-prefix-only. Gives a fixed reference for what each routing
  policy could *possibly* achieve. w600 result: 79.6% intra-session,
  80.3% any-session, 0.1% shared-prefix.

analysis/characterization/b2_sweep_analysis.py (rewrite)
  Previous version used joined_analysis.interference_index() which
  labeled overlap = "any prefill in any other request during this
  decode". With short-prompt decode load this is always true
  (everyone's prefill overlaps everyone else's decode); n_overlap
  was 239/240 even in the different-worker control.

  New version labels overlap iff the decode's [t_first_token, t_finish]
  intersects an actual large *injection* window, computed from the
  cell's "prefill"-tagged metric rows. Different-worker control now
  cleanly sits at idx ≈ 1.0, same-worker scales monotonically.

analysis/characterization/render_window1_figures.py
  Renders 8 PNGs from the result JSONs: B3 latency / APC vs ceiling
  / APC vs hotspot scatter / per-worker TTFT / failure breakdown,
  B2 TPOT and TTFT curves (overlap vs clean and idx), reuse
  decomposition, KV footprint.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 23:24:54 +08:00

160 lines
5.3 KiB
Python

"""Compute theoretical APC upper bound from a trace's hash_ids.
vLLM prefix caching is block-level (BLOCK_SIZE token chunks) and chain-
aware: block i hits the cache iff hash_ids[0..i] matches some previously
seen request's hash_ids[0..i]. The trace's `hash_ids` field is the same
hash chain.
Three variants:
- any_session : trie built across all previously seen requests
- intra_session : trie scoped to the request's own session_id
- shared_prefix : trie built from blocks at position 0, that appear
across >= K distinct sessions (system-prompt proxy)
"""
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from pathlib import Path
BLOCK_SIZE_DEFAULT = 512
def _walk(trie: dict, hashes: list[int]) -> int:
depth = 0
node = trie
for h in hashes:
if h in node:
depth += 1
node = node[h]
else:
break
return depth
def _insert(trie: dict, hashes: list[int]) -> None:
node = trie
for h in hashes:
node = node.setdefault(h, {})
def _resolve_session(row: dict, chat_to_session: dict[int, str]) -> str:
if "session_id" in row:
return str(row["session_id"])
cid = int(row["chat_id"])
pcid = int(row["parent_chat_id"])
if pcid < 0:
sid = str(cid)
else:
sid = chat_to_session.get(pcid, str(pcid))
chat_to_session[cid] = sid
return sid
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--trace", type=Path, required=True)
p.add_argument("--block-size", type=int, default=BLOCK_SIZE_DEFAULT)
p.add_argument("--shared-prefix-min-sessions", type=int, default=8)
p.add_argument("--out", type=Path, default=None)
args = p.parse_args()
rows: list[dict] = []
chat_to_session: dict[int, str] = {}
with args.trace.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
r = json.loads(line)
r["session_id"] = _resolve_session(r, chat_to_session)
rows.append(r)
rows.sort(key=lambda r: float(r.get("timestamp", 0.0)))
global_trie: dict = {}
session_tries: dict[str, dict] = defaultdict(dict)
# First-position block stats: how many sessions hit each top-level
# hash. We approximate "system prefix" as a block seen at position 0
# across many sessions.
pos0_session_set: dict[int, set[str]] = defaultdict(set)
for r in rows:
hids = list(r.get("hash_ids") or [])
if hids:
pos0_session_set[hids[0]].add(r["session_id"])
shared_pos0 = {h for h, s in pos0_session_set.items()
if len(s) >= args.shared_prefix_min_sessions}
total_input = 0
cache_any = 0
cache_intra = 0
cache_shared_only = 0
per_session_input: dict[str, int] = defaultdict(int)
per_session_intra: dict[str, int] = defaultdict(int)
n_with_any_hit = 0
n_with_intra_hit = 0
n_with_shared_hit = 0
for r in rows:
hids = list(r.get("hash_ids") or [])
input_len = int(r.get("input_length") or 0)
sid = r["session_id"]
total_input += input_len
per_session_input[sid] += input_len
g_depth = _walk(global_trie, hids)
s_depth = _walk(session_tries[sid], hids)
# Shared-prefix-only: greedy depth, but stop at first non-shared
# pos0 block (because then no later blocks can be from system
# prefix as a contiguous chain).
sh_depth = 0
if hids and hids[0] in shared_pos0:
sh_depth = 1
# subsequent blocks at deeper positions are NOT modeled as
# "shared system" in this conservative bound.
# accumulate
g_tokens = min(g_depth * args.block_size, input_len)
s_tokens = min(s_depth * args.block_size, input_len)
sh_tokens = min(sh_depth * args.block_size, input_len)
cache_any += g_tokens
cache_intra += s_tokens
cache_shared_only += sh_tokens
per_session_intra[sid] += s_tokens
if g_tokens > 0:
n_with_any_hit += 1
if s_tokens > 0:
n_with_intra_hit += 1
if sh_tokens > 0:
n_with_shared_hit += 1
# update tries
_insert(global_trie, hids)
_insert(session_tries[sid], hids)
out = {
"trace": str(args.trace),
"n_requests": len(rows),
"n_sessions": len({r["session_id"] for r in rows}),
"block_size": args.block_size,
"shared_prefix_min_sessions": args.shared_prefix_min_sessions,
"total_input_tokens": total_input,
"apc_upper_any_session": cache_any / total_input,
"apc_upper_intra_session": cache_intra / total_input,
"apc_upper_shared_prefix_only": cache_shared_only / total_input,
"cached_tokens_any_session": cache_any,
"cached_tokens_intra_session": cache_intra,
"cached_tokens_shared_prefix_only": cache_shared_only,
"n_requests_any_hit": n_with_any_hit,
"n_requests_intra_hit": n_with_intra_hit,
"n_requests_shared_hit": n_with_shared_hit,
"n_shared_pos0_blocks": len(shared_pos0),
}
text = json.dumps(out, indent=2)
if args.out:
args.out.write_text(text)
print(text)
if __name__ == "__main__":
main()