"""Simulate different KV cache eviction policies on the agentic trace. Compares: 1. LRU (current vLLM default) 2. LFU (Least Frequently Used) 3. Session-protected LRU (multi-turn session blocks get eviction immunity) 4. Infinite cache (upper bound) All use the same cache-aware session-sticky routing with balanced placement. """ import json, statistics from collections import OrderedDict, defaultdict rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")] rows.sort(key=lambda r: float(r["timestamp"])) BLOCK_SIZE = 512 KV_CAPACITY_BLOCKS = 550 N_INSTANCES = 8 # Build session map chat_to_session = {} session_turns = defaultdict(list) for idx, r in enumerate(rows): cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) chat_to_session[cid] = str(sid) session_turns[str(sid)].append(idx) multi_turn_sessions = {s for s, turns in session_turns.items() if len(turns) > 1} class InfiniteCache: """No eviction, unlimited capacity.""" def __init__(self): self.cache = set() def peek(self, k): return k in self.cache def access(self, k, **kw): hit = k in self.cache self.cache.add(k) return hit class LRUCache: """Standard LRU.""" def __init__(self, cap): self.cap = cap self.cache = OrderedDict() def peek(self, k): return k in self.cache def access(self, k, **kw): if k in self.cache: self.cache.move_to_end(k) return True self.cache[k] = 0 while len(self.cache) > self.cap: self.cache.popitem(last=False) return False class LFUCache: """Least Frequently Used with aging.""" def __init__(self, cap): self.cap = cap self.cache = {} # key -> freq self.age_counter = 0 self.insert_order = {} # key -> insertion time (for tiebreak) def peek(self, k): return k in self.cache def access(self, k, **kw): self.age_counter += 1 if k in self.cache: self.cache[k] += 1 return True self.cache[k] = 1 self.insert_order[k] = self.age_counter while len(self.cache) > self.cap: # Evict: lowest freq, tiebreak by oldest insert victim = min(self.cache.keys(), key=lambda x: (self.cache[x], self.insert_order.get(x, 0))) del self.cache[victim] self.insert_order.pop(victim, None) return False class SessionProtectedLRU: """LRU but blocks tagged with active multi-turn session are protected.""" def __init__(self, cap, protected_budget_ratio=0.4): self.cap = cap self.protected_budget = int(cap * protected_budget_ratio) self.cache = OrderedDict() self.protected = set() # block IDs currently protected def peek(self, k): return k in self.cache def protect(self, block_ids): """Mark blocks as protected (from session routing).""" for bid in block_ids: if bid in self.cache: self.protected.add(bid) # Trim protected set if over budget while len(self.protected) > self.protected_budget: # Remove oldest protected for k in self.cache: if k in self.protected: self.protected.discard(k) break def unprotect(self, block_ids): for bid in block_ids: self.protected.discard(bid) def access(self, k, **kw): if k in self.cache: self.cache.move_to_end(k) return True self.cache[k] = True while len(self.cache) > self.cap: # Evict: oldest that is NOT protected evicted = False for candidate in list(self.cache.keys()): if candidate not in self.protected: del self.cache[candidate] evicted = True break if not evicted: # All are protected; evict oldest anyway self.cache.popitem(last=False) return False # Balanced session-sticky routing (distribute by KV size) def build_routing(rows, n_instances): """Assign sessions to instances, balanced by total KV tokens.""" session_kv = defaultdict(int) session_first_idx = {} for idx, r in enumerate(rows): cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) session_kv[str(sid)] += r["input_length"] if str(sid) not in session_first_idx: session_first_idx[str(sid)] = idx # Assign: greedy bin packing by KV size inst_load = [0] * n_instances assignment = {} for sid in sorted(session_kv.keys(), key=lambda s: -session_kv[s]): best = min(range(n_instances), key=lambda i: inst_load[i]) assignment[sid] = best inst_load[best] += session_kv[sid] return assignment routing = build_routing(rows, N_INSTANCES) # Check routing balance inst_reqs = defaultdict(int) for r in rows: cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) inst_reqs[routing[str(sid)]] += 1 print("Routing balance: %s" % dict(sorted(inst_reqs.items()))) # Run all policies policies = { "Infinite": [InfiniteCache() for _ in range(N_INSTANCES)], "LRU": [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)], "LFU": [LFUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)], "SessionProtLRU": [SessionProtectedLRU(KV_CAPACITY_BLOCKS, 0.4) for _ in range(N_INSTANCES)], } results = {name: {"hits": 0, "total": 0, "mt_hits": 0, "mt_total": 0, "st_hits": 0, "st_total": 0} for name in policies} # Track active sessions for protected LRU active_mt_blocks = defaultdict(set) # sid -> set of block_ids for idx, r in enumerate(rows): il = r["input_length"] hids = r.get("hash_ids", []) cid = r["chat_id"] pid = r["parent_chat_id"] sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) is_mt = str(sid) in multi_turn_sessions inst = routing[str(sid)] for name, caches in policies.items(): cache = caches[inst] # Prefix hit hit_blocks = 0 for hid in hids: if cache.peek(hid): cache.access(hid) hit_blocks += 1 else: break # Insert remaining for hid in hids[hit_blocks:]: cache.access(hid) hit_tokens = hit_blocks * BLOCK_SIZE results[name]["hits"] += hit_tokens results[name]["total"] += il if is_mt: results[name]["mt_hits"] += hit_tokens results[name]["mt_total"] += il else: results[name]["st_hits"] += hit_tokens results[name]["st_total"] += il # For SessionProtLRU: protect multi-turn session blocks after each turn if is_mt: active_mt_blocks[str(sid)] = set(hids) for name, caches in policies.items(): if name == "SessionProtLRU": caches[inst].protect(hids) # Print results sep = "=" * 70 print() print(sep) print(" CACHE POLICY SIMULATION (1000 req, 8 inst, session-sticky balanced)") print(sep) fmt = " %-18s %7s %7s %7s" print(fmt % ("Policy", "APC", "MT APC", "ST APC")) print(" " + "-" * 44) for name in ["Infinite", "LRU", "LFU", "SessionProtLRU"]: r = results[name] apc = r["hits"] / r["total"] * 100 if r["total"] > 0 else 0 mt_apc = r["mt_hits"] / r["mt_total"] * 100 if r["mt_total"] > 0 else 0 st_apc = r["st_hits"] / r["st_total"] * 100 if r["st_total"] > 0 else 0 print(fmt % (name, "%.1f%%" % apc, "%.1f%%" % mt_apc, "%.1f%%" % st_apc)) print() print(" APC = overall, MT APC = multi-turn sessions only, ST APC = single-turn only") # Deltas lru_apc = results["LRU"]["hits"] / results["LRU"]["total"] * 100 inf_apc = results["Infinite"]["hits"] / results["Infinite"]["total"] * 100 for name in ["LFU", "SessionProtLRU"]: apc = results[name]["hits"] / results[name]["total"] * 100 recovered = apc - lru_apc total_gap = inf_apc - lru_apc print(" %s: +%.1fpp over LRU (%.0f%% of gap recovered)" % ( name, recovered, recovered / total_gap * 100 if total_gap > 0 else 0))