With balanced session-sticky routing: LRU APC = 49.2% (only 1.8pp below infinite 51.0%) LFU APC = 43.5% (worse than LRU!) SessionProtLRU = 49.0% (no improvement) The previous 10.1pp gap was from routing imbalance (all traffic to inst_0), not from cache eviction policy. Balanced routing recovers 5.9pp of the gap. Multi-turn sessions get 80.1% APC with simple LRU + session-sticky routing because inter-turn gap is only 2 requests (LRU naturally keeps it warm). Conclusion: fix routing balance, not cache policy. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
257 lines
8.3 KiB
Python
257 lines
8.3 KiB
Python
"""Simulate different KV cache eviction policies on the agentic trace.
|
|
|
|
Compares:
|
|
1. LRU (current vLLM default)
|
|
2. LFU (Least Frequently Used)
|
|
3. Session-protected LRU (multi-turn session blocks get eviction immunity)
|
|
4. Infinite cache (upper bound)
|
|
|
|
All use the same cache-aware session-sticky routing with balanced placement.
|
|
"""
|
|
|
|
import json, statistics
|
|
from collections import OrderedDict, defaultdict
|
|
|
|
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
|
rows.sort(key=lambda r: float(r["timestamp"]))
|
|
|
|
BLOCK_SIZE = 512
|
|
KV_CAPACITY_BLOCKS = 550
|
|
N_INSTANCES = 8
|
|
|
|
# Build session map
|
|
chat_to_session = {}
|
|
session_turns = defaultdict(list)
|
|
for idx, r in enumerate(rows):
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
chat_to_session[cid] = str(sid)
|
|
session_turns[str(sid)].append(idx)
|
|
|
|
multi_turn_sessions = {s for s, turns in session_turns.items() if len(turns) > 1}
|
|
|
|
|
|
class InfiniteCache:
|
|
"""No eviction, unlimited capacity."""
|
|
def __init__(self):
|
|
self.cache = set()
|
|
def peek(self, k):
|
|
return k in self.cache
|
|
def access(self, k, **kw):
|
|
hit = k in self.cache
|
|
self.cache.add(k)
|
|
return hit
|
|
|
|
|
|
class LRUCache:
|
|
"""Standard LRU."""
|
|
def __init__(self, cap):
|
|
self.cap = cap
|
|
self.cache = OrderedDict()
|
|
def peek(self, k):
|
|
return k in self.cache
|
|
def access(self, k, **kw):
|
|
if k in self.cache:
|
|
self.cache.move_to_end(k)
|
|
return True
|
|
self.cache[k] = 0
|
|
while len(self.cache) > self.cap:
|
|
self.cache.popitem(last=False)
|
|
return False
|
|
|
|
|
|
class LFUCache:
|
|
"""Least Frequently Used with aging."""
|
|
def __init__(self, cap):
|
|
self.cap = cap
|
|
self.cache = {} # key -> freq
|
|
self.age_counter = 0
|
|
self.insert_order = {} # key -> insertion time (for tiebreak)
|
|
def peek(self, k):
|
|
return k in self.cache
|
|
def access(self, k, **kw):
|
|
self.age_counter += 1
|
|
if k in self.cache:
|
|
self.cache[k] += 1
|
|
return True
|
|
self.cache[k] = 1
|
|
self.insert_order[k] = self.age_counter
|
|
while len(self.cache) > self.cap:
|
|
# Evict: lowest freq, tiebreak by oldest insert
|
|
victim = min(self.cache.keys(),
|
|
key=lambda x: (self.cache[x], self.insert_order.get(x, 0)))
|
|
del self.cache[victim]
|
|
self.insert_order.pop(victim, None)
|
|
return False
|
|
|
|
|
|
class SessionProtectedLRU:
|
|
"""LRU but blocks tagged with active multi-turn session are protected."""
|
|
def __init__(self, cap, protected_budget_ratio=0.4):
|
|
self.cap = cap
|
|
self.protected_budget = int(cap * protected_budget_ratio)
|
|
self.cache = OrderedDict()
|
|
self.protected = set() # block IDs currently protected
|
|
|
|
def peek(self, k):
|
|
return k in self.cache
|
|
|
|
def protect(self, block_ids):
|
|
"""Mark blocks as protected (from session routing)."""
|
|
for bid in block_ids:
|
|
if bid in self.cache:
|
|
self.protected.add(bid)
|
|
# Trim protected set if over budget
|
|
while len(self.protected) > self.protected_budget:
|
|
# Remove oldest protected
|
|
for k in self.cache:
|
|
if k in self.protected:
|
|
self.protected.discard(k)
|
|
break
|
|
|
|
def unprotect(self, block_ids):
|
|
for bid in block_ids:
|
|
self.protected.discard(bid)
|
|
|
|
def access(self, k, **kw):
|
|
if k in self.cache:
|
|
self.cache.move_to_end(k)
|
|
return True
|
|
self.cache[k] = True
|
|
while len(self.cache) > self.cap:
|
|
# Evict: oldest that is NOT protected
|
|
evicted = False
|
|
for candidate in list(self.cache.keys()):
|
|
if candidate not in self.protected:
|
|
del self.cache[candidate]
|
|
evicted = True
|
|
break
|
|
if not evicted:
|
|
# All are protected; evict oldest anyway
|
|
self.cache.popitem(last=False)
|
|
return False
|
|
|
|
|
|
# Balanced session-sticky routing (distribute by KV size)
|
|
def build_routing(rows, n_instances):
|
|
"""Assign sessions to instances, balanced by total KV tokens."""
|
|
session_kv = defaultdict(int)
|
|
session_first_idx = {}
|
|
for idx, r in enumerate(rows):
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
session_kv[str(sid)] += r["input_length"]
|
|
if str(sid) not in session_first_idx:
|
|
session_first_idx[str(sid)] = idx
|
|
|
|
# Assign: greedy bin packing by KV size
|
|
inst_load = [0] * n_instances
|
|
assignment = {}
|
|
for sid in sorted(session_kv.keys(), key=lambda s: -session_kv[s]):
|
|
best = min(range(n_instances), key=lambda i: inst_load[i])
|
|
assignment[sid] = best
|
|
inst_load[best] += session_kv[sid]
|
|
|
|
return assignment
|
|
|
|
|
|
routing = build_routing(rows, N_INSTANCES)
|
|
|
|
# Check routing balance
|
|
inst_reqs = defaultdict(int)
|
|
for r in rows:
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
inst_reqs[routing[str(sid)]] += 1
|
|
print("Routing balance: %s" % dict(sorted(inst_reqs.items())))
|
|
|
|
|
|
# Run all policies
|
|
policies = {
|
|
"Infinite": [InfiniteCache() for _ in range(N_INSTANCES)],
|
|
"LRU": [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)],
|
|
"LFU": [LFUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)],
|
|
"SessionProtLRU": [SessionProtectedLRU(KV_CAPACITY_BLOCKS, 0.4) for _ in range(N_INSTANCES)],
|
|
}
|
|
|
|
results = {name: {"hits": 0, "total": 0, "mt_hits": 0, "mt_total": 0,
|
|
"st_hits": 0, "st_total": 0} for name in policies}
|
|
|
|
# Track active sessions for protected LRU
|
|
active_mt_blocks = defaultdict(set) # sid -> set of block_ids
|
|
|
|
for idx, r in enumerate(rows):
|
|
il = r["input_length"]
|
|
hids = r.get("hash_ids", [])
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
is_mt = str(sid) in multi_turn_sessions
|
|
inst = routing[str(sid)]
|
|
|
|
for name, caches in policies.items():
|
|
cache = caches[inst]
|
|
|
|
# Prefix hit
|
|
hit_blocks = 0
|
|
for hid in hids:
|
|
if cache.peek(hid):
|
|
cache.access(hid)
|
|
hit_blocks += 1
|
|
else:
|
|
break
|
|
# Insert remaining
|
|
for hid in hids[hit_blocks:]:
|
|
cache.access(hid)
|
|
|
|
hit_tokens = hit_blocks * BLOCK_SIZE
|
|
results[name]["hits"] += hit_tokens
|
|
results[name]["total"] += il
|
|
if is_mt:
|
|
results[name]["mt_hits"] += hit_tokens
|
|
results[name]["mt_total"] += il
|
|
else:
|
|
results[name]["st_hits"] += hit_tokens
|
|
results[name]["st_total"] += il
|
|
|
|
# For SessionProtLRU: protect multi-turn session blocks after each turn
|
|
if is_mt:
|
|
active_mt_blocks[str(sid)] = set(hids)
|
|
for name, caches in policies.items():
|
|
if name == "SessionProtLRU":
|
|
caches[inst].protect(hids)
|
|
|
|
# Print results
|
|
sep = "=" * 70
|
|
print()
|
|
print(sep)
|
|
print(" CACHE POLICY SIMULATION (1000 req, 8 inst, session-sticky balanced)")
|
|
print(sep)
|
|
|
|
fmt = " %-18s %7s %7s %7s"
|
|
print(fmt % ("Policy", "APC", "MT APC", "ST APC"))
|
|
print(" " + "-" * 44)
|
|
|
|
for name in ["Infinite", "LRU", "LFU", "SessionProtLRU"]:
|
|
r = results[name]
|
|
apc = r["hits"] / r["total"] * 100 if r["total"] > 0 else 0
|
|
mt_apc = r["mt_hits"] / r["mt_total"] * 100 if r["mt_total"] > 0 else 0
|
|
st_apc = r["st_hits"] / r["st_total"] * 100 if r["st_total"] > 0 else 0
|
|
print(fmt % (name, "%.1f%%" % apc, "%.1f%%" % mt_apc, "%.1f%%" % st_apc))
|
|
|
|
print()
|
|
print(" APC = overall, MT APC = multi-turn sessions only, ST APC = single-turn only")
|
|
|
|
# Deltas
|
|
lru_apc = results["LRU"]["hits"] / results["LRU"]["total"] * 100
|
|
inf_apc = results["Infinite"]["hits"] / results["Infinite"]["total"] * 100
|
|
for name in ["LFU", "SessionProtLRU"]:
|
|
apc = results[name]["hits"] / results[name]["total"] * 100
|
|
recovered = apc - lru_apc
|
|
total_gap = inf_apc - lru_apc
|
|
print(" %s: +%.1fpp over LRU (%.0f%% of gap recovered)" % (
|
|
name, recovered, recovered / total_gap * 100 if total_gap > 0 else 0))
|