Files
xserv/docs/14-flash-attention.md
Gahow Wang 6cc1c9332d docs: Phase 14 design doc + benchmark, fix Phase 11/12 honesty
Phase 14 (Flash Attention):
- Design doc: FA2 algorithm, SM120 hardware constraints (FA4 incompatible),
  kernel config (BR=BC=64, 32KB smem), GQA mapping, causal tile-skip,
  known limitations and optimization roadmap
- Benchmark doc: correctness (9/10 top-1 match, identical to pre-FA baseline),
  performance tracking (6.9→10.3→12.9 tok/s across phases), memory savings
  analysis, remaining bottleneck breakdown

Phase 11 doc: title corrected from "Paged Attention" to "GPU-Resident KV Cache"
with explicit note that paged allocation was not implemented.

Phase 12 doc: "当前状态" updated from "未实现" to reflect actual state —
iteration-level scheduling implemented + verified (6.0x concurrent speedup),
batched GPU forward explicitly marked as not yet implemented.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-22 18:51:29 +08:00

168 lines
7.5 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 14: Flash Attention 2 for SM120 — Design Document
## Goal
用自写的 Flash Attention 2 CUDA kernel 替换 naive attention (Phase 5)。消除 O(S²) 显存分配,支持 GQA kernel 内部索引(消除 repeat_kv 开销)。
## 硬件约束: FA4 不适用于 RTX 5090
Flash Attention 已发展到第 4 代 (FA4, arxiv 2603.05451),但各版本有明确硬件依赖:
| 版本 | 目标架构 | 关键硬件特性 | RTX 5090 (SM120) |
|------|---------|------------|-----------------|
| FA2 | 通用 CUDA (SM75+) | shared memory + HMMA | **兼容** |
| FA3 | Hopper SM90 (H100) | TMA + WGMMA + warp specialization | 不兼容 |
| FA4 | Blackwell SM100 (B200/B300) | TMEM + async MMA + 2-CTA mode | 不兼容 |
RTX 5090 使用消费级 Blackwell (GB202, SM120),与数据中心 Blackwell (B200, SM100) 是不同硅片。SM120 **没有 TMEM (Tensor Memory)**,这是 FA4 kernel 设计的核心硬件依赖。这不是软件限制,是硬件级差异。
因此本项目实现 **FA2 算法**,使用标准 CUDA (shared memory + 标准 HMMA)。
## Naive Attention 的问题
Phase 5 的 naive attention 流程:
```
k_t = K.transpose(2,3).contiguous() ← 分配 K^T 显存
scores = batched_matmul(Q, k_t) ← 分配 [B,H,S,S] score 矩阵 (O(S²) 显存)
scores = scale(scores, 1/sqrt(d)) ← 逐元素 kernel
causal_mask(scores) ← 逐元素 kernel
weights = softmax(scores) ← 分配 [B,H,S,S] weight 矩阵
output = batched_matmul(weights, V) ← 最终结果
```
问题:
1. **显存 O(S²)**: score 和 weight 矩阵各需 `B × H × S × S × dtype_size`。S=2048, H=32, BF16 → 256 MB。S=8192 → 4 GB。
2. **GQA 预处理**: 在调用 attention 前需要 `repeat_kv_gpu` 将 K/V 从 8 heads 扩展到 32 heads每层额外分配和拷贝。
3. **多次 kernel launch**: scale, mask, softmax 各一次 kernel launch + global memory round-trip。
4. **K^T materialization**: `K.transpose().contiguous()` 需要分配和拷贝。
## FA2 算法
核心思想: **不 materialize S×S 矩阵**。将 Q, K, V 分成 tiles在 shared memory (SRAM) 中计算,使用 **online softmax trick** 边算边更新 running max 和 sum。
FA2 (Dao 2023) 相比 FA1 的改进: 外层循环遍历 Q tiles (而非 K/V),减少 HBM 读写次数,提高并行性。
```
scale = 1 / sqrt(head_dim)
for each Q tile (q_start..q_start + BR): // 外层: Q tiles
load Q_tile [BR, D] to shared memory (一次加载,内层复用)
init per-row: O[D] = 0, m = -inf, l = 0
for each K/V tile j (kv_start..kv_start + BC): // 内层: K/V tiles
// Causal tile-skip: 如果整个 K tile 在 Q tile "未来",跳过
if causal && kv_start > max_q_pos + kv_offset: skip
load K_tile [BC, D] to shared memory
S = Q_tile @ K_tile^T * scale // [BR, BC], in registers
if causal: mask S[r][c] = -inf where kv_pos > q_pos
// Online softmax update
m_new = max(m, rowmax(S))
P = exp(S - m_new)
l_new = exp(m - m_new) * l + rowsum(P)
O = exp(m - m_new) * O // rescale accumulator
load V_tile [BC, D] to shared memory (复用 K 的空间)
O += P @ V_tile // accumulate
m = m_new, l = l_new
O = O / l // final normalize
write O[BR, D] to HBM (convert FP32 → BF16)
```
## 实现细节
### Kernel 配置
| 参数 | 值 | 说明 |
|------|---|------|
| BR (Q tile rows) | 64 | Q tile 大小 |
| BC (K/V tile rows) | 64 | K/V tile 大小 |
| head_dim | 运行时参数 (≤128) | 支持 64 (GPT-2) 和 128 (Qwen3) |
| Block size | 128 threads | 64 线程各 own 一行 Q其余协助加载 |
| Grid | (q_tiles, batch × num_q_heads) | 每个 block 处理一个 Q tile + 一个 head |
### Shared Memory (BF16 存储)
```
smem_q [BR × head_dim] BF16 = 64 × 128 × 2 = 16 KB (加载一次,内层复用)
smem_kv[BC × head_dim] BF16 = 64 × 128 × 2 = 16 KB (K 和 V 交替使用)
────────────────────────────────────────────
Total: 32 KB (SM120 默认 48 KB余量充足)
```
### 线程映射
- Thread 0..63: 各 own Q_tile 的一行。负责该行的全部计算dot products、softmax、PV 累加。
- Thread 64..127: 协助 shared memory 加载 (K/V tile),不参与计算。
- 加载模式: 每个 thread 加载 `(BR × head_dim) / 128 = 64` 个 BF16 元素。
### Per-Thread Register 使用
```
O_acc[128] FP32 = 512 bytes (128 regs) — 输出累加器
P[64] FP32 = 256 bytes (64 regs) — 当前 tile 的 softmax 后权重
m, l FP32 = 8 bytes (2 regs) — online softmax running state
循环变量 + 临时 ≈ 16 regs
────────────────────────────────────────────
Total: ~210 regs/thread (max 255在限制内)
```
### GQA 支持
每个 thread block 处理一个 Q head通过 `kv_head = q_head / (num_q_heads / num_kv_heads)` 映射到对应的 KV head。K/V 的数据指针直接指向 KV head 的存储,无需 repeat_kv。
```
// 32 Q heads, 8 KV heads → heads_per_group = 4
// Q head 0,1,2,3 → KV head 0
// Q head 4,5,6,7 → KV head 1
// ...
kv_head = q_head / heads_per_group;
K_ptr = K + (batch * num_kv_heads + kv_head) * kv_len * head_dim;
```
### Causal Mask
两级优化:
1. **Tile-level skip**: 如果 `kv_tile_start > max_q_pos + kv_offset`,整个 K/V tile 都在未来,跳过(减少 ~50% 计算)。
2. **Element-level mask**: 在 tile 内部,`if kv_pos > q_pos + kv_offset: S = -inf`
`kv_offset = kv_len - q_len` 处理 decode 时 KV cache 长于 Q 的情况。
## 与 Naive Attention 的对比
| 特性 | Naive (Phase 5) | FA2 (Phase 14) |
|------|----------------|----------------|
| 显存 | O(B × H × S²) | O(B × H × S × D) |
| GQA | 需要 repeat_kv (分配+拷贝) | Kernel 内部索引 (零开销) |
| K^T | 需要 transpose+contiguous | Kernel 内部计算 |
| Kernel launches | 6 (matmul, scale, mask, softmax, matmul, ...) | 1 (单个 fused kernel) |
| S=8192 可行性 | OOM (~4 GB score matrix) | 可行 (32 KB shared memory) |
## 源码结构
```
csrc/attention/flash_attention.cu — FA2 kernel (BF16 in, FP32 accumulate, BF16 out)
crates/xserv-kernels/src/attention.rs — flash_attention() Rust wrapper + 原 attention() 保留
crates/xserv-model/src/qwen3.rs — forward_gpu_cache 调用 flash_attention
```
## 已知局限与后续优化方向
1. **Decode (Q_len=1) 效率低**: BR=64 线程中只有 1 个 activeowns_row。应写专用 decode attention kernel沿 KV 维度 parallel reduction。
2. **无向量化加载**: 当前逐元素 bf16→f32 转换,应改用 `float4``__nv_bfloat162` 批量加载。
3. **Register tiling**: 每个 thread 目前串行计算 dot product (128 MADs per K column)。可改为多线程协作。
4. **K/V double buffering**: 可在计算当前 tile 时预加载下一个 tile 到另一半 shared memory。
5. **Tile size 调优**: 更大的 tile (BR=128) 可能在长 sequence 时更优,需要 opt-in shared memory。
## Test Plan
- [x] 正确性: logits 与 HF transformers 对比 (top-1 match 9/10, top-5 overlap 4.0/5)
- [x] 生成质量: 52/52 prompt 生成连贯文本,中英文均可
- [x] SSE streaming 正常工作
- [x] 性能: 12.9 tok/s (vs naive 10.3 tok/s, +25%)
- [ ] 长 sequence (S=4096, S=8192): 验证 naive OOM 而 FA2 正常
- [ ] ncu profile: compute utilization, memory throughput