- Batched GEMM via cublasGemmStridedBatchedEx - Causal mask CUDA kernel (F32 + BF16) - Element-wise scale CUDA kernel (F32 + BF16) - attention() composing: batched_matmul + scale + causal_mask + softmax - Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip) - 5 attention tests passing (max_err < 3e-7 F32) - Total: 61 tests passing across all crates Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
4.1 KiB
Phase 5: Naive Attention Kernel — Design Document
Goal
实现标准 Multi-Head Attention(不做 Flash/Paged 优化),用组合式方法(GEMM + Softmax)完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。
计算流程
Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
B=batch, H=num_heads, S=seq_len, D=head_dim
1. scores = Q @ K^T / sqrt(D) → [B, H, S, S]
2. scores += causal_mask → 上三角置为 -inf
3. weights = softmax(scores, dim=-1) → [B, H, S, S]
4. output = weights @ V → [B, H, S, D]
设计选择
组合式实现(Phase 3 GEMM + Phase 4 Softmax)
不写新的 fused CUDA kernel,而是复用已有的 matmul 和 softmax:
scores = batched_matmul(Q, K^T)— 需要支持 batched GEMMmasked_fill(scores, causal_mask, -inf)— 新的逐元素 kernelsoftmax(scores)— 复用 Phase 4output = batched_matmul(weights, V)— 复用 batched GEMM
这意味着需要先扩展 matmul 支持 batched GEMM(cublasGemmStridedBatchedEx)。
Causal Mask
不显式构造 mask 矩阵。写一个 kernel:
if (col > row + offset) score = -infinity
其中 offset 用于支持 KV cache 场景(decode 时 query 的 row 偏移)。
Batched GEMM via cuBLAS
cublasGemmStridedBatchedEx 在一个 batch 维度上并行执行多个 GEMM:
C[b] = A[b] @ B[b] for b = 0..batch_count
stride_a = M * K, stride_b = K * N, stride_c = M * N
Attention 中 batch 维度 = B * H(batch_size × num_heads)。
文件布局
csrc/attention/
└── causal_mask.cu # causal mask fill kernel
crates/xserv-kernels/src/
├── gemm.rs # 扩展: batched_matmul
├── attention.rs # NEW: multi_head_attention()
└── causal_mask.rs # NEW: causal mask apply
API 设计
/// Multi-head attention (naive, materializes S×S scores).
/// q, k, v: [batch, num_heads, seq_len, head_dim]
/// Returns: [batch, num_heads, seq_len, head_dim]
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;
/// Batched matmul: A[b] @ B[b] for all b.
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;
Test Plan
- batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
- attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
- attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
- attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
- causal mask 语义: position 0 只能看到 token 0,output[0] == V[0] → exact
Takeaways
-
to_device不应强制 contiguous:最初to_device()会先调contiguous(),而 GPU 的contiguous()又调to_device(Cpu),导致无限递归栈溢出。修复:to_device()直接传输 raw storage,保留 strides/offset,用户需要时自己调contiguous()。GPUcontiguous()现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效,Phase 15 需要写 GPU contiguous kernel。 -
Batched GEMM via
cublasGemmStridedBatchedEx:row-major trick 同 Phase 3,额外参数是 stride(元素数,不是字节)。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了elem_size,cuBLAS 的 stride 单位是元素。 -
Attention 的组合式实现足够验证正确性:没有写 fused kernel,而是复用
batched_matmul+scale+causal_mask+softmax。精度极好(max_err < 1e-7),因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materialize(O(S²) 显存),Flash Attention 会解决。 -
Scale kernel 的必要性:原本想在 CPU 上做 scale(round-trip),但那太慢了。加了
scale_f32/bf16逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。 -
Causal mask 的 offset 设计:
col > row + offset中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行,offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。