phase 5: naive multi-head attention

- Batched GEMM via cublasGemmStridedBatchedEx
- Causal mask CUDA kernel (F32 + BF16)
- Element-wise scale CUDA kernel (F32 + BF16)
- attention() composing: batched_matmul + scale + causal_mask + softmax
- Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip)
- 5 attention tests passing (max_err < 3e-7 F32)
- Total: 61 tests passing across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 21:17:23 +08:00
parent c8e8153702
commit 6035ffdc0b
10 changed files with 550 additions and 12 deletions

92
docs/05-attention.md Normal file
View File

@@ -0,0 +1,92 @@
# Phase 5: Naive Attention Kernel — Design Document
## Goal
实现标准 Multi-Head Attention不做 Flash/Paged 优化用组合式方法GEMM + Softmax完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
## 计算流程
```
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
B=batch, H=num_heads, S=seq_len, D=head_dim
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
2. scores += causal_mask → 上三角置为 -inf
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
4. output = weights @ V → [B, H, S, D]
```
## 设计选择
### 组合式实现Phase 3 GEMM + Phase 4 Softmax
不写新的 fused CUDA kernel而是复用已有的 matmul 和 softmax
- `scores = batched_matmul(Q, K^T)` — 需要支持 batched GEMM
- `masked_fill(scores, causal_mask, -inf)` — 新的逐元素 kernel
- `softmax(scores)` — 复用 Phase 4
- `output = batched_matmul(weights, V)` — 复用 batched GEMM
这意味着需要先扩展 matmul 支持 batched GEMMcublasGemmStridedBatchedEx
### Causal Mask
不显式构造 mask 矩阵。写一个 kernel
```
if (col > row + offset) score = -infinity
```
其中 offset 用于支持 KV cache 场景decode 时 query 的 row 偏移)。
### Batched GEMM via cuBLAS
`cublasGemmStridedBatchedEx` 在一个 batch 维度上并行执行多个 GEMM
```
C[b] = A[b] @ B[b] for b = 0..batch_count
stride_a = M * K, stride_b = K * N, stride_c = M * N
```
Attention 中 batch 维度 = B * Hbatch_size × num_heads
## 文件布局
```
csrc/attention/
└── causal_mask.cu # causal mask fill kernel
crates/xserv-kernels/src/
├── gemm.rs # 扩展: batched_matmul
├── attention.rs # NEW: multi_head_attention()
└── causal_mask.rs # NEW: causal mask apply
```
## API 设计
```rust
/// Multi-head attention (naive, materializes S×S scores).
/// q, k, v: [batch, num_heads, seq_len, head_dim]
/// Returns: [batch, num_heads, seq_len, head_dim]
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
/// Batched matmul: A[b] @ B[b] for all b.
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
```
## Test Plan
- [x] batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
- [x] attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
- [x] attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
- [x] attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
- [x] causal mask 语义: position 0 只能看到 token 0output[0] == V[0] → exact
## Takeaways
1. **`to_device` 不应强制 contiguous**:最初 `to_device()` 会先调 `contiguous()`,而 GPU 的 `contiguous()` 又调 `to_device(Cpu)`,导致无限递归栈溢出。修复:`to_device()` 直接传输 raw storage保留 strides/offset用户需要时自己调 `contiguous()`。GPU `contiguous()` 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效Phase 15 需要写 GPU contiguous kernel。
2. **Batched GEMM via `cublasGemmStridedBatchedEx`**row-major trick 同 Phase 3额外参数是 stride元素数不是字节。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 `elem_size`cuBLAS 的 stride 单位是元素。
3. **Attention 的组合式实现足够验证正确性**:没有写 fused kernel而是复用 `batched_matmul` + `scale` + `causal_mask` + `softmax`。精度极好max_err < 1e-7因为每步都在 FP32 中完成缺点是 S×S score 矩阵完全 materializeO(S²) 显存Flash Attention 会解决
4. **Scale kernel 的必要性**原本想在 CPU 上做 scaleround-trip但那太慢了加了 `scale_f32/bf16` 逐元素 CUDA kernel未来可以把 scale 合进 GEMM alpha 参数省一次 kernel launch
5. **Causal mask 的 offset 设计**`col > row + offset` 中的 offset KV cache 场景预留Decode Q 只有 1 行但 KV cache 有前 S offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens