Files
agentic-kvc/paper/data/f2a_reuse_topology_analyze.py
Gahow Wang 19c443e3bc paper f2a: reuse-topology decomposition + mixture-sensitivity sweep
Full-trace analysis backing figure 2a on the real 2h cluster trace:

- f2a_reuse_topology_analyze.py: infinite-KV-cache (LRU) decomposition of
  prefix-cache reuse hits into intra-session vs cross-session, by most-recent
  prior holder of each content-addressed block.
- f2a_mixture_sweep.py: sensitivity of the intra/cross split to the
  single-turn session fraction (tests whether the 93%-intra sample vs 54.6%
  full-trace gap is session-mixture selection bias) -- keep all multi-turn
  sessions, downsample single-turn to each target fraction, reclassify.

Includes the result JSONs for both.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-01 01:03:40 +08:00

183 lines
6.7 KiB
Python

#!/usr/bin/env python3
"""
f2a reuse topology — full-trace, infinite-KV-cache decomposition (LRU semantics).
Question: on the real 2h cluster trace, assuming an *infinite* KV cache (nothing
ever evicted), where do prefix-cache REUSE HITS come from?
We classify only reuse hits (the 1st occurrence of a block is `new` = irreducible
prefill; it is reported only as context for the APC ceiling, not in the split).
A block (content-addressed `hash_id`) processed in timestamp order. For each hit we
look at the block's **most recent prior holder** (last computed OR used = LRU):
intra : last touch was the SAME session (parent_chat_id chain)
cross : last touch was a DIFFERENT session
After classifying, the block's last-holder / last-time are updated to the current
request (LRU refresh). The reuse "recency" is the **LRU reuse distance** = time since
the block was last touched (what a finite TTL/LRU cache would need to retain).
`cross` is further resolved by *block popularity* = number of distinct sessions that
ever touch the block: a handful of hugely-popular blocks are the shared system/tool
prefix; low-popularity cross blocks are genuine cross-session content.
Run on dash2 (trace lives there):
python3 f2a_reuse_topology_analyze.py \
~/ali-trace/trace-glm5.1-formatted/051315-051317.jsonl /tmp/f2a_result.json
"""
import sys, json, time
from collections import defaultdict
PATH = sys.argv[1]
OUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/f2a_result.json"
POP_CAP = 4096 # cap per-block root set; >= this is "very shared", buckets unaffected
t0 = time.time()
chat_parent = {}
records = [] # (ts, chat_id, hash_ids)
total_input_tokens = 0
total_blocks = 0
turn1 = 0
n = 0
with open(PATH) as f:
for line in f:
d = json.loads(line)
cid = d["chat_id"]
pc = d.get("parent_chat_id")
chat_parent[cid] = 0 if pc is None else pc
hs = d.get("hash_ids") or []
records.append((d.get("timestamp", 0.0), cid, hs))
total_input_tokens += d.get("input_length", 0) or 0
total_blocks += len(hs)
if (d.get("turn", 1) or 1) == 1:
turn1 += 1
n += 1
sys.stderr.write(f"[{time.time()-t0:.0f}s] loaded {n} reqs, {total_blocks} block-occ\n")
# resolve session root by following parent_chat_id to turn-1 / out-of-window head
root_cache = {}
def resolve_root(cid):
chain = []
cur = cid
while True:
if cur in root_cache:
r = root_cache[cur]; break
p = chat_parent.get(cur, 0)
if p == 0 or p not in chat_parent:
r = cur; break
chain.append(cur); cur = p
if len(chain) > 100000:
r = cur; break
for nd in chain:
root_cache[nd] = r
root_cache[cid] = r
return r
records.sort(key=lambda r: r[0])
sys.stderr.write(f"[{time.time()-t0:.0f}s] sorted by ts\n")
last_root = {} # block -> root of MOST RECENT holder (LRU)
last_ts = {} # block -> ts of most recent touch (LRU)
roots_of = defaultdict(set) # block -> set of distinct roots (capped) = popularity
intra_cnt = defaultdict(int) # block -> intra reuse hits
cross_cnt = defaultdict(int) # block -> cross reuse hits
new = intra = cross = 0
# LRU reuse distance of each hit: gap = consumer_ts - last_touch_ts
GAP_EDGES = [1, 10, 60, 300, 1800, 3600, float("inf")] # seconds
GAP_LABELS = ["<1s", "1-10s", "10-60s", "1-5min", "5-30min", "30-60min", ">60min"]
rec_intra = [0] * len(GAP_EDGES)
rec_cross = [0] * len(GAP_EDGES)
def gap_bucket(g):
for i, e in enumerate(GAP_EDGES):
if g < e:
return i
return len(GAP_EDGES) - 1
for ts, cid, hs in records:
if not hs:
continue
r = resolve_root(cid)
for h in hs:
lr = last_root.get(h)
if lr is None:
new += 1 # first compute: not a hit
else:
gb = gap_bucket(max(0.0, ts - last_ts[h]))
if lr == r:
intra += 1; intra_cnt[h] += 1; rec_intra[gb] += 1
else:
cross += 1; cross_cnt[h] += 1; rec_cross[gb] += 1
last_root[h] = r # LRU refresh: now held by current session
last_ts[h] = ts
s = roots_of[h]
if len(s) < POP_CAP:
s.add(r)
sys.stderr.write(f"[{time.time()-t0:.0f}s] classified: new={new} intra={intra} cross={cross}\n")
# popularity buckets: distinct sessions touching a block
POP_EDGES = [2, 10, 100, 1000, float("inf")]
POP_LABELS = ["1 (private)", "2-9", "10-99", "100-999", ">=1000"]
def pop_bucket(p):
if p <= 1:
return 0
for i, e in enumerate(POP_EDGES[1:], start=1):
if p < e:
return i
return len(POP_LABELS) - 1
pop_blocks = [0] * len(POP_LABELS)
pop_intra = [0] * len(POP_LABELS)
pop_cross = [0] * len(POP_LABELS)
for h in last_root:
p = len(roots_of[h])
b = pop_bucket(p)
pop_blocks[b] += 1
pop_intra[b] += intra_cnt.get(h, 0)
pop_cross[b] += cross_cnt.get(h, 0)
eff_blk = total_input_tokens / total_blocks if total_blocks else 0.0
total_occ = new + intra + cross
reuse = intra + cross
result = {
"trace": PATH,
"semantics": "LRU last-touched; reuse-hits only (new excluded from split)",
"n_requests": n,
"n_sessions": len(set(resolve_root(c) for c in chat_parent)),
"turn1_frac": turn1 / n,
"block_size_tokens_eff": eff_blk,
"total_input_tokens": total_input_tokens,
"total_block_occ": total_occ,
"distinct_blocks": len(last_root),
"new_occ": new, # context only
"apc_ceiling": reuse / total_occ, # context only
# REUSE-ONLY decomposition (the headline)
"reuse_total": reuse,
"reuse": {"intra": intra, "cross": cross},
"reuse_frac": {"intra": intra / reuse, "cross": cross / reuse},
# cross resolved by popularity (over reuse hits)
"pop_labels": POP_LABELS,
"pop_blocks": pop_blocks,
"pop_intra": pop_intra,
"pop_cross": pop_cross,
# LRU reuse-distance recency (over reuse hits)
"gap_labels": GAP_LABELS,
"rec_intra": rec_intra,
"rec_cross": rec_cross,
}
with open(OUT, "w") as f:
json.dump(result, f, indent=2)
sys.stderr.write(f"[{time.time()-t0:.0f}s] wrote {OUT}\n")
# human summary
print(json.dumps({k: result[k] for k in
("n_requests","n_sessions","distinct_blocks","reuse_total",
"reuse_frac","apc_ceiling")}, indent=2))
print(f"new(context)={new} intra={intra} cross={cross}")
print("popularity blocks / intra-hits / cross-hits:")
for i, lab in enumerate(POP_LABELS):
print(f" {lab:>12}: {pop_blocks[i]:>10} | {pop_intra[i]:>11} | {pop_cross[i]:>11}")
print("LRU reuse-distance intra / cross:")
for i, lab in enumerate(GAP_LABELS):
print(f" {lab:>8}: {rec_intra[i]:>11} | {rec_cross[i]:>11}")