Systematic study of prefill-decode disaggregation for agentic LLM workloads using production GLM-5.1 coder trace (2.1M requests, 71B input tokens). Key findings: - Cache-aware routing improves TPOT p90 by 15% and APC from 20.8% to 44.7% without PD separation, matching PD-Sep's decode isolation benefit - PD separation adds +72% TTFT overhead (KV transfer) with no TPOT gain when using the same cache-aware scheduler - Prefill remains compute-bound even at 95% KV cache reuse (AI >1000x vs decode AI <2), but absolute FLOPs drop 71% from cache hits - For agentic MoE workloads, cache-aware routing > PD separation Infrastructure: - Trace sampler preserving session structure + hash_ids for prefix sharing - Async trace replayer with streaming TTFT/TPOT/E2E measurement - Unified cache-aware + token-level load-balanced global scheduler proxy supporting both PD-colocated and PD-disaggregated (Mooncake/RDMA) modes - vLLM 0.18.1 scheduler patch for KV transfer abort race condition - Roofline analysis tool for prefill/decode compute characterization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
197 lines
6.7 KiB
Python
197 lines
6.7 KiB
Python
"""Analyze theoretical vs actual KV cache hit ratio for the agentic trace."""
|
|
import json
|
|
from collections import Counter
|
|
|
|
rows = [json.loads(l) for l in open("traces/sampled_1000req_seed42.jsonl")]
|
|
rows.sort(key=lambda r: float(r["timestamp"]))
|
|
|
|
BLOCK_SIZE = 512
|
|
|
|
# === 1. Theoretical max: infinite cache, single instance ===
|
|
total_tokens = 0
|
|
total_cached = 0
|
|
seen_blocks = set()
|
|
per_req = []
|
|
|
|
for r in rows:
|
|
input_len = r["input_length"]
|
|
hash_ids = r.get("hash_ids", [])
|
|
total_tokens += input_len
|
|
|
|
cached_blocks = 0
|
|
prefix_broken = False
|
|
for hid in hash_ids:
|
|
if not prefix_broken and hid in seen_blocks:
|
|
cached_blocks += 1
|
|
else:
|
|
prefix_broken = True
|
|
|
|
cached_tokens = cached_blocks * BLOCK_SIZE
|
|
total_cached += cached_tokens
|
|
for hid in hash_ids:
|
|
seen_blocks.add(hid)
|
|
|
|
per_req.append({
|
|
"input_length": input_len,
|
|
"cached_tokens": cached_tokens,
|
|
"new_tokens": max(0, input_len - cached_tokens),
|
|
"ratio": cached_tokens / input_len if input_len > 0 else 0,
|
|
})
|
|
|
|
sep = "=" * 70
|
|
print(sep)
|
|
print(" THEORETICAL KV CACHE HIT (infinite cache, single instance)")
|
|
print(sep)
|
|
print(f" Total input tokens: {total_tokens:>14,}")
|
|
print(f" Cacheable (prefix hit): {total_cached:>14,} ({total_cached*100//total_tokens}%)")
|
|
print(f" Must prefill (new): {total_tokens-total_cached:>14,} ({(total_tokens-total_cached)*100//total_tokens}%)")
|
|
|
|
ratios = sorted([s["ratio"] for s in per_req if s["input_length"] > 0])
|
|
new_tokens = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
|
p = lambda v, q: v[min(int(q*len(v)), len(v)-1)]
|
|
|
|
print(f"\n Per-request cache hit ratio:")
|
|
print(f" p10={p(ratios,.1)*100:.1f}% p50={p(ratios,.5)*100:.1f}% p90={p(ratios,.9)*100:.1f}% mean={sum(ratios)/len(ratios)*100:.1f}%")
|
|
high = sum(1 for r in ratios if r > 0.5)
|
|
very_high = sum(1 for r in ratios if r > 0.9)
|
|
zero = sum(1 for r in ratios if r == 0)
|
|
print(f" 0% hit (cold start): {zero} ({zero*100//len(ratios)}%)")
|
|
print(f" >50% hit: {high} ({high*100//len(ratios)}%)")
|
|
print(f" >90% hit: {very_high} ({very_high*100//len(ratios)}%)")
|
|
|
|
print(f"\n Actual new tokens to prefill per request:")
|
|
print(f" p10={p(new_tokens,.1):>7,} p50={p(new_tokens,.5):>7,} p90={p(new_tokens,.9):>7,} max={max(new_tokens):>7,}")
|
|
|
|
# === 2. 4-instance split (simulating DP=4 or 4 prefill instances) ===
|
|
print(f"\n{sep}")
|
|
print(" 4-INSTANCE SPLIT (round-robin, per-instance cache)")
|
|
print(sep)
|
|
|
|
instance_seen = [set() for _ in range(4)]
|
|
inst_total = [0]*4
|
|
inst_cached = [0]*4
|
|
|
|
for i, r in enumerate(rows):
|
|
inst = i % 4
|
|
input_len = r["input_length"]
|
|
hash_ids = r.get("hash_ids", [])
|
|
inst_total[inst] += input_len
|
|
|
|
cached_blocks = 0
|
|
prefix_broken = False
|
|
for hid in hash_ids:
|
|
if not prefix_broken and hid in instance_seen[inst]:
|
|
cached_blocks += 1
|
|
else:
|
|
prefix_broken = True
|
|
|
|
inst_cached[inst] += cached_blocks * BLOCK_SIZE
|
|
for hid in hash_ids:
|
|
instance_seen[inst].add(hid)
|
|
|
|
rr_total = sum(inst_total)
|
|
rr_cached = sum(inst_cached)
|
|
print(f" Cache hit ratio (RR): {rr_cached*100//rr_total}%")
|
|
|
|
# === 3. Cache-aware routing (route to instance with best prefix match) ===
|
|
print(f"\n{sep}")
|
|
print(" 4-INSTANCE CACHE-AWARE ROUTING")
|
|
print(sep)
|
|
|
|
ca_seen = [set() for _ in range(4)]
|
|
ca_total = [0]*4
|
|
ca_cached = [0]*4
|
|
|
|
for r in rows:
|
|
input_len = r["input_length"]
|
|
hash_ids = r.get("hash_ids", [])
|
|
|
|
# Pick instance with most prefix blocks cached
|
|
best_inst = 0
|
|
best_hit = 0
|
|
for inst in range(4):
|
|
hit = 0
|
|
for hid in hash_ids:
|
|
if hid in ca_seen[inst]:
|
|
hit += 1
|
|
else:
|
|
break
|
|
if hit > best_hit:
|
|
best_hit = hit
|
|
best_inst = inst
|
|
|
|
ca_total[best_inst] += input_len
|
|
ca_cached[best_inst] += best_hit * BLOCK_SIZE
|
|
for hid in hash_ids:
|
|
ca_seen[best_inst].add(hid)
|
|
|
|
ca_total_sum = sum(ca_total)
|
|
ca_cached_sum = sum(ca_cached)
|
|
print(f" Cache hit ratio: {ca_cached_sum*100//ca_total_sum}%")
|
|
print(f" vs RR: {rr_cached*100//rr_total}% -> {ca_cached_sum*100//ca_total_sum}% (+{(ca_cached_sum-rr_cached)*100//rr_total}pp)")
|
|
|
|
# === 4. Session structure analysis ===
|
|
print(f"\n{sep}")
|
|
print(" SESSION & MULTI-TURN ANALYSIS")
|
|
print(sep)
|
|
|
|
sessions = {}
|
|
chat_to_session = {}
|
|
for r in rows:
|
|
cid = int(r["chat_id"])
|
|
pid = int(r["parent_chat_id"])
|
|
sid = r.get("session_id", str(cid) if pid < 0 else chat_to_session.get(pid, str(pid)))
|
|
chat_to_session[cid] = str(sid)
|
|
sessions.setdefault(str(sid), []).append(r)
|
|
|
|
multi = {k: v for k, v in sessions.items() if len(v) > 1}
|
|
single = {k: v for k, v in sessions.items() if len(v) == 1}
|
|
|
|
print(f" Sessions: {len(sessions)} total, {len(multi)} multi-turn ({len(multi)*100//len(sessions)}%)")
|
|
|
|
# Multi-turn: cache hit in turn 2+
|
|
mt_new = 0
|
|
mt_reuse = 0
|
|
for sid, turns in multi.items():
|
|
turns.sort(key=lambda r: r["turn"])
|
|
prev_blocks = set()
|
|
for t in turns:
|
|
hids = t.get("hash_ids", [])
|
|
for hid in hids:
|
|
if hid in prev_blocks:
|
|
mt_reuse += BLOCK_SIZE
|
|
else:
|
|
mt_new += BLOCK_SIZE
|
|
prev_blocks.add(hid)
|
|
|
|
mt_total_tok = mt_new + mt_reuse
|
|
print(f" Multi-turn intra-session reuse: {mt_reuse*100//mt_total_tok}% of tokens")
|
|
print(f" (Turn 2+ reuses KV from prior turns in same session)")
|
|
|
|
# Single-turn: cross-session sharing via system prompt
|
|
block_freq = Counter()
|
|
for r in rows:
|
|
for hid in r.get("hash_ids", []):
|
|
block_freq[hid] += 1
|
|
|
|
shared = {k: v for k, v in block_freq.items() if v > 1}
|
|
top = block_freq.most_common(5)
|
|
print(f"\n Cross-session block sharing:")
|
|
print(f" Unique blocks: {len(block_freq):,}")
|
|
print(f" Shared (ref>1): {len(shared):,} ({len(shared)*100//len(block_freq)}%)")
|
|
print(f" Top-5 block ref counts: {[c for _,c in top]}")
|
|
print(f" (Shared blocks = system prompt / common code context)")
|
|
|
|
# === 5. Implication for PD separation ===
|
|
print(f"\n{sep}")
|
|
print(" IMPLICATION FOR PD SEPARATION")
|
|
print(sep)
|
|
actual_prefill_pct = (total_tokens - total_cached) * 100 // total_tokens
|
|
print(f" With perfect caching, only {actual_prefill_pct}% of tokens need actual prefill compute.")
|
|
print(f" The remaining {100-actual_prefill_pct}% are prefix cache hits (skip prefill, reuse KV).")
|
|
print(f" This means PD separation's prefill overhead is much smaller than it appears:")
|
|
print(f" - Nominal avg input: {total_tokens//len(rows):,} tokens/request")
|
|
new_per_req = sorted([s["new_tokens"] for s in per_req if s["input_length"] > 0])
|
|
print(f" - Actual avg prefill: {sum(new_per_req)//len(new_per_req):,} tokens/request (after cache hit)")
|
|
print(f" - KV transfer size is also reduced (only transfer new blocks)")
|