Files
xserv/docs/09-kv-cache.md
Gahow Wang 64084d3489 phase 9: KV cache + autoregressive generation
- KVCache: per-layer, per-head storage with append + reconstruct
- forward_with_cache: prefill (full prompt) + decode (single token) modes
- Fixed data layout bug: per-head vectors avoid cross-head interleaving
- CLI updated to use KV cache by default
- bench-gpt2 supports --no-cache flag for comparison

Benchmark results (50 prompts × 20 tokens):
- KV cache vs no-cache: 50/50 bit-identical (cache is correct)
- 18x speedup: TTFT 400→24ms, TBT 407→22ms, throughput 2.5→44 tok/s
- vs HF transformers: 40/50 match (10 are FP divergence, avg logit gap 0.20)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 23:39:41 +08:00

2.7 KiB
Raw Blame History

Phase 9: KV Cache + Autoregressive Generation — Design Document

Goal

实现 KV Cache将 decode 从每步 full forward (O(S²)) 降为增量计算 (O(S))。这是最大的单点性能提升。

核心变化

Before (no cache)

每生成一个 token:
  forward(all_tokens)  → 重新计算所有层的 Q/K/V/attention
  开销: O(S²) attention per step, S 递增

After (with cache)

Prefill:
  forward(prompt_tokens)  → 计算并缓存所有层的 K/V
  
Decode (per token):
  forward(last_token_only) → 只计算新 token 的 Q/K/V
  Q: [1, H, 1, D]          → 新 token 的 query
  K: append to cache        → cache 变为 [1, H, S+1, D]
  V: append to cache
  attention: Q @ K_cache^T  → [1, H, 1, S+1], O(S) not O(S²)

KVCache 数据结构

pub struct KVCache {
    k: Vec<Tensor>,  // per layer, shape [1, num_heads, current_len, head_dim]
    v: Vec<Tensor>,
    len: usize,      // current sequence length
}

Forward Pass 变化

模型需要两种 forward 模式:

  1. prefill(tokens): 处理完整 prompt填充 KV cache
  2. decode(token, cache): 处理单个 token读写 KV cache

实现策略

为了最小化改动,在 GPT-2 forward 中加入可选的 &mut KVCache 参数:

  • cache=None → 现有行为full forward
  • cache=Some → prefill 或 decode 模式

CPU round-trip 问题暂不修复Phase 15先让 KV cache 逻辑正确。

Test Plan

  • KV cache vs no-cache: 50/50 bit-identical output
  • Benchmark: 18x decode speedup (407ms → 22ms TBT)
  • 50 prompt validation: 40/50 vs HF (10 are FP divergence, gap 0.04-0.56)

Takeaways

  1. KV cache 数据布局是核心难点:初始实现直接 append flat bytes 导致 head 维度交错错误。正确做法per-head 独立存储reconstruct 时按 [1, H, S, D] layout 组装。这是一个非常容易犯的 layout bug调试时输出看起来"几乎对"但不完全对。

  2. 18x 提速 > 理论预期:理论上 KV cache 将 decode 从 O(S²) 降到 O(S),对 S=20-25 的序列预期 ~20x 提速。实测 18x 符合预期。TTFT 也从 400ms 降到 24ms因为 prefill 只跑一次而不是每步重跑。

  3. xserv vs HF 的 10 个 mismatch 不是 buglogit gap 仅 0.04-0.56(在 -80 到 -140 的 logit 值上),是不同 CUDA kernel 实现间的浮点累积误差导致 argmax 翻转。重要验证:xserv KV-cache vs xserv no-cache 是 50/50 完全一致的——证明 KV cache 实现本身无误。

  4. CPU round-trip 仍是主要瓶颈KV cache 的 per-head 数据存在 CPU Vec 中,每步 decode 都要重新组装成 GPU tensor。这意味着每步仍有 24 次 GPU→CPU→GPU 传输12 层 × 2 KV。Phase 15 需要将 KV cache 直接放在 GPU 上。