KV cache lifecycle design + eviction loss analysis

Root cause of 10.1pp APC gap: multi-turn sessions' KV evicted between
turns by cold-start prefills (66% of loss). Inter-turn gap is only 2
requests p50, but LRU cache (550 blocks) can't protect 93 blocks/session
across 14-21 concurrent sessions.

Three approaches designed:
  A. Session-sticky routing with KV reservation (proxy-only, no vLLM change)
  B. Two-tier KV cache: GPU + DRAM offload via Mooncake
  C. Prefill-aware eviction (LFU/ARC instead of LRU, vLLM patch)

Next: simulate LRU vs LFU vs "infinite-for-MT" to quantify upper bounds,
then implement Approach A (lowest effort, immediate benchmark).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 01:27:22 +08:00
parent d11d9f5cb9
commit 10636b1ab1
2 changed files with 347 additions and 0 deletions

184
scripts/analyze_eviction.py Normal file
View File

@@ -0,0 +1,184 @@
"""Analyze the 10pp APC gap: what gets evicted and why."""
import json
from collections import OrderedDict
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
rows.sort(key=lambda r: float(r["timestamp"]))
BLOCK_SIZE = 512
KV_CAPACITY_BLOCKS = 550
N_INSTANCES = 8
class LRUCache:
def __init__(self, cap):
self.cap = cap
self.cache = OrderedDict()
self.evictions = 0
def peek(self, k):
return k in self.cache
def access(self, k):
if k in self.cache:
self.cache.move_to_end(k)
return True
self.cache[k] = True
while len(self.cache) > self.cap:
self.cache.popitem(last=False)
self.evictions += 1
return False
inf_seen = [set() for _ in range(N_INSTANCES)]
lru_caches = [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)]
session_aff = {}
chat_to_session = {}
loss_intra = 0 # multi-turn: prior turn evicted
loss_cross = 0 # single-turn: shared prefix evicted
total_loss = 0
total_inf_hits = 0
total_lru_hits = 0
total_tokens = 0
per_req = []
for idx, r in enumerate(rows):
il = r["input_length"]
hids = r.get("hash_ids", [])
cid = r["chat_id"]
pid = r["parent_chat_id"]
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
chat_to_session[cid] = str(sid)
is_mt = pid >= 0
if sid in session_aff:
inst = session_aff[sid]
else:
best_inst, best_h = 0, 0
for j in range(N_INSTANCES):
h = sum(1 for hid in hids[:10] if hid in lru_caches[j].cache)
if h > best_h:
best_h = h
best_inst = j
inst = best_inst
session_aff[sid] = inst
# Infinite
inf_h = 0
for hid in hids:
if hid in inf_seen[inst]:
inf_h += 1
else:
break
for hid in hids:
inf_seen[inst].add(hid)
# LRU
lru_h = 0
for hid in hids:
if lru_caches[inst].peek(hid):
lru_caches[inst].access(hid)
lru_h += 1
else:
break
for hid in hids:
lru_caches[inst].access(hid)
inf_tok = inf_h * BLOCK_SIZE
lru_tok = lru_h * BLOCK_SIZE
loss = inf_tok - lru_tok
total_inf_hits += inf_tok
total_lru_hits += lru_tok
total_tokens += il
if loss > 0:
total_loss += loss
if is_mt:
loss_intra += loss
else:
loss_cross += loss
per_req.append({
"idx": idx, "input": il, "inf_hit": inf_h, "lru_hit": lru_h,
"loss_blocks": inf_h - lru_h, "loss_tok": loss,
"mt": is_mt, "sid": sid, "turn": r.get("turn", 1),
"n_blocks": len(hids),
})
sep = "=" * 70
print(sep)
print(" EVICTION LOSS ANALYSIS")
print(sep)
print()
print(" Infinite APC: %.1f%%" % (total_inf_hits / total_tokens * 100))
print(" LRU APC: %.1f%%" % (total_lru_hits / total_tokens * 100))
print(" Gap: %.1f pp (%s tokens lost)" % (
(total_inf_hits - total_lru_hits) / total_tokens * 100,
"{:,}".format(total_loss)))
print()
print(" Loss by type:")
print(" Multi-turn (prior turn KV evicted): %s tok (%.0f%%)" % (
"{:,}".format(loss_intra), loss_intra * 100 / max(total_loss, 1)))
print(" Single-turn (shared prefix evicted): %s tok (%.0f%%)" % (
"{:,}".format(loss_cross), loss_cross * 100 / max(total_loss, 1)))
print()
print(" Requests with loss: %d / %d" % (len(per_req), len(rows)))
print()
print(" Top-15 by loss:")
print(" %4s %7s %5s %5s %5s %7s %3s %8s %4s" % (
"#", "input", "inf_h", "lru_h", "loss", "tok", "mt", "session", "turn"))
for r in sorted(per_req, key=lambda x: -x["loss_tok"])[:15]:
print(" %4d %7d %5d %5d %5d %7d %3s %8s %4d" % (
r["idx"], r["input"], r["inf_hit"], r["lru_hit"],
r["loss_blocks"], r["loss_tok"],
"Y" if r["mt"] else "N", r["sid"][:8], r["turn"]))
# Instance-level analysis
print()
print(" Per-instance:")
for i in range(N_INSTANCES):
n = len(inf_seen[i])
e = lru_caches[i].evictions
overflow = n / KV_CAPACITY_BLOCKS
print(" inst_%d: %5d unique blocks, overflow=%.1fx, evictions=%d" % (
i, n, overflow, e))
# Time gap analysis: for lost requests, how long between
# the block being deposited and being needed again?
print()
print(" Temporal analysis of evicted blocks:")
# Track when each block was last inserted, per instance
block_deposit_time = [{} for _ in range(N_INSTANCES)]
gaps = []
# Re-scan
session_aff2 = {}
chat_to_session2 = {}
for idx, r in enumerate(rows):
hids = r.get("hash_ids", [])
cid = r["chat_id"]
pid = r["parent_chat_id"]
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session2.get(pid, str(pid)))
chat_to_session2[cid] = str(sid)
if sid in session_aff2:
inst = session_aff2[sid]
else:
inst = 0 # simplified
session_aff2[sid] = inst
for hid in hids:
if hid in block_deposit_time[inst]:
gap = idx - block_deposit_time[inst][hid]
gaps.append(gap)
block_deposit_time[inst][hid] = idx
if gaps:
gaps.sort()
p = lambda q: gaps[min(int(q * len(gaps)), len(gaps) - 1)]
print(" Block reuse distance (requests between deposit and reaccess):")
print(" p10=%d p50=%d p90=%d max=%d" % (p(.1), p(.5), p(.9), max(gaps)))
short = sum(1 for g in gaps if g <= 10)
medium = sum(1 for g in gaps if 10 < g <= 100)
long_ = sum(1 for g in gaps if g > 100)
print(" <=10 req: %d (%.0f%%) 10-100: %d (%.0f%%) >100: %d (%.0f%%)" % (
short, short * 100 / len(gaps),
medium, medium * 100 / len(gaps),
long_, long_ * 100 / len(gaps)))