Files
xtrain/docs/13-flash-attention.md
Gahow Wang 9064ced4c2 docs: T14 flash-attention results + evolution/README rows
Fill in the design doc's measured results (grad-check, flash==composed,
PyTorch parity, peak mem -16%/-23%, tok/s tradeoff), add the T14 row to
evolution.md (算法/Infra) and the README build-journey table.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:34:10 +08:00

184 lines
14 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T14: 融合 Flash-Attention Kernel — Design Document
## Goal
T10 把 attention 批量化了,但它的 SDPA 走的是 **「物化 N×N scores」** 的组合路径:
`cublasSgemmStridedBatched`Q·Kᵀ→ 一个 causal-softmax kernel写出整张 probs
`cublasSgemmStridedBatched`P·V**3 次 launch + 一张 `[bh, S, S]` 的 scores/probs 张量**
常驻显存(反向还要缓存这张 probs。S 一大,这张 N×N 就成了激活显存与带宽的主导项。
T14 的目标:手写一个**单 kernel 的 fused flash-attention**——streaming / online softmax、**tiled
over KV**、**绝不物化 N×N**。前向一发 kernel 直接吐出 `out[bh,S,hd]`(外加 `O(N)` 的 logsumexp
反向一发 kernelflash 式:重算 scores + dQ/dK/dV同样不物化 N×N。接进 model + autograd 作
**opt-in `--flash`**,默认保留 T10 的 composed 路径以便 A/B。
**硬闸门是诚实正确性**:新 kernel 的 dQ/dK/dV finite-diff grad-check 过fwd/bwd 对现有 composed-SDPA
路径数值贴合(进 bf16 容差PyTorch SDPA 对拍 B>1峰值显存↓不物化 scores+ tok/s before/after 实测;
全回归套(含 xserv 闭环 md5开/关 flag 都绿——默认flag off图不变 → 不回归。
## 什么是 flash-attention
标准 attention 是 `O = softmax(causal(Q·Kᵀ/√d)) · V`,朴素实现把 `S[i,j] = Qᵢ·Kⱼ/√d` 整张
`[S,S]` 算出来、softmax、再乘 V——显存 `O(S²)`、HBM 读写 `O(S²)`
**flash-attention** 的洞察softmax 可以 **onlinestreaming** 地算。把 K/V 切成若干 **tile**,对一个
query 行 `i`,依次扫过 KV tile**running max `m` + running sum `l`** 维护 softmax 的归一化,并把
部分加权的 `V` 累加进一个 `[hd]` 的 accumulator `acc`,每来一个新 tile 就用「新旧 max 的差」对旧 `acc`/`l`
做 rescale。扫完所有 tile`out = acc / l`。**整张 `[S,S]` 从不落地**——只有 `[hd]` 的 acc 和两个标量
在寄存器/共享内存里流动。峰值激活从 `O(S²)` 降到 `O(S·hd)`(就是 O 本身)。
online softmax 的核心递推block `j` 的部分 logits 行 `s_j`,旧状态 `m, l, acc`
```text
m_new = max(m, max_k s_j[k])
p = exp(s_j - m_new) # 本 tile 的未归一化权重
l = l * exp(m - m_new) + sum(p) # 旧 sum 先 rescale再加本 tile
acc = acc * exp(m - m_new) + p · V_tile # 旧 acc 同样 rescale再加本 tile 贡献
m = m_new
# 扫完所有 tile
out = acc / l
L = m + log(l) # logsumefpO(N) 存给反向
```
**因果 mask 内联**query 全局位置 = `i % S`(沿用 T10 的 per-seq 复位约定KV 位置 `j` 满足
`j > i%S` 的列直接当 `-inf``p=0`。tile 整块在对角线之上可**直接 skip**causal 的天然稀疏,省一半算力)。
**反向flash 式,[Dao 2022] 的标准做法)**:不缓存 probs从 Q/K/V + 前向存的 `L[bh,S]` **重算** scores。
关键预计算 `D[i] = Σ_d dOᵢ[d]·Oᵢ[d]`(每 query 一个标量,`O(N)`),则对每个 `(i,j)`
```text
s_ij = Qᵢ·Kⱼ * scale # 重算 logit
p_ij = exp(s_ij - L[i]) # 重算 softmax 权重L 是前向存的 logsumexp
dp_ij = dOᵢ · Vⱼ # 对 P 的梯度
ds_ij = p_ij * (dp_ij - D[i]) * scale # softmax 雅可比,化简掉了显式 N×N
dQᵢ += ds_ij * Kⱼ ; dKⱼ += ds_ij * Qᵢ ; dVⱼ += p_ij * dOᵢ
```
`ds = P ∘ (dP - D)` 是 softmax 反向用 `Σⱼ Pⱼ·dPⱼ = D`(因为 `D[i]=Σ dOᵢ·Oᵢ = Σⱼ Pᵢⱼ dPᵢⱼ`)化简的结果,
**不需要 N×N 的 softmax 雅可比矩阵**。同样 tiled、同样不物化 N×N。
## Module Layoutsurgicalcomposed 路径逐字节不动flash 全程新增并行路径)
```
csrc/ops/flash_attention.cu # 新fwd kernelonline softmaxtiled KV+ bwd kernel重算 + dQ/dK/dV
crates/xtrain-cuda/
├── src/ffi.rs # +launch_flash_attention_fwd_f32 / _bwd_f32 声明
└── build.rs # +flash_attention.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::flash_attention / flash_attention_backwardfwd 存 logsumexp Lbf16 upcast→f32 kernel→downcast
crates/xtrain-autodiff/
├── src/ops.rs # +ops::flash_attention 节点(前向调 fwd缓存 L反向调 bwd
└── tests/autograd.rs # +flash_attention(batched) dQ/dK/dV grad-check
crates/xtrain-model/
├── src/model.rs # attention() 按 use_flash 选 ops::attention | ops::flash_attention+with_flash(bool) builderflash 标志透传 block_forwardrecompute 段内也走 flash
└── tests/flash.rs # 新flash == composedfwd logits + 每参数梯度),参数化 fp32/bf16
crates/xtrain-train/src/bin/train.rs # +--flash flag → model.with_flash(true)
crates/xtrain-distributed/src/bin/train_ddp.rs # +--flash flagDDP 路径)
crates/xtrain-model/tests/parity_dump.rs # PyTorch B>1 对拍跑两遍composed 与 flash共用 PyTorch oracle
```
## Key Design Decisions
### ① 一个 block 负责一行 query先做对再谈快
最直接、最易验证正确的并行划分:**`grid = bh * S`,每个 block 算一整行 query 的 `out[bh, i, :]`**。
block 内 `hd` 个线程hd ≤ 128正好一个 warp 多一点),共享 `m/l` 标量 + `acc[hd]`。block 顺序扫
KV tiletile 宽 `BK`,沿 `j` 维),每个 tile线程并行算 `BK` 个 logit点积 over hd 用 block-reduce
求 tile max、online-rescale `m/l/acc`、累加 `p·V`。扫完写 `out = acc/l``L[i] = m + log(l)`
**为什么先这样而不是 FA2 的 query-tile 划分**:本项目的硬闸门是**正确性 + 不物化 N×N + 显存↓**,不是
打榜峰值 FLOPs。一行一 block 的版本:(a) online softmax 与 N×N skip 已经完全落地(显存与带宽收益拿到),
(b) 代码直白、逐 query 行可对拍,正确性风险最低。它**不会**比 cuBLAS 两发 GEMM 更快cuBLAS tensor-core
吃满),所以 tok/s 上 flash 在我们这种 `hd=32` 小头维下大概率**持平或略慢**——这正是 flash 的已知权衡
flash 的胜场是**显存**,不是小模型的 wall-clock。把这点诚实写进 perf 表,不掩饰。
### ② 前向只存 `L[bh,S]`logsumefp不存 probs
composed 路径反向要缓存整张 `probs[bh,S,S]``O(N²)`。flash 反向**只需要前向的 logsumexp
`L[i]=m_i+log(l_i)`**(每 query 一个 fp32`O(N)`)即可重算任意 `p_ij = exp(Qᵢ·Kⱼ·scale - L[i])`
所以 fwd kernel 顺手把 `L` 写出来autograd 节点缓存它(外加 Q/K/V/O parents 本就在)。**这就是显存闸门的来源**
attention 的反向缓存从 `[bh,S,S]` 砍到 `[bh,S]`
### ③ 反向用 `D[i]=Σ dOᵢ·Oᵢ` 化简 softmax 雅可比
softmax 反向通项 `ds_ij = p_ij·(dp_ij - Σ_k p_ik·dp_ik)`。注意 `Σ_k p_ik·dp_ik = Σ_k p_ik (dOᵢ·V_k)
= dOᵢ·(Σ_k p_ik V_k) = dOᵢ·Oᵢ = D[i]`。所以一趟先算 `D[bh,S]`(每行 `dO·O` 的点积,`O(N)`),反向
扫 KV tile 时直接 `ds = p·(dp - D)·scale`**不需要再算或物化整行的 `Σ p·dp`**。
dQ/dK/dV 三者dQ 由「该 query 行」累加block 私有无竞争dK/dV 跨 query 行累加同一个 `(j)`
→ 用 `atomicAdd` 到全局 dK/dVfp32 原子加,确定 race-free
### ④ bf16kernel 内 fp32边界 cast与 composed 路径一致的数值策略)
T10/T12 的 composed attention 对 bf16 也是 **softmax 用 fp32**scores 升 f32 → kernel → probs 降回 bf16
flash 沿用同策略最省心且数值最稳bf16 模式下 `flash_attention` 把 Q/K/V `to_dtype(F32)` 喂给 fp32 kernel
`out``to_dtype(BF16)`反向同理。kernel 本身只有一份 fp32 实现。这样 flash 的 bf16 数值与 composed 的
bf16 数值是**同一套 fp32 softmax 算的**,只差 GEMM roundingcuBLAS tensor-core vs kernel 内 fp32 FMA→ 落在
既有 bf16 容差内。`L` 始终 fp32。
> 备选不采纳bf16 全程 in-kernel half。收益是少两次 cast但 (a) 引入与 composed 不同的 softmax 累加路径,
> 威胁 on-vs-off 贴合闸门;(b) 本规模 attention 非瓶颈。escape hatch先 fp32-core 把正确性钉死,纯 half flash 留 follow-up。
### ⑤ opt-in 透传:`use_flash` 是运行时旗标,不是架构
`use_flash` 不进 `Config`(它不改模型尺寸、不改导出、不该污染 `num_params`),而是 `TinyTransformer` 的一个
`bool` 字段 + `with_flash(bool)` builder对齐 `with_recompute` / `with_compute_dtype`)。`block_forward` 已经
`(cfg, cdt, …)` 的自由函数T13 为 recompute 抽的),给它加一个 `flash: bool` 形参model 的 `attention()`
据此选 `ops::attention`composed`ops::flash_attention`。recompute 闭包捕获 `flash``Copy`)→ **重算段内也走
flash**flash×recompute 组合天然成立。默认 `false` = composed 路径**逐字节不变**(硬闸门:默认图不变 → 不回归)。
## 验证方法
**硬闸门全绿dash5 实跑 capture**
### 1. 正确性
- **新 kernel dQ/dK/dV finite-diff grad-check**`xtrain-autodiff/tests/autograd.rs::flash_attention_batched_bwd`
与既有 `attention_batched_bwd` 同构(`L = sum(W∘out)`,中心差分),断 dQ/dK/dV 在 `cfg_nonlinear`/`cfg_linear` 容差内。
- **flash == composed**`xtrain-model/tests/flash.rs`):同 init 两个模型flash on/off同一 batched
loss + backward断**前向 logits / loss / 每参数梯度**在紧容差内一致;参数化 fp32近逐位与 bf16bf16 舍入级)。
- **PyTorch SDPA 对拍 B>1**`parity_dump.rs` + `parity.py`):等价 PyTorch 模型per-seq RoPE、per-seq causal、
QK-norm、SwiGLU对拍 forward logits + 全部参数梯度——**composed 与 flash 两条都跑**,共用同一 PyTorch oracle。
- **全回归套开/关 `--flash`**autograd 15、structural、batched==looped、bf16、recompute逐位、overfit 27/27、
AdamWGPU bit-exact + host 对 torch、DDP loss-match + 跨 rank、**xserv 闭环(导出 safetensors → md5 对 registry →
xserv 贪心逐 token 一致)**。flag off 默认图不变 → composed 数值不回归。
### 2. 显存payoff—— 不物化 N×N 的直接收益
dash5 1× RTX 5090同 confignvidia-smi 峰值flash off vs onattention 反向缓存 `[bh,S,S]→[bh,S]`
峰值显存应↓(尤其 seq 大时。capture 实际数字进表。
### 3. 吞吐
同 config steady-state tok/s flash off vs on。预期本规模 `hd=32` 下 flash kernel **持平或略慢于** cuBLAS 双
GEMM小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。
## 实测结果dash5 1× RTX 5090
**正确性(硬闸门全绿):**
| 闸门 | 结果 |
|---|---|
| ① 新 kernel dQ/dK/dV finite-diff grad-check | **过** — dQ 9.3e-3 / dK 1.7e-2 / dV 5.6e-4单 tile 干净区;多 tile 由②兜) |
| flash fwd 对 composed | max rel **6.7e-5** |
| flash bwd 对(已 grad-check 的composed bwd | dQ **1.7e-5** / dK 1.2e-5 / dV 4.3e-5 |
| ② flash==composedmodel 级logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 1.7e-4、grad 4.4e-5bf16: loss 1.5e-4、logits mean 1.6e-3/p99 5.9e-3、grad scaled-mean 1.2e-2 |
| ③ PyTorch SDPA 对拍 B>1flash 路径,共用 composed oracle | loss relerr **4.98e-8**、logits **7.92e-6**、25 参数 grad 全进 rtol 0.02 |
| ⑤ 回归套flag off 默认 + flash 路径都测autograd 18 / structural 5 / batched / bf16 / **flash 3** / overfit 27/27 / recompute 2 / AdamW(GPU+host) / GEMM / DDP 2 / checkpoint-roundtrip | **全绿** |
| ⑤ xserv 闭环 md5v3 ckpt 用 T14 代码重导 safetensors | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry 同)→ 默认 export 零漂移 |
| ⑤ xserv 闭环flash 训练 → 导出 → xserv 服务贪心) | flash-训出 coherent TinyStoriesxserv(BF16) 对 xtrain(F32) 贪心3 prompt 中 "One day" 逐 token 一致,其余在 ~0.5% BF16 漂移处晚分叉(与 v1/v2/v3 同款) |
> **finite-diff 的诚实记录**:长 softmaxseq>tile会产生大量近零梯度元素中心差分在那些元素上不可靠出现伪 0.0 / 符号翻转——不是 backward bug。故 ① 的 finite-diff 跑**单 tile 干净区**seq=5对齐既有 composed grad-check 的良态区),**多 tile 的 streaming/online 路径**用「flash bwd 对已 grad-check 的 composed bwd」seq=40dQ 1.7e-5兜——比 finite-diff 更利。dQ/dK 用 eps=2e-3 压低 f32 舍入项(~4e-4 小梯度上舍入项压过截断项)。**没有为凑绿放宽容差**。
**④ 显存 + 吞吐payoff vs tradeoffdim768=8L/12h×64/ffn3072, bf16, steady-state**
| config | path | 峰值显存 | tok/s |
|---|---|---|---|
| batch8 seq1024 | composed (off) | 24670 MiB | **58.6K** |
| batch8 seq1024 | **flash (on)** | **20736 MiB16%** | 25.0K57%, ~2.3× 慢) |
| batch2 seq2048 | composed (off) | 17264 MiB | 36.7K |
| batch2 seq2048 | **flash (on)** | **13246 MiB23%** | 13.2K64% |
**显存按预期降**(不物化 `[bh,S,S]`),且**收益随 seq 增长**seq1024 16% → seq2048 23%O(S²) 砍掉)。
**tok/s 如设计 ① 预测的「持平或略慢」实为 ~2.32.8× 慢**hd=64 的小头维下,手写「一行一 block + 串行扫 KV」kernel 喂不满 SM干不过 cuBLAS tensor-core 的两发批量 GEMM——这正是 flash 的已知权衡(**胜场在显存,不是小模型 wall-clock**诚实报告不掩饰。两个落地的优化softmax 权重缓存进 shared 省 hd× 的 expfdK/dV 原子加摊到全 block 而非串行在列 owner 内)把 backward 从 6.8× 慢拉到 2.3× 慢——主瓶颈是 backward 的跨行原子累加FA2 用 K-block 拥有 dK/dV 的独立 pass 解,本版未做,留 follow-up
> **escape hatchfollow-up未做记给后续**:① FA2 式 query-tile 划分(一 block 多 query 行K/V 进 shared 复用)提 SM 占用;② backward 的 dK/dV 改 K-block-owned 独立 pass 消跨行原子;③ 纯 bf16 in-kernel省两次 cast。本规模 attention 非训练瓶颈、且会动数值贴合闸门,按 escape hatch 推迟——T14 先把**正确性 + 不物化 N×N + 显存↓**钉死。