"""Analyze the 10pp APC gap: what gets evicted and why.""" import json from collections import OrderedDict rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")] rows.sort(key=lambda r: float(r["timestamp"])) BLOCK_SIZE = 512 KV_CAPACITY_BLOCKS = 550 N_INSTANCES = 8 class LRUCache: def __init__(self, cap): self.cap = cap self.cache = OrderedDict() self.evictions = 0 def peek(self, k): return k in self.cache def access(self, k): if k in self.cache: self.cache.move_to_end(k) return True self.cache[k] = True while len(self.cache) > self.cap: self.cache.popitem(last=False) self.evictions += 1 return False inf_seen = [set() for _ in range(N_INSTANCES)] lru_caches = [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)] session_aff = {} chat_to_session = {} loss_intra = 0 # multi-turn: prior turn evicted loss_cross = 0 # single-turn: shared prefix evicted total_loss = 0 total_inf_hits = 0 total_lru_hits = 0 total_tokens = 0 per_req = [] for idx, r in enumerate(rows): il = r["input_length"] hids = r.get("hash_ids", []) cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) chat_to_session[cid] = str(sid) is_mt = pid >= 0 if sid in session_aff: inst = session_aff[sid] else: best_inst, best_h = 0, 0 for j in range(N_INSTANCES): h = sum(1 for hid in hids[:10] if hid in lru_caches[j].cache) if h > best_h: best_h = h best_inst = j inst = best_inst session_aff[sid] = inst # Infinite inf_h = 0 for hid in hids: if hid in inf_seen[inst]: inf_h += 1 else: break for hid in hids: inf_seen[inst].add(hid) # LRU lru_h = 0 for hid in hids: if lru_caches[inst].peek(hid): lru_caches[inst].access(hid) lru_h += 1 else: break for hid in hids: lru_caches[inst].access(hid) inf_tok = inf_h * BLOCK_SIZE lru_tok = lru_h * BLOCK_SIZE loss = inf_tok - lru_tok total_inf_hits += inf_tok total_lru_hits += lru_tok total_tokens += il if loss > 0: total_loss += loss if is_mt: loss_intra += loss else: loss_cross += loss per_req.append({ "idx": idx, "input": il, "inf_hit": inf_h, "lru_hit": lru_h, "loss_blocks": inf_h - lru_h, "loss_tok": loss, "mt": is_mt, "sid": sid, "turn": r.get("turn", 1), "n_blocks": len(hids), }) sep = "=" * 70 print(sep) print(" EVICTION LOSS ANALYSIS") print(sep) print() print(" Infinite APC: %.1f%%" % (total_inf_hits / total_tokens * 100)) print(" LRU APC: %.1f%%" % (total_lru_hits / total_tokens * 100)) print(" Gap: %.1f pp (%s tokens lost)" % ( (total_inf_hits - total_lru_hits) / total_tokens * 100, "{:,}".format(total_loss))) print() print(" Loss by type:") print(" Multi-turn (prior turn KV evicted): %s tok (%.0f%%)" % ( "{:,}".format(loss_intra), loss_intra * 100 / max(total_loss, 1))) print(" Single-turn (shared prefix evicted): %s tok (%.0f%%)" % ( "{:,}".format(loss_cross), loss_cross * 100 / max(total_loss, 1))) print() print(" Requests with loss: %d / %d" % (len(per_req), len(rows))) print() print(" Top-15 by loss:") print(" %4s %7s %5s %5s %5s %7s %3s %8s %4s" % ( "#", "input", "inf_h", "lru_h", "loss", "tok", "mt", "session", "turn")) for r in sorted(per_req, key=lambda x: -x["loss_tok"])[:15]: print(" %4d %7d %5d %5d %5d %7d %3s %8s %4d" % ( r["idx"], r["input"], r["inf_hit"], r["lru_hit"], r["loss_blocks"], r["loss_tok"], "Y" if r["mt"] else "N", r["sid"][:8], r["turn"])) # Instance-level analysis print() print(" Per-instance:") for i in range(N_INSTANCES): n = len(inf_seen[i]) e = lru_caches[i].evictions overflow = n / KV_CAPACITY_BLOCKS print(" inst_%d: %5d unique blocks, overflow=%.1fx, evictions=%d" % ( i, n, overflow, e)) # Time gap analysis: for lost requests, how long between # the block being deposited and being needed again? print() print(" Temporal analysis of evicted blocks:") # Track when each block was last inserted, per instance block_deposit_time = [{} for _ in range(N_INSTANCES)] gaps = [] # Re-scan session_aff2 = {} chat_to_session2 = {} for idx, r in enumerate(rows): hids = r.get("hash_ids", []) cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session2.get(pid, str(pid))) chat_to_session2[cid] = str(sid) if sid in session_aff2: inst = session_aff2[sid] else: inst = 0 # simplified session_aff2[sid] = inst for hid in hids: if hid in block_deposit_time[inst]: gap = idx - block_deposit_time[inst][hid] gaps.append(gap) block_deposit_time[inst][hid] = idx if gaps: gaps.sort() p = lambda q: gaps[min(int(q * len(gaps)), len(gaps) - 1)] print(" Block reuse distance (requests between deposit and reaccess):") print(" p10=%d p50=%d p90=%d max=%d" % (p(.1), p(.5), p(.9), max(gaps))) short = sum(1 for g in gaps if g <= 10) medium = sum(1 for g in gaps if 10 < g <= 100) long_ = sum(1 for g in gaps if g > 100) print(" <=10 req: %d (%.0f%%) 10-100: %d (%.0f%%) >100: %d (%.0f%%)" % ( short, short * 100 / len(gaps), medium, medium * 100 / len(gaps), long_, long_ * 100 / len(gaps)))