Files
xserv/docs/05-attention.md
Gahow Wang 6035ffdc0b phase 5: naive multi-head attention
- Batched GEMM via cublasGemmStridedBatchedEx
- Causal mask CUDA kernel (F32 + BF16)
- Element-wise scale CUDA kernel (F32 + BF16)
- attention() composing: batched_matmul + scale + causal_mask + softmax
- Fixed to_device/contiguous infinite recursion (GPU contiguous via CPU round-trip)
- 5 attention tests passing (max_err < 3e-7 F32)
- Total: 61 tests passing across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 21:17:23 +08:00

4.1 KiB
Raw Blame History

Phase 5: Naive Attention Kernel — Design Document

Goal

实现标准 Multi-Head Attention不做 Flash/Paged 优化用组合式方法GEMM + Softmax完成。这是理解 attention 计算流程的基础,也是后续 Flash Attention 的 baseline。

计算流程

Input: Q [B, H, S, D], K [B, H, S, D], V [B, H, S, D]
       B=batch, H=num_heads, S=seq_len, D=head_dim

1. scores = Q @ K^T / sqrt(D)           → [B, H, S, S]
2. scores += causal_mask                 → 上三角置为 -inf
3. weights = softmax(scores, dim=-1)     → [B, H, S, S]
4. output = weights @ V                  → [B, H, S, D]

设计选择

组合式实现Phase 3 GEMM + Phase 4 Softmax

不写新的 fused CUDA kernel而是复用已有的 matmul 和 softmax

  • scores = batched_matmul(Q, K^T) — 需要支持 batched GEMM
  • masked_fill(scores, causal_mask, -inf) — 新的逐元素 kernel
  • softmax(scores) — 复用 Phase 4
  • output = batched_matmul(weights, V) — 复用 batched GEMM

这意味着需要先扩展 matmul 支持 batched GEMMcublasGemmStridedBatchedEx

Causal Mask

不显式构造 mask 矩阵。写一个 kernel

if (col > row + offset) score = -infinity

其中 offset 用于支持 KV cache 场景decode 时 query 的 row 偏移)。

Batched GEMM via cuBLAS

cublasGemmStridedBatchedEx 在一个 batch 维度上并行执行多个 GEMM

C[b] = A[b] @ B[b]  for b = 0..batch_count
stride_a = M * K, stride_b = K * N, stride_c = M * N

Attention 中 batch 维度 = B * Hbatch_size × num_heads

文件布局

csrc/attention/
└── causal_mask.cu      # causal mask fill kernel

crates/xserv-kernels/src/
├── gemm.rs             # 扩展: batched_matmul
├── attention.rs        # NEW: multi_head_attention()
└── causal_mask.rs      # NEW: causal mask apply

API 设计

/// Multi-head attention (naive, materializes S×S scores).
/// q, k, v: [batch, num_heads, seq_len, head_dim]
/// Returns: [batch, num_heads, seq_len, head_dim]
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tensor;

/// Batched matmul: A[b] @ B[b] for all b.
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor;

Test Plan

  • batched_matmul: [4,8,32,64]×[4,8,64,32] → max_err 2.7e-7
  • attention (non-causal): B=1,H=2,S=8,D=16 → max_err 4.5e-8
  • attention (causal): B=1,H=2,S=16,D=32 → max_err 3.0e-8
  • attention (causal, larger): B=2,H=4,S=64,D=64 → max_err 6.0e-8
  • causal mask 语义: position 0 只能看到 token 0output[0] == V[0] → exact

Takeaways

  1. to_device 不应强制 contiguous:最初 to_device() 会先调 contiguous(),而 GPU 的 contiguous() 又调 to_device(Cpu),导致无限递归栈溢出。修复:to_device() 直接传输 raw storage保留 strides/offset用户需要时自己调 contiguous()。GPU contiguous() 现在走 GPU→CPU→CPU contiguous→CPU→GPU 路径——正确但低效Phase 15 需要写 GPU contiguous kernel。

  2. Batched GEMM via cublasGemmStridedBatchedExrow-major trick 同 Phase 3额外参数是 stride元素数不是字节。stride_a = M×K, stride_b = K×N, stride_c = M×N。注意初始版本错误地乘了 elem_sizecuBLAS 的 stride 单位是元素。

  3. Attention 的组合式实现足够验证正确性:没有写 fused kernel而是复用 batched_matmul + scale + causal_mask + softmax。精度极好max_err < 1e-7因为每步都在 FP32 中完成。缺点是 S×S score 矩阵完全 materializeO(S²) 显存Flash Attention 会解决。

  4. Scale kernel 的必要性:原本想在 CPU 上做 scaleround-trip但那太慢了。加了 scale_f32/bf16 逐元素 CUDA kernel。未来可以把 scale 合进 GEMM 的 alpha 参数,省一次 kernel launch。

  5. Causal mask 的 offset 设计col > row + offset 中的 offset 为 KV cache 场景预留。Decode 时 Q 只有 1 行但 KV cache 有前 S 行offset = kv_len - q_len 确保 decode query 能看到所有 cached tokens。