- KVCache: per-layer, per-head storage with append + reconstruct - forward_with_cache: prefill (full prompt) + decode (single token) modes - Fixed data layout bug: per-head vectors avoid cross-head interleaving - CLI updated to use KV cache by default - bench-gpt2 supports --no-cache flag for comparison Benchmark results (50 prompts × 20 tokens): - KV cache vs no-cache: 50/50 bit-identical (cache is correct) - 18x speedup: TTFT 400→24ms, TBT 407→22ms, throughput 2.5→44 tok/s - vs HF transformers: 40/50 match (10 are FP divergence, avg logit gap 0.20) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2.7 KiB
Phase 9: KV Cache + Autoregressive Generation — Design Document
Goal
实现 KV Cache,将 decode 从每步 full forward (O(S²)) 降为增量计算 (O(S))。这是最大的单点性能提升。
核心变化
Before (no cache)
每生成一个 token:
forward(all_tokens) → 重新计算所有层的 Q/K/V/attention
开销: O(S²) attention per step, S 递增
After (with cache)
Prefill:
forward(prompt_tokens) → 计算并缓存所有层的 K/V
Decode (per token):
forward(last_token_only) → 只计算新 token 的 Q/K/V
Q: [1, H, 1, D] → 新 token 的 query
K: append to cache → cache 变为 [1, H, S+1, D]
V: append to cache
attention: Q @ K_cache^T → [1, H, 1, S+1], O(S) not O(S²)
KVCache 数据结构
pub struct KVCache {
k: Vec<Tensor>, // per layer, shape [1, num_heads, current_len, head_dim]
v: Vec<Tensor>,
len: usize, // current sequence length
}
Forward Pass 变化
模型需要两种 forward 模式:
- prefill(tokens): 处理完整 prompt,填充 KV cache
- decode(token, cache): 处理单个 token,读写 KV cache
实现策略
为了最小化改动,在 GPT-2 forward 中加入可选的 &mut KVCache 参数:
- cache=None → 现有行为(full forward)
- cache=Some → prefill 或 decode 模式
CPU round-trip 问题暂不修复(Phase 15),先让 KV cache 逻辑正确。
Test Plan
- KV cache vs no-cache: 50/50 bit-identical output
- Benchmark: 18x decode speedup (407ms → 22ms TBT)
- 50 prompt validation: 40/50 vs HF (10 are FP divergence, gap 0.04-0.56)
Takeaways
-
KV cache 数据布局是核心难点:初始实现直接 append flat bytes 导致 head 维度交错错误。正确做法:per-head 独立存储,reconstruct 时按
[1, H, S, D]layout 组装。这是一个非常容易犯的 layout bug,调试时输出看起来"几乎对"但不完全对。 -
18x 提速 > 理论预期:理论上 KV cache 将 decode 从 O(S²) 降到 O(S),对 S=20-25 的序列预期 ~20x 提速。实测 18x 符合预期。TTFT 也从 400ms 降到 24ms,因为 prefill 只跑一次而不是每步重跑。
-
xserv vs HF 的 10 个 mismatch 不是 bug:logit gap 仅 0.04-0.56(在 -80 到 -140 的 logit 值上),是不同 CUDA kernel 实现间的浮点累积误差导致 argmax 翻转。重要验证:xserv KV-cache vs xserv no-cache 是 50/50 完全一致的——证明 KV cache 实现本身无误。
-
CPU round-trip 仍是主要瓶颈:KV cache 的 per-head 数据存在 CPU Vec 中,每步 decode 都要重新组装成 GPU tensor。这意味着每步仍有 24 次 GPU→CPU→GPU 传输(12 层 × 2 KV)。Phase 15 需要将 KV cache 直接放在 GPU 上。