"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace.""" import json from collections import Counter rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")] rows.sort(key=lambda r: float(r["timestamp"])) BLOCK_SIZE = 512 # === 1. Theoretical max: infinite cache, single instance === total_tokens = 0 total_cached = 0 seen_blocks = set() per_req = [] for r in rows: input_len = r["input_length"] hash_ids = r.get("hash_ids", []) total_tokens += input_len cached_blocks = 0 prefix_broken = False for hid in hash_ids: if not prefix_broken and hid in seen_blocks: cached_blocks += 1 else: prefix_broken = True cached_tokens = cached_blocks * BLOCK_SIZE total_cached += cached_tokens for hid in hash_ids: seen_blocks.add(hid) per_req.append({ "input_length": input_len, "cached_tokens": cached_tokens, "new_tokens": max(0, input_len - cached_tokens), "ratio": cached_tokens / input_len if input_len > 0 else 0, }) sep = "=" * 70 print(sep) print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)") print(sep) print(f" Total input tokens: {total_tokens:>14,}") print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)") print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)") ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0]) new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0]) p = lambda v, q: v[min(int(q*len(v)), len(v)-1)] print(f"\n Per-request cache hit ratio:") print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%") high = sum(1 for r in ratios if r > 0.5) very_high = sum(1 for r in ratios if r > 0.9) zero = sum(1 for r in ratios if r == 0) print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)") print(f" >50% hit: {high} ({high*100//len(ratios)}%)") print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)") print(f"\n Actual new tokens to prefill per request:") print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}") # === 2. 4-instance split (simulating DP=4 or 4 prefill instances) === print(f"\n{sep}") print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)") print(sep) instance_seen = [set() for _ in range(4)] inst_total = [0]*4 inst_cached = [0]*4 for i, r in enumerate(rows): inst = i % 4 input_len = r["input_length"] hash_ids = r.get("hash_ids", []) inst_total[inst] += input_len cached_blocks = 0 prefix_broken = False for hid in hash_ids: if not prefix_broken and hid in instance_seen[inst]: cached_blocks += 1 else: prefix_broken = True inst_cached[inst] += cached_blocks * BLOCK_SIZE for hid in hash_ids: instance_seen[inst].add(hid) rr_total = sum(inst_total) rr_cached = sum(inst_cached) print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%") # === 3. Cache-aware routing (route to instance with best prefix match) === print(f"\n{sep}") print(" 4-INSTANCE CACHE-AWARE ROUTING") print(sep) ca_seen = [set() for _ in range(4)] ca_total = [0]*4 ca_cached = [0]*4 for r in rows: input_len = r["input_length"] hash_ids = r.get("hash_ids", []) # Pick instance with most prefix blocks cached best_inst = 0 best_hit = 0 for inst in range(4): hit = 0 for hid in hash_ids: if hid in ca_seen[inst]: hit += 1 else: break if hit > best_hit: best_hit = hit best_inst = inst ca_total[best_inst] += input_len ca_cached[best_inst] += best_hit * BLOCK_SIZE for hid in hash_ids: ca_seen[best_inst].add(hid) ca_total_sum = sum(ca_total) ca_cached_sum = sum(ca_cached) print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%") print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)") # === 4. Session structure analysis === print(f"\n{sep}") print(" SESSION & MULTI-TURN ANALYSIS") print(sep) sessions = {} chat_to_session = {} for r in rows: cid = int(r["chat_id"]) pid = int(r["parent_chat_id"]) sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid))) chat_to_session[cid] = str(sid) sessions.setdefault(str(sid), []).append(r) multi = {k: v for k, v in sessions.items() if len(v) > 1} single = {k: v for k, v in sessions.items() if len(v) == 1} print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)") # Multi-turn: cache hit in turn 2+ mt_new = 0 mt_reuse = 0 for sid, turns in multi.items(): turns.sort(key=lambda r: r["turn"]) prev_blocks = set() for t in turns: hids = t.get("hash_ids", []) for hid in hids: if hid in prev_blocks: mt_reuse += BLOCK_SIZE else: mt_new += BLOCK_SIZE prev_blocks.add(hid) mt_total_tok = mt_new + mt_reuse print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens") print(f" (Turn 2+ reuses KV from prior turns in same session)") # Single-turn: cross-session sharing via system prompt block_freq = Counter() for r in rows: for hid in r.get("hash_ids", []): block_freq[hid] += 1 shared = {k: v for k, v in block_freq.items() if v > 1} top = block_freq.most_common(5) print(f"\n Cross-session block sharing:") print(f" Unique blocks: {len(block_freq):,}") print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)") print(f" Top-5 block ref counts: {[c for _,c in top]}") print(f" (Shared blocks = system prompt / common code context)") # === 5. Implication for PD separation === print(f"\n{sep}") print(" IMPLICATION FOR PD SEPARATION") print(sep) actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.") print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).") print(f" This means PD separation's prefill overhead is much smaller than it appears:") print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request") new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0]) print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)") print(f" - KV transfer size is also reduced (only transfer new blocks)")