Root cause of 10.1pp APC gap: multi-turn sessions' KV evicted between turns by cold-start prefills (66% of loss). Inter-turn gap is only 2 requests p50, but LRU cache (550 blocks) can't protect 93 blocks/session across 14-21 concurrent sessions. Three approaches designed: A. Session-sticky routing with KV reservation (proxy-only, no vLLM change) B. Two-tier KV cache: GPU + DRAM offload via Mooncake C. Prefill-aware eviction (LFU/ARC instead of LRU, vLLM patch) Next: simulate LRU vs LFU vs "infinite-for-MT" to quantify upper bounds, then implement Approach A (lowest effort, immediate benchmark). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
185 lines
5.6 KiB
Python
185 lines
5.6 KiB
Python
"""Analyze the 10pp APC gap: what gets evicted and why."""
|
|
import json
|
|
from collections import OrderedDict
|
|
|
|
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
|
rows.sort(key=lambda r: float(r["timestamp"]))
|
|
|
|
BLOCK_SIZE = 512
|
|
KV_CAPACITY_BLOCKS = 550
|
|
N_INSTANCES = 8
|
|
|
|
class LRUCache:
|
|
def __init__(self, cap):
|
|
self.cap = cap
|
|
self.cache = OrderedDict()
|
|
self.evictions = 0
|
|
def peek(self, k):
|
|
return k in self.cache
|
|
def access(self, k):
|
|
if k in self.cache:
|
|
self.cache.move_to_end(k)
|
|
return True
|
|
self.cache[k] = True
|
|
while len(self.cache) > self.cap:
|
|
self.cache.popitem(last=False)
|
|
self.evictions += 1
|
|
return False
|
|
|
|
inf_seen = [set() for _ in range(N_INSTANCES)]
|
|
lru_caches = [LRUCache(KV_CAPACITY_BLOCKS) for _ in range(N_INSTANCES)]
|
|
session_aff = {}
|
|
chat_to_session = {}
|
|
|
|
loss_intra = 0 # multi-turn: prior turn evicted
|
|
loss_cross = 0 # single-turn: shared prefix evicted
|
|
total_loss = 0
|
|
total_inf_hits = 0
|
|
total_lru_hits = 0
|
|
total_tokens = 0
|
|
per_req = []
|
|
|
|
for idx, r in enumerate(rows):
|
|
il = r["input_length"]
|
|
hids = r.get("hash_ids", [])
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
chat_to_session[cid] = str(sid)
|
|
is_mt = pid >= 0
|
|
|
|
if sid in session_aff:
|
|
inst = session_aff[sid]
|
|
else:
|
|
best_inst, best_h = 0, 0
|
|
for j in range(N_INSTANCES):
|
|
h = sum(1 for hid in hids[:10] if hid in lru_caches[j].cache)
|
|
if h > best_h:
|
|
best_h = h
|
|
best_inst = j
|
|
inst = best_inst
|
|
session_aff[sid] = inst
|
|
|
|
# Infinite
|
|
inf_h = 0
|
|
for hid in hids:
|
|
if hid in inf_seen[inst]:
|
|
inf_h += 1
|
|
else:
|
|
break
|
|
for hid in hids:
|
|
inf_seen[inst].add(hid)
|
|
|
|
# LRU
|
|
lru_h = 0
|
|
for hid in hids:
|
|
if lru_caches[inst].peek(hid):
|
|
lru_caches[inst].access(hid)
|
|
lru_h += 1
|
|
else:
|
|
break
|
|
for hid in hids:
|
|
lru_caches[inst].access(hid)
|
|
|
|
inf_tok = inf_h * BLOCK_SIZE
|
|
lru_tok = lru_h * BLOCK_SIZE
|
|
loss = inf_tok - lru_tok
|
|
|
|
total_inf_hits += inf_tok
|
|
total_lru_hits += lru_tok
|
|
total_tokens += il
|
|
|
|
if loss > 0:
|
|
total_loss += loss
|
|
if is_mt:
|
|
loss_intra += loss
|
|
else:
|
|
loss_cross += loss
|
|
per_req.append({
|
|
"idx": idx, "input": il, "inf_hit": inf_h, "lru_hit": lru_h,
|
|
"loss_blocks": inf_h - lru_h, "loss_tok": loss,
|
|
"mt": is_mt, "sid": sid, "turn": r.get("turn", 1),
|
|
"n_blocks": len(hids),
|
|
})
|
|
|
|
sep = "=" * 70
|
|
print(sep)
|
|
print(" EVICTION LOSS ANALYSIS")
|
|
print(sep)
|
|
print()
|
|
print(" Infinite APC: %.1f%%" % (total_inf_hits / total_tokens * 100))
|
|
print(" LRU APC: %.1f%%" % (total_lru_hits / total_tokens * 100))
|
|
print(" Gap: %.1f pp (%s tokens lost)" % (
|
|
(total_inf_hits - total_lru_hits) / total_tokens * 100,
|
|
"{:,}".format(total_loss)))
|
|
print()
|
|
print(" Loss by type:")
|
|
print(" Multi-turn (prior turn KV evicted): %s tok (%.0f%%)" % (
|
|
"{:,}".format(loss_intra), loss_intra * 100 / max(total_loss, 1)))
|
|
print(" Single-turn (shared prefix evicted): %s tok (%.0f%%)" % (
|
|
"{:,}".format(loss_cross), loss_cross * 100 / max(total_loss, 1)))
|
|
print()
|
|
print(" Requests with loss: %d / %d" % (len(per_req), len(rows)))
|
|
|
|
print()
|
|
print(" Top-15 by loss:")
|
|
print(" %4s %7s %5s %5s %5s %7s %3s %8s %4s" % (
|
|
"#", "input", "inf_h", "lru_h", "loss", "tok", "mt", "session", "turn"))
|
|
for r in sorted(per_req, key=lambda x: -x["loss_tok"])[:15]:
|
|
print(" %4d %7d %5d %5d %5d %7d %3s %8s %4d" % (
|
|
r["idx"], r["input"], r["inf_hit"], r["lru_hit"],
|
|
r["loss_blocks"], r["loss_tok"],
|
|
"Y" if r["mt"] else "N", r["sid"][:8], r["turn"]))
|
|
|
|
# Instance-level analysis
|
|
print()
|
|
print(" Per-instance:")
|
|
for i in range(N_INSTANCES):
|
|
n = len(inf_seen[i])
|
|
e = lru_caches[i].evictions
|
|
overflow = n / KV_CAPACITY_BLOCKS
|
|
print(" inst_%d: %5d unique blocks, overflow=%.1fx, evictions=%d" % (
|
|
i, n, overflow, e))
|
|
|
|
# Time gap analysis: for lost requests, how long between
|
|
# the block being deposited and being needed again?
|
|
print()
|
|
print(" Temporal analysis of evicted blocks:")
|
|
# Track when each block was last inserted, per instance
|
|
block_deposit_time = [{} for _ in range(N_INSTANCES)]
|
|
gaps = []
|
|
|
|
# Re-scan
|
|
session_aff2 = {}
|
|
chat_to_session2 = {}
|
|
for idx, r in enumerate(rows):
|
|
hids = r.get("hash_ids", [])
|
|
cid = r["chat_id"]
|
|
pid = r["parent_chat_id"]
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session2.get(pid, str(pid)))
|
|
chat_to_session2[cid] = str(sid)
|
|
if sid in session_aff2:
|
|
inst = session_aff2[sid]
|
|
else:
|
|
inst = 0 # simplified
|
|
session_aff2[sid] = inst
|
|
|
|
for hid in hids:
|
|
if hid in block_deposit_time[inst]:
|
|
gap = idx - block_deposit_time[inst][hid]
|
|
gaps.append(gap)
|
|
block_deposit_time[inst][hid] = idx
|
|
|
|
if gaps:
|
|
gaps.sort()
|
|
p = lambda q: gaps[min(int(q * len(gaps)), len(gaps) - 1)]
|
|
print(" Block reuse distance (requests between deposit and reaccess):")
|
|
print(" p10=%d p50=%d p90=%d max=%d" % (p(.1), p(.5), p(.9), max(gaps)))
|
|
short = sum(1 for g in gaps if g <= 10)
|
|
medium = sum(1 for g in gaps if 10 < g <= 100)
|
|
long_ = sum(1 for g in gaps if g > 100)
|
|
print(" <=10 req: %d (%.0f%%) 10-100: %d (%.0f%%) >100: %d (%.0f%%)" % (
|
|
short, short * 100 / len(gaps),
|
|
medium, medium * 100 / len(gaps),
|
|
long_, long_ * 100 / len(gaps)))
|