Files
agentic-kvc/paper/data/f2a_mixture_sweep.py
Gahow Wang 19c443e3bc paper f2a: reuse-topology decomposition + mixture-sensitivity sweep
Full-trace analysis backing figure 2a on the real 2h cluster trace:

- f2a_reuse_topology_analyze.py: infinite-KV-cache (LRU) decomposition of
  prefix-cache reuse hits into intra-session vs cross-session, by most-recent
  prior holder of each content-addressed block.
- f2a_mixture_sweep.py: sensitivity of the intra/cross split to the
  single-turn session fraction (tests whether the 93%-intra sample vs 54.6%
  full-trace gap is session-mixture selection bias) -- keep all multi-turn
  sessions, downsample single-turn to each target fraction, reclassify.

Includes the result JSONs for both.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-01 01:03:40 +08:00

121 lines
4.5 KiB
Python

#!/usr/bin/env python3
"""
f2a sensitivity: how does the intra/cross reuse split move as we change the
single-turn session fraction? (Tests whether the old 93%-intra sample vs 54.6%
full-trace gap is just session-mixture selection bias.)
Keep ALL multi-turn sessions; downsample single-turn sessions to hit each target
single-turn fraction f. Re-run the LRU (last-touched), reuse-hits-only
classification on the filtered request stream.
python3 f2a_mixture_sweep.py ~/ali-trace/.../051315-051317.jsonl /tmp/f2a_sweep.json
"""
import sys, json, time, random
from collections import Counter, defaultdict
PATH = sys.argv[1]
OUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/f2a_sweep.json"
random.seed(0)
t0 = time.time()
chat_parent = {}
records = []
with open(PATH) as f:
for line in f:
d = json.loads(line)
cid = d["chat_id"]; pc = d.get("parent_chat_id")
chat_parent[cid] = 0 if pc is None else pc
records.append((d.get("timestamp", 0.0), cid, d.get("hash_ids") or []))
sys.stderr.write(f"[{time.time()-t0:.0f}s] loaded {len(records)}\n")
root_cache = {}
def resolve_root(cid):
chain = []; cur = cid
while True:
if cur in root_cache:
r = root_cache[cur]; break
p = chat_parent.get(cur, 0)
if p == 0 or p not in chat_parent:
r = cur; break
chain.append(cur); cur = p
if len(chain) > 100000:
r = cur; break
for nd in chain:
root_cache[nd] = r
root_cache[cid] = r
return r
records.sort(key=lambda x: x[0])
roots = [resolve_root(cid) for _, cid, _ in records]
req_per_root = Counter(roots)
single_roots = [r for r, c in req_per_root.items() if c == 1]
multi_roots = [r for r, c in req_per_root.items() if c >= 2]
M = len(multi_roots)
sys.stderr.write(f"[{time.time()-t0:.0f}s] roots: single={len(single_roots)} multi={M}\n")
GAP_EDGES = [1, 10, 60, 300, 1800, 3600, float("inf")]
def gbucket(g):
for i, e in enumerate(GAP_EDGES):
if g < e:
return i
return len(GAP_EDGES) - 1
def classify(kept): # kept=None -> keep all
last_root = {}; last_ts = {}
intra = cross = new = 0
rec_i = [0] * len(GAP_EDGES); rec_c = [0] * len(GAP_EDGES)
for (ts, cid, hs), r in zip(records, roots):
if kept is not None and r not in kept:
continue
for h in hs:
lr = last_root.get(h)
if lr is None:
new += 1
else:
gb = gbucket(max(0.0, ts - last_ts[h]))
if lr == r:
intra += 1; rec_i[gb] += 1
else:
cross += 1; rec_c[gb] += 1
last_root[h] = r; last_ts[h] = ts
return intra, cross, new, rec_i, rec_c
def cum_le(rec, idx): # cumulative fraction with gap-bucket <= idx
tot = sum(rec) or 1
return sum(rec[: idx + 1]) / tot
targets = [("full", None), (0.75, None), (0.50, None),
(0.25, None), (0.10, None), (0.00, None)]
rows = []
for label, _ in targets:
if label == "full":
kept = None
f_actual = len(single_roots) / (len(single_roots) + M)
else:
f = float(label)
S = min(len(single_roots), int(round(M * f / (1 - f)))) if f < 1 else len(single_roots)
keep_single = set(random.sample(single_roots, S)) if S < len(single_roots) else set(single_roots)
kept = set(multi_roots) | keep_single
f_actual = S / (S + M)
intra, cross, new, rec_i, rec_c = classify(kept)
reuse = intra + cross
n_sess = (len(single_roots) + M) if kept is None else len(kept)
row = {
"target": label, "single_turn_frac": round(f_actual, 4), "n_sessions": n_sess,
"new": new, "intra": intra, "cross": cross, "reuse": reuse,
"intra_frac_of_reuse": round(intra / reuse, 4),
"cross_frac_of_reuse": round(cross / reuse, 4),
"intra_le60s": round(cum_le(rec_i, 2), 4),
"cross_le60s": round(cum_le(rec_c, 2), 4),
}
rows.append(row)
sys.stderr.write(f"[{time.time()-t0:.0f}s] f={row['single_turn_frac']}: "
f"intra={row['intra_frac_of_reuse']} cross={row['cross_frac_of_reuse']}\n")
json.dump({"rows": rows, "n_single": len(single_roots), "n_multi": M}, open(OUT, "w"), indent=2)
print(f"{'single-turn%':>12} {'sessions':>10} {'intra%':>8} {'cross%':>8} {'intra<=60s':>11} {'cross<=60s':>11}")
for r in rows:
print(f"{r['single_turn_frac']*100:>11.1f}% {r['n_sessions']:>10} "
f"{r['intra_frac_of_reuse']*100:>7.1f}% {r['cross_frac_of_reuse']*100:>7.1f}% "
f"{r['intra_le60s']*100:>10.1f}% {r['cross_le60s']*100:>10.1f}%")