docs: Phase T14 — fused flash-attention design
Design doc for the hand-written single fused flash-attention kernel: online softmax tiled over KV, NEVER materializing the [bh,S,S] score matrix; flash-style backward (recompute scores from saved logsumexp + D=ΣdO·O, dQ/dK/dV). Opt-in --flash; composed T10 path stays default. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
156
docs/13-flash-attention.md
Normal file
156
docs/13-flash-attention.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Phase T14: 融合 Flash-Attention Kernel — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
T10 把 attention 批量化了,但它的 SDPA 走的是 **「物化 N×N scores」** 的组合路径:
|
||||
`cublasSgemmStridedBatched`(Q·Kᵀ)→ 一个 causal-softmax kernel(写出整张 probs)→
|
||||
`cublasSgemmStridedBatched`(P·V),**3 次 launch + 一张 `[bh, S, S]` 的 scores/probs 张量**
|
||||
常驻显存(反向还要缓存这张 probs)。S 一大,这张 N×N 就成了激活显存与带宽的主导项。
|
||||
|
||||
T14 的目标:手写一个**单 kernel 的 fused flash-attention**——streaming / online softmax、**tiled
|
||||
over KV**、**绝不物化 N×N**。前向一发 kernel 直接吐出 `out[bh,S,hd]`(外加 `O(N)` 的 logsumexp);
|
||||
反向一发 kernel(flash 式:重算 scores + dQ/dK/dV,同样不物化 N×N)。接进 model + autograd 作
|
||||
**opt-in `--flash`**,默认保留 T10 的 composed 路径以便 A/B。
|
||||
|
||||
**硬闸门是诚实正确性**:新 kernel 的 dQ/dK/dV finite-diff grad-check 过;fwd/bwd 对现有 composed-SDPA
|
||||
路径数值贴合(进 bf16 容差);PyTorch SDPA 对拍 B>1;峰值显存↓(不物化 scores)+ tok/s before/after 实测;
|
||||
全回归套(含 xserv 闭环 md5)开/关 flag 都绿——默认(flag off)图不变 → 不回归。
|
||||
|
||||
## 什么是 flash-attention
|
||||
|
||||
标准 attention 是 `O = softmax(causal(Q·Kᵀ/√d)) · V`,朴素实现把 `S[i,j] = Qᵢ·Kⱼ/√d` 整张
|
||||
`[S,S]` 算出来、softmax、再乘 V——显存 `O(S²)`、HBM 读写 `O(S²)`。
|
||||
|
||||
**flash-attention** 的洞察:softmax 可以 **online(streaming)** 地算。把 K/V 切成若干 **tile**,对一个
|
||||
query 行 `i`,依次扫过 KV tile,用 **running max `m` + running sum `l`** 维护 softmax 的归一化,并把
|
||||
部分加权的 `V` 累加进一个 `[hd]` 的 accumulator `acc`,每来一个新 tile 就用「新旧 max 的差」对旧 `acc`/`l`
|
||||
做 rescale。扫完所有 tile,`out = acc / l`。**整张 `[S,S]` 从不落地**——只有 `[hd]` 的 acc 和两个标量
|
||||
在寄存器/共享内存里流动。峰值激活从 `O(S²)` 降到 `O(S·hd)`(就是 O 本身)。
|
||||
|
||||
online softmax 的核心递推(block `j` 的部分 logits 行 `s_j`,旧状态 `m, l, acc`):
|
||||
|
||||
```text
|
||||
m_new = max(m, max_k s_j[k])
|
||||
p = exp(s_j - m_new) # 本 tile 的未归一化权重
|
||||
l = l * exp(m - m_new) + sum(p) # 旧 sum 先 rescale,再加本 tile
|
||||
acc = acc * exp(m - m_new) + p · V_tile # 旧 acc 同样 rescale,再加本 tile 贡献
|
||||
m = m_new
|
||||
# 扫完所有 tile:
|
||||
out = acc / l
|
||||
L = m + log(l) # logsumefp,O(N) 存给反向
|
||||
```
|
||||
|
||||
**因果 mask 内联**:query 全局位置 = `i % S`(沿用 T10 的 per-seq 复位约定),KV 位置 `j` 满足
|
||||
`j > i%S` 的列直接当 `-inf`(`p=0`)。tile 整块在对角线之上可**直接 skip**(causal 的天然稀疏,省一半算力)。
|
||||
|
||||
**反向(flash 式,[Dao 2022] 的标准做法)**:不缓存 probs,从 Q/K/V + 前向存的 `L[bh,S]` **重算** scores。
|
||||
关键预计算 `D[i] = Σ_d dOᵢ[d]·Oᵢ[d]`(每 query 一个标量,`O(N)`),则对每个 `(i,j)`:
|
||||
|
||||
```text
|
||||
s_ij = Qᵢ·Kⱼ * scale # 重算 logit
|
||||
p_ij = exp(s_ij - L[i]) # 重算 softmax 权重(L 是前向存的 logsumexp)
|
||||
dp_ij = dOᵢ · Vⱼ # 对 P 的梯度
|
||||
ds_ij = p_ij * (dp_ij - D[i]) * scale # softmax 雅可比,化简掉了显式 N×N
|
||||
dQᵢ += ds_ij * Kⱼ ; dKⱼ += ds_ij * Qᵢ ; dVⱼ += p_ij * dOᵢ
|
||||
```
|
||||
|
||||
`ds = P ∘ (dP - D)` 是 softmax 反向用 `Σⱼ Pⱼ·dPⱼ = D`(因为 `D[i]=Σ dOᵢ·Oᵢ = Σⱼ Pᵢⱼ dPᵢⱼ`)化简的结果,
|
||||
**不需要 N×N 的 softmax 雅可比矩阵**。同样 tiled、同样不物化 N×N。
|
||||
|
||||
## Module Layout(surgical:composed 路径逐字节不动,flash 全程新增并行路径)
|
||||
|
||||
```
|
||||
csrc/ops/flash_attention.cu # 新:fwd kernel(online softmax,tiled KV)+ bwd kernel(重算 + dQ/dK/dV)
|
||||
crates/xtrain-cuda/
|
||||
├── src/ffi.rs # +launch_flash_attention_fwd_f32 / _bwd_f32 声明
|
||||
└── build.rs # +flash_attention.cu
|
||||
crates/xtrain-tensor/src/tensor.rs # +Tensor::flash_attention / flash_attention_backward(fwd 存 logsumexp L;bf16 upcast→f32 kernel→downcast)
|
||||
crates/xtrain-autodiff/
|
||||
├── src/ops.rs # +ops::flash_attention 节点(前向调 fwd,缓存 L,反向调 bwd)
|
||||
└── tests/autograd.rs # +flash_attention(batched) dQ/dK/dV grad-check
|
||||
crates/xtrain-model/
|
||||
├── src/model.rs # attention() 按 use_flash 选 ops::attention | ops::flash_attention;+with_flash(bool) builder;flash 标志透传 block_forward(recompute 段内也走 flash)
|
||||
└── tests/flash.rs # 新:flash == composed(fwd logits + 每参数梯度),参数化 fp32/bf16
|
||||
crates/xtrain-train/src/bin/train.rs # +--flash flag → model.with_flash(true)
|
||||
crates/xtrain-distributed/src/bin/train_ddp.rs # +--flash flag(DDP 路径)
|
||||
crates/xtrain-model/tests/parity_dump.rs # PyTorch B>1 对拍跑两遍:composed 与 flash(共用 PyTorch oracle)
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### ① 一个 block 负责一行 query(先做对,再谈快)
|
||||
|
||||
最直接、最易验证正确的并行划分:**`grid = bh * S`,每个 block 算一整行 query 的 `out[bh, i, :]`**。
|
||||
block 内 `hd` 个线程(hd ≤ 128,正好一个 warp 多一点),共享 `m/l` 标量 + `acc[hd]`。block 顺序扫
|
||||
KV tile(tile 宽 `BK`,沿 `j` 维),每个 tile:线程并行算 `BK` 个 logit(点积 over hd 用 block-reduce)、
|
||||
求 tile max、online-rescale `m/l/acc`、累加 `p·V`。扫完写 `out = acc/l` 与 `L[i] = m + log(l)`。
|
||||
|
||||
**为什么先这样而不是 FA2 的 query-tile 划分**:本项目的硬闸门是**正确性 + 不物化 N×N + 显存↓**,不是
|
||||
打榜峰值 FLOPs。一行一 block 的版本:(a) online softmax 与 N×N skip 已经完全落地(显存与带宽收益拿到),
|
||||
(b) 代码直白、逐 query 行可对拍,正确性风险最低。它**不会**比 cuBLAS 两发 GEMM 更快(cuBLAS tensor-core
|
||||
吃满),所以 tok/s 上 flash 在我们这种 `hd=32` 小头维下大概率**持平或略慢**——这正是 flash 的已知权衡
|
||||
(flash 的胜场是**显存**,不是小模型的 wall-clock)。把这点诚实写进 perf 表,不掩饰。
|
||||
|
||||
### ② 前向只存 `L[bh,S]`(logsumefp),不存 probs
|
||||
|
||||
composed 路径反向要缓存整张 `probs[bh,S,S]`(`O(N²)`)。flash 反向**只需要前向的 logsumexp
|
||||
`L[i]=m_i+log(l_i)`**(每 query 一个 fp32,`O(N)`)即可重算任意 `p_ij = exp(Qᵢ·Kⱼ·scale - L[i])`。
|
||||
所以 fwd kernel 顺手把 `L` 写出来,autograd 节点缓存它(外加 Q/K/V/O parents 本就在)。**这就是显存闸门的来源**:
|
||||
attention 的反向缓存从 `[bh,S,S]` 砍到 `[bh,S]`。
|
||||
|
||||
### ③ 反向用 `D[i]=Σ dOᵢ·Oᵢ` 化简 softmax 雅可比
|
||||
|
||||
softmax 反向通项 `ds_ij = p_ij·(dp_ij - Σ_k p_ik·dp_ik)`。注意 `Σ_k p_ik·dp_ik = Σ_k p_ik (dOᵢ·V_k)
|
||||
= dOᵢ·(Σ_k p_ik V_k) = dOᵢ·Oᵢ = D[i]`。所以一趟先算 `D[bh,S]`(每行 `dO·O` 的点积,`O(N)`),反向
|
||||
扫 KV tile 时直接 `ds = p·(dp - D)·scale`,**不需要再算或物化整行的 `Σ p·dp`**。
|
||||
dQ/dK/dV 三者:dQ 由「该 query 行」累加(block 私有,无竞争);dK/dV 跨 query 行累加同一个 `(j)`
|
||||
→ 用 `atomicAdd` 到全局 dK/dV(fp32 原子加,确定 race-free)。
|
||||
|
||||
### ④ bf16:kernel 内 fp32,边界 cast(与 composed 路径一致的数值策略)
|
||||
|
||||
T10/T12 的 composed attention 对 bf16 也是 **softmax 用 fp32**(scores 升 f32 → kernel → probs 降回 bf16)。
|
||||
flash 沿用同策略,最省心且数值最稳:bf16 模式下 `flash_attention` 把 Q/K/V `to_dtype(F32)` 喂给 fp32 kernel,
|
||||
`out` 再 `to_dtype(BF16)`;反向同理。kernel 本身只有一份 fp32 实现。这样 flash 的 bf16 数值与 composed 的
|
||||
bf16 数值是**同一套 fp32 softmax 算的**,只差 GEMM rounding(cuBLAS tensor-core vs kernel 内 fp32 FMA)→ 落在
|
||||
既有 bf16 容差内。`L` 始终 fp32。
|
||||
|
||||
> 备选(不采纳):bf16 全程 in-kernel half。收益是少两次 cast,但 (a) 引入与 composed 不同的 softmax 累加路径,
|
||||
> 威胁 on-vs-off 贴合闸门;(b) 本规模 attention 非瓶颈。escape hatch:先 fp32-core 把正确性钉死,纯 half flash 留 follow-up。
|
||||
|
||||
### ⑤ opt-in 透传:`use_flash` 是运行时旗标,不是架构
|
||||
|
||||
`use_flash` 不进 `Config`(它不改模型尺寸、不改导出、不该污染 `num_params`),而是 `TinyTransformer` 的一个
|
||||
`bool` 字段 + `with_flash(bool)` builder(对齐 `with_recompute` / `with_compute_dtype`)。`block_forward` 已经
|
||||
是 `(cfg, cdt, …)` 的自由函数(T13 为 recompute 抽的),给它加一个 `flash: bool` 形参,model 的 `attention()`
|
||||
据此选 `ops::attention`(composed)或 `ops::flash_attention`。recompute 闭包捕获 `flash`(`Copy`)→ **重算段内也走
|
||||
flash**,flash×recompute 组合天然成立。默认 `false` = composed 路径**逐字节不变**(硬闸门:默认图不变 → 不回归)。
|
||||
|
||||
## 验证方法
|
||||
|
||||
**硬闸门全绿(dash5 实跑 capture):**
|
||||
|
||||
### 1. 正确性
|
||||
|
||||
- **新 kernel dQ/dK/dV finite-diff grad-check**(`xtrain-autodiff/tests/autograd.rs::flash_attention_batched_bwd`):
|
||||
与既有 `attention_batched_bwd` 同构(`L = sum(W∘out)`,中心差分),断 dQ/dK/dV 在 `cfg_nonlinear`/`cfg_linear` 容差内。
|
||||
- **flash == composed**(`xtrain-model/tests/flash.rs`):同 init 两个模型(flash on/off),同一 batched
|
||||
loss + backward,断**前向 logits / loss / 每参数梯度**在紧容差内一致;参数化 fp32(近逐位)与 bf16(bf16 舍入级)。
|
||||
- **PyTorch SDPA 对拍 B>1**(`parity_dump.rs` + `parity.py`):等价 PyTorch 模型(per-seq RoPE、per-seq causal、
|
||||
QK-norm、SwiGLU)对拍 forward logits + 全部参数梯度——**composed 与 flash 两条都跑**,共用同一 PyTorch oracle。
|
||||
- **全回归套开/关 `--flash`**:autograd 15、structural、batched==looped、bf16、recompute(逐位)、overfit 27/27、
|
||||
AdamW(GPU bit-exact + host 对 torch)、DDP loss-match + 跨 rank、**xserv 闭环(导出 safetensors → md5 对 registry →
|
||||
xserv 贪心逐 token 一致)**。flag off 默认图不变 → composed 数值不回归。
|
||||
|
||||
### 2. 显存(payoff)—— 不物化 N×N 的直接收益
|
||||
|
||||
dash5 1× RTX 5090,同 config,nvidia-smi 峰值,flash off vs on:attention 反向缓存 `[bh,S,S]→[bh,S]`,
|
||||
峰值显存应↓(尤其 seq 大时)。capture 实际数字进表。
|
||||
|
||||
### 3. 吞吐
|
||||
|
||||
同 config steady-state tok/s flash off vs on。预期:本规模 `hd=32` 下 flash kernel **持平或略慢于** cuBLAS 双
|
||||
GEMM(小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。
|
||||
|
||||
## 实测结果(dash5,待 capture)
|
||||
|
||||
<!-- dash5 实跑后回填:grad-check 数字、flash-vs-composed rel-err、PyTorch 对拍、显存 before/after、tok/s before/after -->
|
||||
Reference in New Issue
Block a user