"""KV-cache working-set sizing for agentic traces, across GPU / model / parallelism. WHAT IT COMPUTES hash_ids in these traces are global content-addressed block ids (same content -> same id; reuse = repeated id). vLLM prefix cache is block-level, so the cluster-wide KV footprint at any instant = the set of distinct block ids that must be resident. Session/instance placement only moves blocks between GPUs; it does not change this aggregate, so the analysis is placement-independent. Three working-set notions, swept over a retention window T: W_all retain every block forever (true upper bound) W_oracle keep block in [first_use, last_use] (Belady foresight floor) W_denning(T) distinct blocks touched in (t-T, t] (realistic TTL=T LRU) and the APC actually captured at each T (validates vs the trie ceiling). HARDWARE MODEL KV pool per serving replica = gpus_per_replica * hbm_per_gpu - model_weights - activation_reserve (TP/EP shard weights+KV across the replica's GPUs; the *aggregate* KV pool is what we size against, so only gpus_per_replica and total weights matter.) KV bytes / token: GQA/MHA : 2 * L * kv_heads * head_dim * kv_dtype_bytes MLA : L * (kv_lora_rank + qk_rope_head_dim) * kv_dtype_bytes (matches kvcache-simulator/src/config.rs::kv_block_bytes) All sizes reported in GB = 1e9 bytes (matches the simulator's `hbm_bytes` e9 convention). """ from __future__ import annotations import argparse, json import numpy as np GB = 1e9 # Nominal HBM per GPU, in GB (decimal). GPU_HBM_GB = { "H100": 80, "H200": 141, "H20": 96, "H20-141G": 141, "A100-40G": 40, "A100-80G": 80, "B200": 192, "B300": 288, "GB200": 192, } # ----------------------------------------------------------------------------- model def load_model(config_json: str) -> dict: v = json.load(open(config_json)) L = int(v["num_hidden_layers"]) out = {"name": v.get("model_type", "?"), "L": L} if "kv_lora_rank" in v: # MLA (DeepSeek / GLM-MoE-DSA) out["mla"] = True out["kv_lora_rank"] = int(v["kv_lora_rank"]) out["qk_rope_head_dim"] = int(v["qk_rope_head_dim"]) else: # GQA / MHA out["mla"] = False H = int(v.get("num_attention_heads", 0)) out["kv_heads"] = int(v.get("num_key_value_heads", H) or H) out["head_dim"] = int(v.get("head_dim") or (v["hidden_size"] // H)) return out def kv_bytes_per_token(model: dict, kv_dtype_bytes: int) -> int: L = model["L"] if model["mla"]: return L * (model["kv_lora_rank"] + model["qk_rope_head_dim"]) * kv_dtype_bytes return 2 * L * model["kv_heads"] * model["head_dim"] * kv_dtype_bytes # ----------------------------------------------------------------------------- trace def load_trace(path: str, min_ts=None, max_ts=None): ids, ts = [], [] n = dropped = 0 with open(path) as fh: for line in fh: line = line.strip() if not line: continue r = json.loads(line) h = r.get("hash_ids") if isinstance(h, str): h = json.loads(h) if not h: continue t = float(r.get("timestamp", 0.0)) if (min_ts is not None and t < min_ts) or (max_ts is not None and t > max_ts): dropped += 1 continue ids.extend(h) ts.extend([t] * len(h)) n += 1 if dropped: print(f" (clipped {dropped} reqs outside [{min_ts}, {max_ts}])") return n, np.asarray(ids, dtype=np.int64), np.asarray(ts, dtype=np.float64) def _sweep_peak(starts, ends): """Peak concurrency of intervals [start, end); ends applied before starts at ties.""" ev = np.concatenate([starts, ends]) d = np.concatenate([np.ones(len(starts), np.int64), -np.ones(len(ends), np.int64)]) order = np.lexsort((d, ev)) # at equal time: -1 (end) before +1 (start) return int(np.cumsum(d[order]).max()) def _series(starts, ends, grid): s = np.sort(starts); e = np.sort(ends) return np.searchsorted(s, grid, side="right") - np.searchsorted(e, grid, side="right") def compute_working_set(ids, ts, taus, series_taus=()): """Return dict with appearance stats + per-tau Denning peaks + oracle/all. For each T in series_taus, also return the full W(t) time series on `grid`.""" A = len(ids) order = np.lexsort((ts, ids)) ids_s, ts_s = ids[order], ts[order] same_prev = np.empty(A, bool); same_prev[0] = False same_prev[1:] = ids_s[1:] == ids_s[:-1] same_next = np.empty(A, bool); same_next[-1] = False same_next[:-1] = ids_s[:-1] == ids_s[1:] prev_gap = np.full(A, np.inf); prev_gap[1:][same_prev[1:]] = (ts_s[1:] - ts_s[:-1])[same_prev[1:]] next_gap = np.full(A, np.inf); next_gap[:-1][same_next[:-1]] = (ts_s[1:] - ts_s[:-1])[same_next[:-1]] n_unique = int((~same_prev).sum()) grid = np.linspace(ts.min(), ts.max(), 400) # oracle [first,last] first = np.full(ids.max() + 1, np.inf); last = np.full(ids.max() + 1, -np.inf) np.minimum.at(first, ids, ts); np.maximum.at(last, ids, ts) seen = np.isfinite(first) oracle_peak = _sweep_peak(first[seen], last[seen]) rows = [] series = {} for T in taus: enter = ts_s[prev_gap > T] exit_ = ts_s[next_gap > T] + T peak = _sweep_peak(enter, exit_) ser = _series(enter, exit_, grid) rows.append({ "tau": T, "peak_blocks": peak, "p99_blocks": float(np.percentile(ser, 99)), "p50_blocks": float(np.percentile(ser, 50)), "apc": float((prev_gap <= T).sum() / A), }) if T in series_taus: series[T] = ser return { "A": A, "n_unique": n_unique, "n_reuse": A - n_unique, "apc_ceiling": (A - n_unique) / A, "oracle_peak_blocks": oracle_peak, "span": float(ts.max() - ts.min()), "grid_s": grid - grid.min(), "series": series, "taus": rows, } # ----------------------------------------------------------------------------- plot def plot(ws, hw, block_bytes, label, out_path): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt bgb = block_bytes / GB pool = hw["kv_pool_gb"] # KV pool per node (= per replica) gpr = hw["gpus_per_replica"] ceil = ws["apc_ceiling"] * 100 oracle_nodes = ws["oracle_peak_blocks"] * bgb / pool # all operating points, out to the largest retention window (~50 nodes) rows = list(ws["taus"]) nodes = np.array([r["peak_blocks"] * bgb / pool for r in rows]) apc = np.array([r["apc"] * 100 for r in rows]) XMAX_L = 53 # left panel x-axis (nodes), shows up to T=1800s (~52 nodes) XMAX_R = 16 # right panel y-axis (nodes) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) # ===== panel 1: benefit vs cost -- APC you get per cluster size ===== ax1.plot(nodes, apc, "o-", color="#1f77b4", lw=2, ms=7, zorder=4, label="TTL-LRU cache") # interpolated APC exactly at the 1-node budget apc_at_1 = float(np.interp(1.0, nodes, apc)) ax1.scatter([1], [apc_at_1], s=90, facecolors="none", edgecolors="#ff7f0e", lw=2, zorder=6) ax1.annotate(f"1 node -> ~{apc_at_1:.0f}% APC\n(TTL model; real LRU higher)", (1, apc_at_1), textcoords="offset points", xytext=(14, 8), fontsize=9, color="#ff7f0e", va="bottom") # label the well-separated decision-zone points for r, x, y in zip(rows, nodes, apc): if x >= 1.5: ax1.annotate(f"{r['tau']:g}s", (x, y), textcoords="offset points", xytext=(5, 6), fontsize=9) ax1.annotate("T<=10s reuse:\nall < 1.4 nodes", (1.5, 18), fontsize=8.5, color="#1f77b4", ha="left") # diminishing returns past the oracle point ax1.annotate("diminishing returns:\n14 -> 52 nodes buys only +6pp", (30, 64), fontsize=9, color="#555", ha="center") # budget + ceiling ax1.axvspan(0, 1, color="#2ca02c", alpha=.08) ax1.axvline(1, ls="--", color="#2ca02c", lw=1.8) ax1.text(1.6, 96, "1 B300 node (your budget)", color="#2ca02c", fontsize=9, va="top") ax1.scatter([oracle_nodes], [ceil], marker="*", s=340, color="#d62728", zorder=7) ax1.annotate(f"ceiling {ceil:.1f}% — oracle/LRU\nreaches it at {oracle_nodes:.0f} nodes", (oracle_nodes, ceil), textcoords="offset points", xytext=(12, -4), fontsize=9, color="#d62728", ha="left", va="top") ax1.axhline(ceil, ls=":", color="#d62728", alpha=.5) ax1.set_xlim(0, XMAX_L); ax1.set_ylim(0, 100) ax1.set_xticks(range(0, 51, 10)); ax1.set_xticks(range(0, XMAX_L, 5), minor=True) ax1.set_xlabel(f"# nodes of GPU HBM needed (1 node = {gpr}x {hw['gpu']} = {pool:.0f} GB KV)") ax1.set_ylabel("Prefix-cache hit rate (APC %)") ax1.set_title("Benefit vs cost: APC per cluster size", fontweight="bold") ax1.grid(alpha=.3); ax1.grid(alpha=.15, which="minor"); ax1.legend(loc="lower right") # ===== panel 2: working set W(t) over time (steady -> peak ~ median) ===== apc_of = {r["tau"]: r["apc"] * 100 for r in ws["taus"]} t_min = ws["grid_s"] / 60.0 # minutes colors = {2: "#2ca02c", 30: "#ff7f0e", 300: "#1f77b4"} for T, ser in sorted(ws["series"].items()): y = ser * bgb / pool c = colors.get(T, "#777") ax2.plot(t_min, y, lw=1.8, color=c, label=f"keep {T:g}s reuse (APC {apc_of[T]:.0f}%)") ax2.axhline(float(np.median(y)), ls=":", color=c, alpha=.6, lw=1) ax2.axhline(1, ls="--", color="#2ca02c", lw=1.6, alpha=.8) ax2.text(t_min.max(), 1, " 1-node budget", color="#2ca02c", fontsize=8.5, va="center") ax2.axhline(oracle_nodes, ls="--", color="#d62728", lw=1.6, alpha=.8) ax2.text(t_min.max(), oracle_nodes, " ceiling: 14 nodes", color="#d62728", fontsize=8.5, va="center") ax2.set_ylim(0, XMAX_R); ax2.set_yticks(range(0, XMAX_R + 1, 2)) ax2.set_xlim(0, t_min.max()) ax2.set_xlabel("wall-clock time into the trace (min)") ax2.set_ylabel("# nodes of GPU HBM resident (W(t))") ax2.set_title("Working set over time (flat -> peak ~ median)", fontweight="bold") ax2.grid(alpha=.3); ax2.legend(loc="center right", fontsize=9) fig.suptitle(label, fontsize=13, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.97]) fig.savefig(out_path, dpi=130) print(f" figure -> {out_path}") # ----------------------------------------------------------------------------- main def main(): ap = argparse.ArgumentParser() ap.add_argument("trace") ap.add_argument("--model-config", required=True, help="path to HF config.json") ap.add_argument("--gpu", required=True, choices=sorted(GPU_HBM_GB)) ap.add_argument("--tp", type=int, default=8) ap.add_argument("--pp", type=int, default=1) ap.add_argument("--ep", type=int, default=0, help="informational only (KV unchanged by EP)") ap.add_argument("--kv-dtype-bytes", type=int, default=1, help="1=FP8, 2=BF16") ap.add_argument("--weight-gb", type=float, required=True, help="total resident model weights, GB") ap.add_argument("--activation-gb", type=float, default=32.0, help="activation+ctx reserve, GB") ap.add_argument("--block-size", type=int, default=512) ap.add_argument("--min-ts", type=float, default=None, help="drop reqs with timestamp < this") ap.add_argument("--max-ts", type=float, default=None, help="drop reqs with timestamp > this") ap.add_argument("--label", default="") ap.add_argument("--out", default="figs/working_set.png") a = ap.parse_args() model = load_model(a.model_config) kv_tok = kv_bytes_per_token(model, a.kv_dtype_bytes) block_bytes = kv_tok * a.block_size gpus_per_replica = a.tp * a.pp total_hbm = gpus_per_replica * GPU_HBM_GB[a.gpu] kv_pool_gb = total_hbm - a.weight_gb - a.activation_gb hw = {"gpus_per_replica": gpus_per_replica, "kv_pool_gb": kv_pool_gb, "gpu": a.gpu} taus = [1, 2, 5, 10, 30, 60, 300, 600, 1800] series_taus = [2, 30, 300] # W(t) lines drawn in panel 2 n, ids, ts = load_trace(a.trace, a.min_ts, a.max_ts) ws = compute_working_set(ids, ts, taus, series_taus) label = a.label or f"{model['name']} {a.gpu} TP{a.tp}" + (f" EP{a.ep}" if a.ep else "") print("=" * 84) print(f" {label}") print("=" * 84) print(f" model {model['name']} L={model['L']} " + (f"MLA(kv_lora={model['kv_lora_rank']}+rope={model['qk_rope_head_dim']})" if model["mla"] else f"GQA(kv_heads={model['kv_heads']}xhd={model['head_dim']})")) print(f" KV / token {kv_tok:,} B ({kv_tok/1024:.1f} KiB) KV / block({a.block_size}) {block_bytes/1e6:.1f} MB") print(f" hardware {gpus_per_replica}x {a.gpu} ({GPU_HBM_GB[a.gpu]} GB) = {total_hbm:.0f} GB HBM/replica" + (f" EP={a.ep}" if a.ep else "")) print(f" weights {a.weight_gb:.0f} GB ({a.kv_dtype_bytes}B-KV) + act {a.activation_gb:.0f} GB" f" => KV pool/replica = {kv_pool_gb:.0f} GB") print() print(f" trace {n:,} reqs span {ws['span']:.0f}s ({ws['span']/3600:.2f}h) QPS~{n/ws['span']:.1f}") print(f" block appearances {ws['A']:,} distinct {ws['n_unique']:,} APC ceiling {ws['apc_ceiling']*100:.2f}%") bgb = block_bytes / GB print(f" W_all (retain forever) {ws['n_unique']*bgb:>10,.0f} GB" f" = {ws['n_unique']*bgb/kv_pool_gb:6.1f} replicas ({ws['n_unique']*bgb/kv_pool_gb*gpus_per_replica:,.0f} GPU)") print(f" W_oracle (full ceiling) {ws['oracle_peak_blocks']*bgb:>10,.0f} GB" f" = {ws['oracle_peak_blocks']*bgb/kv_pool_gb:6.1f} replicas ({ws['oracle_peak_blocks']*bgb/kv_pool_gb*gpus_per_replica:,.0f} GPU)") print() print(f" {'T':>7} | {'peak GB':>9} {'p50 GB':>8} | {'replicas':>8} {'GPUs':>6} | {'APC@T':>6}") print(" " + "-" * 60) for r in ws["taus"]: pg = r["peak_blocks"] * bgb rep = pg / kv_pool_gb print(f" {r['tau']:>6g}s | {pg:>9,.0f} {r['p50_blocks']*bgb:>8,.0f} | " f"{rep:>8.1f} {rep*gpus_per_replica:>6.0f} | {r['apc']*100:>5.1f}%") print() print(f" [ref] 1 replica = {gpus_per_replica} GPU = {kv_pool_gb:.0f} GB KV pool") import os os.makedirs(os.path.dirname(a.out) or ".", exist_ok=True) plot(ws, hw, block_bytes, label, a.out) if __name__ == "__main__": main()