Cache policy simulation: routing quality dominates, not eviction policy
With balanced session-sticky routing: LRU APC = 49.2% (only 1.8pp below infinite 51.0%) LFU APC = 43.5% (worse than LRU!) SessionProtLRU = 49.0% (no improvement) The previous 10.1pp gap was from routing imbalance (all traffic to inst_0), not from cache eviction policy. Balanced routing recovers 5.9pp of the gap. Multi-turn sessions get 80.1% APC with simple LRU + session-sticky routing because inter-turn gap is only 2 requests (LRU naturally keeps it warm). Conclusion: fix routing balance, not cache policy. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
256
scripts/simulate_cache_policies.py
Normal file
256
scripts/simulate_cache_policies.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Simulate different KV cache eviction policies on the agentic trace.
|
||||
|
||||
Compares:
|
||||
1. LRU (current vLLM default)
|
||||
2. LFU (Least Frequently Used)
|
||||
3. Session-protected LRU (multi-turn session blocks get eviction immunity)
|
||||
4. Infinite cache (upper bound)
|
||||
|
||||
All use the same cache-aware session-sticky routing with balanced placement.
|
||||
"""
|
||||
|
||||
import json, statistics
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
||||
rows.sort(key=lambda r: float(r["timestamp"]))
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
KV_CAPACITY_BLOCKS = 550
|
||||
N_INSTANCES = 8
|
||||
|
||||
# Build session map
|
||||
chat_to_session = {}
|
||||
session_turns = defaultdict(list)
|
||||
for idx, r in enumerate(rows):
|
||||
cid = r["chat_id"]
|
||||
pid = r["parent_chat_id"]
|
||||
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||
chat_to_session[cid] = str(sid)
|
||||
session_turns[str(sid)].append(idx)
|
||||
|
||||
multi_turn_sessions = {s for s, turns in session_turns.items() if len(turns) > 1}
|
||||
|
||||
|
||||
class InfiniteCache:
|
||||
"""No eviction, unlimited capacity."""
|
||||
def __init__(self):
|
||||
self.cache = set()
|
||||
def peek(self, k):
|
||||
return k in self.cache
|
||||
def access(self, k, **kw):
|
||||
hit = k in self.cache
|
||||
self.cache.add(k)
|
||||
return hit
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Standard LRU."""
|
||||
def __init__(self, cap):
|
||||
self.cap = cap
|
||||
self.cache = OrderedDict()
|
||||
def peek(self, k):
|
||||
return k in self.cache
|
||||
def access(self, k, **kw):
|
||||
if k in self.cache:
|
||||
self.cache.move_to_end(k)
|
||||
return True
|
||||
self.cache[k] = 0
|
||||
while len(self.cache) > self.cap:
|
||||
self.cache.popitem(last=False)
|
||||
return False
|
||||
|
||||
|
||||
class LFUCache:
|
||||
"""Least Frequently Used with aging."""
|
||||
def __init__(self, cap):
|
||||
self.cap = cap
|
||||
self.cache = {} # key -> freq
|
||||
self.age_counter = 0
|
||||
self.insert_order = {} # key -> insertion time (for tiebreak)
|
||||
def peek(self, k):
|
||||
return k in self.cache
|
||||
def access(self, k, **kw):
|
||||
self.age_counter += 1
|
||||
if k in self.cache:
|
||||
self.cache[k] += 1
|
||||
return True
|
||||
self.cache[k] = 1
|
||||
self.insert_order[k] = self.age_counter
|
||||
while len(self.cache) > self.cap:
|
||||
# Evict: lowest freq, tiebreak by oldest insert
|
||||
victim = min(self.cache.keys(),
|
||||
key=lambda x: (self.cache[x], self.insert_order.get(x, 0)))
|
||||
del self.cache[victim]
|
||||
self.insert_order.pop(victim, None)
|
||||
return False
|
||||
|
||||
|
||||
class SessionProtectedLRU:
|
||||
"""LRU but blocks tagged with active multi-turn session are protected."""
|
||||
def __init__(self, cap, protected_budget_ratio=0.4):
|
||||
self.cap = cap
|
||||
self.protected_budget = int(cap * protected_budget_ratio)
|
||||
self.cache = OrderedDict()
|
||||
self.protected = set() # block IDs currently protected
|
||||
|
||||
def peek(self, k):
|
||||
return k in self.cache
|
||||
|
||||
def protect(self, block_ids):
|
||||
"""Mark blocks as protected (from session routing)."""
|
||||
for bid in block_ids:
|
||||
if bid in self.cache:
|
||||
self.protected.add(bid)
|
||||
# Trim protected set if over budget
|
||||
while len(self.protected) > self.protected_budget:
|
||||
# Remove oldest protected
|
||||
for k in self.cache:
|
||||
if k in self.protected:
|
||||
self.protected.discard(k)
|
||||
break
|
||||
|
||||
def unprotect(self, block_ids):
|
||||
for bid in block_ids:
|
||||
self.protected.discard(bid)
|
||||
|
||||
def access(self, k, **kw):
|
||||
if k in self.cache:
|
||||
self.cache.move_to_end(k)
|
||||
return True
|
||||
self.cache[k] = True
|
||||
while len(self.cache) > self.cap:
|
||||
# Evict: oldest that is NOT protected
|
||||
evicted = False
|
||||
for candidate in list(self.cache.keys()):
|
||||
if candidate not in self.protected:
|
||||
del self.cache[candidate]
|
||||
evicted = True
|
||||
break
|
||||
if not evicted:
|
||||
# All are protected; evict oldest anyway
|
||||
self.cache.popitem(last=False)
|
||||
return False
|
||||
|
||||
|
||||
# Balanced session-sticky routing (distribute by KV size)
|
||||
def build_routing(rows, n_instances):
|
||||
"""Assign sessions to instances, balanced by total KV tokens."""
|
||||
session_kv = defaultdict(int)
|
||||
session_first_idx = {}
|
||||
for idx, r in enumerate(rows):
|
||||
cid = r["chat_id"]
|
||||
pid = r["parent_chat_id"]
|
||||
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||
session_kv[str(sid)] += r["input_length"]
|
||||
if str(sid) not in session_first_idx:
|
||||
session_first_idx[str(sid)] = idx
|
||||
|
||||
# Assign: greedy bin packing by KV size
|
||||
inst_load = [0] * n_instances
|
||||
assignment = {}
|
||||
for sid in sorted(session_kv.keys(), key=lambda s: -session_kv[s]):
|
||||
best = min(range(n_instances), key=lambda i: inst_load[i])
|
||||
assignment[sid] = best
|
||||
inst_load[best] += session_kv[sid]
|
||||
|
||||
return assignment
|
||||
|
||||
|
||||
routing = build_routing(rows, N_INSTANCES)
|
||||
|
||||
# Check routing balance
|
||||
inst_reqs = defaultdict(int)
|
||||
for r in rows:
|
||||
cid = r["chat_id"]
|
||||
pid = r["parent_chat_id"]
|
||||
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||
inst_reqs[routing[str(sid)]] += 1
|
||||
print("Routing balance: %s" % dict(sorted(inst_reqs.items())))
|
||||
|
||||
|
||||
# Run all policies
|
||||
policies = {
|
||||
"Infinite": [InfiniteCache() for _ in range(N_INSTANCES)],
|
||||
"LRU": [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)],
|
||||
"LFU": [LFUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)],
|
||||
"SessionProtLRU": [SessionProtectedLRU(KV_CAPACITY_BLOCKS, 0.4) for _ in range(N_INSTANCES)],
|
||||
}
|
||||
|
||||
results = {name: {"hits": 0, "total": 0, "mt_hits": 0, "mt_total": 0,
|
||||
"st_hits": 0, "st_total": 0} for name in policies}
|
||||
|
||||
# Track active sessions for protected LRU
|
||||
active_mt_blocks = defaultdict(set) # sid -> set of block_ids
|
||||
|
||||
for idx, r in enumerate(rows):
|
||||
il = r["input_length"]
|
||||
hids = r.get("hash_ids", [])
|
||||
cid = r["chat_id"]
|
||||
pid = r["parent_chat_id"]
|
||||
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
||||
is_mt = str(sid) in multi_turn_sessions
|
||||
inst = routing[str(sid)]
|
||||
|
||||
for name, caches in policies.items():
|
||||
cache = caches[inst]
|
||||
|
||||
# Prefix hit
|
||||
hit_blocks = 0
|
||||
for hid in hids:
|
||||
if cache.peek(hid):
|
||||
cache.access(hid)
|
||||
hit_blocks += 1
|
||||
else:
|
||||
break
|
||||
# Insert remaining
|
||||
for hid in hids[hit_blocks:]:
|
||||
cache.access(hid)
|
||||
|
||||
hit_tokens = hit_blocks * BLOCK_SIZE
|
||||
results[name]["hits"] += hit_tokens
|
||||
results[name]["total"] += il
|
||||
if is_mt:
|
||||
results[name]["mt_hits"] += hit_tokens
|
||||
results[name]["mt_total"] += il
|
||||
else:
|
||||
results[name]["st_hits"] += hit_tokens
|
||||
results[name]["st_total"] += il
|
||||
|
||||
# For SessionProtLRU: protect multi-turn session blocks after each turn
|
||||
if is_mt:
|
||||
active_mt_blocks[str(sid)] = set(hids)
|
||||
for name, caches in policies.items():
|
||||
if name == "SessionProtLRU":
|
||||
caches[inst].protect(hids)
|
||||
|
||||
# Print results
|
||||
sep = "=" * 70
|
||||
print()
|
||||
print(sep)
|
||||
print(" CACHE POLICY SIMULATION (1000 req, 8 inst, session-sticky balanced)")
|
||||
print(sep)
|
||||
|
||||
fmt = " %-18s %7s %7s %7s"
|
||||
print(fmt % ("Policy", "APC", "MT APC", "ST APC"))
|
||||
print(" " + "-" * 44)
|
||||
|
||||
for name in ["Infinite", "LRU", "LFU", "SessionProtLRU"]:
|
||||
r = results[name]
|
||||
apc = r["hits"] / r["total"] * 100 if r["total"] > 0 else 0
|
||||
mt_apc = r["mt_hits"] / r["mt_total"] * 100 if r["mt_total"] > 0 else 0
|
||||
st_apc = r["st_hits"] / r["st_total"] * 100 if r["st_total"] > 0 else 0
|
||||
print(fmt % (name, "%.1f%%" % apc, "%.1f%%" % mt_apc, "%.1f%%" % st_apc))
|
||||
|
||||
print()
|
||||
print(" APC = overall, MT APC = multi-turn sessions only, ST APC = single-turn only")
|
||||
|
||||
# Deltas
|
||||
lru_apc = results["LRU"]["hits"] / results["LRU"]["total"] * 100
|
||||
inf_apc = results["Infinite"]["hits"] / results["Infinite"]["total"] * 100
|
||||
for name in ["LFU", "SessionProtLRU"]:
|
||||
apc = results[name]["hits"] / results[name]["total"] * 100
|
||||
recovered = apc - lru_apc
|
||||
total_gap = inf_apc - lru_apc
|
||||
print(" %s: +%.1fpp over LRU (%.0f%% of gap recovered)" % (
|
||||
name, recovered, recovered / total_gap * 100 if total_gap > 0 else 0))
|
||||
Reference in New Issue
Block a user