Files
xtrain/docs/13-flash-attention.md
Gahow Wang 9064ced4c2 docs: T14 flash-attention results + evolution/README rows
Fill in the design doc's measured results (grad-check, flash==composed,
PyTorch parity, peak mem -16%/-23%, tok/s tradeoff), add the T14 row to
evolution.md (算法/Infra) and the README build-journey table.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:34:10 +08:00

14 KiB
Raw Permalink Blame History

Phase T14: 融合 Flash-Attention Kernel — Design Document

Goal

T10 把 attention 批量化了,但它的 SDPA 走的是 「物化 N×N scores」 的组合路径: cublasSgemmStridedBatchedQ·Kᵀ→ 一个 causal-softmax kernel写出整张 probscublasSgemmStridedBatchedP·V3 次 launch + 一张 [bh, S, S] 的 scores/probs 张量 常驻显存(反向还要缓存这张 probs。S 一大,这张 N×N 就成了激活显存与带宽的主导项。

T14 的目标:手写一个单 kernel 的 fused flash-attention——streaming / online softmax、tiled over KV绝不物化 N×N。前向一发 kernel 直接吐出 out[bh,S,hd](外加 O(N) 的 logsumexp 反向一发 kernelflash 式:重算 scores + dQ/dK/dV同样不物化 N×N。接进 model + autograd 作 opt-in --flash,默认保留 T10 的 composed 路径以便 A/B。

硬闸门是诚实正确性:新 kernel 的 dQ/dK/dV finite-diff grad-check 过fwd/bwd 对现有 composed-SDPA 路径数值贴合(进 bf16 容差PyTorch SDPA 对拍 B>1峰值显存↓不物化 scores+ tok/s before/after 实测; 全回归套(含 xserv 闭环 md5开/关 flag 都绿——默认flag off图不变 → 不回归。

什么是 flash-attention

标准 attention 是 O = softmax(causal(Q·Kᵀ/√d)) · V,朴素实现把 S[i,j] = Qᵢ·Kⱼ/√d 整张 [S,S] 算出来、softmax、再乘 V——显存 O(S²)、HBM 读写 O(S²)

flash-attention 的洞察softmax 可以 onlinestreaming 地算。把 K/V 切成若干 tile,对一个 query 行 i,依次扫过 KV tilerunning max m + running sum l 维护 softmax 的归一化,并把 部分加权的 V 累加进一个 [hd] 的 accumulator acc,每来一个新 tile 就用「新旧 max 的差」对旧 acc/l 做 rescale。扫完所有 tileout = acc / l整张 [S,S] 从不落地——只有 [hd] 的 acc 和两个标量 在寄存器/共享内存里流动。峰值激活从 O(S²) 降到 O(S·hd)(就是 O 本身)。

online softmax 的核心递推block j 的部分 logits 行 s_j,旧状态 m, l, acc

m_new = max(m, max_k s_j[k])
p     = exp(s_j - m_new)                       # 本 tile 的未归一化权重
l     = l * exp(m - m_new) + sum(p)            # 旧 sum 先 rescale再加本 tile
acc   = acc * exp(m - m_new) + p · V_tile      # 旧 acc 同样 rescale再加本 tile 贡献
m     = m_new
# 扫完所有 tile
out   = acc / l
L     = m + log(l)                             # logsumefpO(N) 存给反向

因果 mask 内联query 全局位置 = i % S(沿用 T10 的 per-seq 复位约定KV 位置 j 满足 j > i%S 的列直接当 -infp=0。tile 整块在对角线之上可直接 skipcausal 的天然稀疏,省一半算力)。

反向flash 式,[Dao 2022] 的标准做法):不缓存 probs从 Q/K/V + 前向存的 L[bh,S] 重算 scores。 关键预计算 D[i] = Σ_d dOᵢ[d]·Oᵢ[d](每 query 一个标量,O(N)),则对每个 (i,j)

s_ij = Qᵢ·Kⱼ * scale                           # 重算 logit
p_ij = exp(s_ij - L[i])                         # 重算 softmax 权重L 是前向存的 logsumexp
dp_ij = dOᵢ · Vⱼ                                # 对 P 的梯度
ds_ij = p_ij * (dp_ij - D[i]) * scale           # softmax 雅可比,化简掉了显式 N×N
dQᵢ  += ds_ij * Kⱼ ;  dKⱼ += ds_ij * Qᵢ ;  dVⱼ += p_ij * dOᵢ

ds = P ∘ (dP - D) 是 softmax 反向用 Σⱼ Pⱼ·dPⱼ = D(因为 D[i]=Σ dOᵢ·Oᵢ = Σⱼ Pᵢⱼ dPᵢⱼ)化简的结果, 不需要 N×N 的 softmax 雅可比矩阵。同样 tiled、同样不物化 N×N。

Module Layoutsurgicalcomposed 路径逐字节不动flash 全程新增并行路径)

csrc/ops/flash_attention.cu              # 新fwd kernelonline softmaxtiled KV+ bwd kernel重算 + dQ/dK/dV
crates/xtrain-cuda/
├── src/ffi.rs                           # +launch_flash_attention_fwd_f32 / _bwd_f32 声明
└── build.rs                             # +flash_attention.cu
crates/xtrain-tensor/src/tensor.rs       # +Tensor::flash_attention / flash_attention_backwardfwd 存 logsumexp Lbf16 upcast→f32 kernel→downcast
crates/xtrain-autodiff/
├── src/ops.rs                           # +ops::flash_attention 节点(前向调 fwd缓存 L反向调 bwd
└── tests/autograd.rs                    # +flash_attention(batched) dQ/dK/dV grad-check
crates/xtrain-model/
├── src/model.rs                         # attention() 按 use_flash 选 ops::attention | ops::flash_attention+with_flash(bool) builderflash 标志透传 block_forwardrecompute 段内也走 flash
└── tests/flash.rs                       # 新flash == composedfwd logits + 每参数梯度),参数化 fp32/bf16
crates/xtrain-train/src/bin/train.rs     # +--flash flag → model.with_flash(true)
crates/xtrain-distributed/src/bin/train_ddp.rs  # +--flash flagDDP 路径)
crates/xtrain-model/tests/parity_dump.rs # PyTorch B>1 对拍跑两遍composed 与 flash共用 PyTorch oracle

Key Design Decisions

① 一个 block 负责一行 query先做对再谈快

最直接、最易验证正确的并行划分:grid = bh * S,每个 block 算一整行 query 的 out[bh, i, :]。 block 内 hd 个线程hd ≤ 128正好一个 warp 多一点),共享 m/l 标量 + acc[hd]。block 顺序扫 KV tiletile 宽 BK,沿 j 维),每个 tile线程并行算 BK 个 logit点积 over hd 用 block-reduce、 求 tile max、online-rescale m/l/acc、累加 p·V。扫完写 out = acc/lL[i] = m + log(l)

为什么先这样而不是 FA2 的 query-tile 划分:本项目的硬闸门是正确性 + 不物化 N×N + 显存↓,不是 打榜峰值 FLOPs。一行一 block 的版本:(a) online softmax 与 N×N skip 已经完全落地(显存与带宽收益拿到), (b) 代码直白、逐 query 行可对拍,正确性风险最低。它不会比 cuBLAS 两发 GEMM 更快cuBLAS tensor-core 吃满),所以 tok/s 上 flash 在我们这种 hd=32 小头维下大概率持平或略慢——这正是 flash 的已知权衡 flash 的胜场是显存,不是小模型的 wall-clock。把这点诚实写进 perf 表,不掩饰。

② 前向只存 L[bh,S]logsumefp不存 probs

composed 路径反向要缓存整张 probs[bh,S,S]O(N²)。flash 反向只需要前向的 logsumexp L[i]=m_i+log(l_i)(每 query 一个 fp32O(N))即可重算任意 p_ij = exp(Qᵢ·Kⱼ·scale - L[i])。 所以 fwd kernel 顺手把 L 写出来autograd 节点缓存它(外加 Q/K/V/O parents 本就在)。这就是显存闸门的来源 attention 的反向缓存从 [bh,S,S] 砍到 [bh,S]

③ 反向用 D[i]=Σ dOᵢ·Oᵢ 化简 softmax 雅可比

softmax 反向通项 ds_ij = p_ij·(dp_ij - Σ_k p_ik·dp_ik)。注意 Σ_k p_ik·dp_ik = Σ_k p_ik (dOᵢ·V_k) = dOᵢ·(Σ_k p_ik V_k) = dOᵢ·Oᵢ = D[i]。所以一趟先算 D[bh,S](每行 dO·O 的点积,O(N)),反向 扫 KV tile 时直接 ds = p·(dp - D)·scale不需要再算或物化整行的 Σ p·dp。 dQ/dK/dV 三者dQ 由「该 query 行」累加block 私有无竞争dK/dV 跨 query 行累加同一个 (j) → 用 atomicAdd 到全局 dK/dVfp32 原子加,确定 race-free

④ bf16kernel 内 fp32边界 cast与 composed 路径一致的数值策略)

T10/T12 的 composed attention 对 bf16 也是 softmax 用 fp32scores 升 f32 → kernel → probs 降回 bf16。 flash 沿用同策略最省心且数值最稳bf16 模式下 flash_attention 把 Q/K/V to_dtype(F32) 喂给 fp32 kernel outto_dtype(BF16)反向同理。kernel 本身只有一份 fp32 实现。这样 flash 的 bf16 数值与 composed 的 bf16 数值是同一套 fp32 softmax 算的,只差 GEMM roundingcuBLAS tensor-core vs kernel 内 fp32 FMA→ 落在 既有 bf16 容差内。L 始终 fp32。

备选不采纳bf16 全程 in-kernel half。收益是少两次 cast但 (a) 引入与 composed 不同的 softmax 累加路径, 威胁 on-vs-off 贴合闸门;(b) 本规模 attention 非瓶颈。escape hatch先 fp32-core 把正确性钉死,纯 half flash 留 follow-up。

⑤ opt-in 透传:use_flash 是运行时旗标,不是架构

use_flash 不进 Config(它不改模型尺寸、不改导出、不该污染 num_params),而是 TinyTransformer 的一个 bool 字段 + with_flash(bool) builder对齐 with_recompute / with_compute_dtype)。block_forward 已经 是 (cfg, cdt, …) 的自由函数T13 为 recompute 抽的),给它加一个 flash: bool 形参model 的 attention() 据此选 ops::attentioncomposedops::flash_attention。recompute 闭包捕获 flashCopy)→ 重算段内也走 flashflash×recompute 组合天然成立。默认 false = composed 路径逐字节不变(硬闸门:默认图不变 → 不回归)。

验证方法

硬闸门全绿dash5 实跑 capture

1. 正确性

  • 新 kernel dQ/dK/dV finite-diff grad-checkxtrain-autodiff/tests/autograd.rs::flash_attention_batched_bwd 与既有 attention_batched_bwd 同构(L = sum(W∘out),中心差分),断 dQ/dK/dV 在 cfg_nonlinear/cfg_linear 容差内。
  • flash == composedxtrain-model/tests/flash.rs):同 init 两个模型flash on/off同一 batched loss + backward前向 logits / loss / 每参数梯度在紧容差内一致;参数化 fp32近逐位与 bf16bf16 舍入级)。
  • PyTorch SDPA 对拍 B>1parity_dump.rs + parity.py):等价 PyTorch 模型per-seq RoPE、per-seq causal、 QK-norm、SwiGLU对拍 forward logits + 全部参数梯度——composed 与 flash 两条都跑,共用同一 PyTorch oracle。
  • 全回归套开/关 --flashautograd 15、structural、batched==looped、bf16、recompute逐位、overfit 27/27、 AdamWGPU bit-exact + host 对 torch、DDP loss-match + 跨 rank、xserv 闭环(导出 safetensors → md5 对 registry → xserv 贪心逐 token 一致)。flag off 默认图不变 → composed 数值不回归。

2. 显存payoff—— 不物化 N×N 的直接收益

dash5 1× RTX 5090同 confignvidia-smi 峰值flash off vs onattention 反向缓存 [bh,S,S]→[bh,S] 峰值显存应↓(尤其 seq 大时。capture 实际数字进表。

3. 吞吐

同 config steady-state tok/s flash off vs on。预期本规模 hd=32 下 flash kernel 持平或略慢于 cuBLAS 双 GEMM小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。

实测结果dash5 1× RTX 5090

正确性(硬闸门全绿):

闸门 结果
① 新 kernel dQ/dK/dV finite-diff grad-check — dQ 9.3e-3 / dK 1.7e-2 / dV 5.6e-4单 tile 干净区;多 tile 由②兜)
flash fwd 对 composed max rel 6.7e-5
flash bwd 对(已 grad-check 的composed bwd dQ 1.7e-5 / dK 1.2e-5 / dV 4.3e-5
② flash==composedmodel 级logits/loss/每参数梯度) fp32: loss rel 0.0、logits 1.7e-4、grad 4.4e-5bf16: loss 1.5e-4、logits mean 1.6e-3/p99 5.9e-3、grad scaled-mean 1.2e-2
③ PyTorch SDPA 对拍 B>1flash 路径,共用 composed oracle loss relerr 4.98e-8、logits 7.92e-6、25 参数 grad 全进 rtol 0.02
⑤ 回归套flag off 默认 + flash 路径都测autograd 18 / structural 5 / batched / bf16 / flash 3 / overfit 27/27 / recompute 2 / AdamW(GPU+host) / GEMM / DDP 2 / checkpoint-roundtrip 全绿
⑤ xserv 闭环 md5v3 ckpt 用 T14 代码重导 safetensors 逐位一致 b04fc9f9a0c9af04c47d9ca649aea12e(与 registry 同)→ 默认 export 零漂移
⑤ xserv 闭环flash 训练 → 导出 → xserv 服务贪心) flash-训出 coherent TinyStoriesxserv(BF16) 对 xtrain(F32) 贪心3 prompt 中 "One day" 逐 token 一致,其余在 ~0.5% BF16 漂移处晚分叉(与 v1/v2/v3 同款)

finite-diff 的诚实记录:长 softmaxseq>tile会产生大量近零梯度元素中心差分在那些元素上不可靠出现伪 0.0 / 符号翻转——不是 backward bug。故 ① 的 finite-diff 跑单 tile 干净区seq=5对齐既有 composed grad-check 的良态区),多 tile 的 streaming/online 路径用「flash bwd 对已 grad-check 的 composed bwd」seq=40dQ 1.7e-5兜——比 finite-diff 更利。dQ/dK 用 eps=2e-3 压低 f32 舍入项(~4e-4 小梯度上舍入项压过截断项)。没有为凑绿放宽容差

④ 显存 + 吞吐payoff vs tradeoffdim768=8L/12h×64/ffn3072, bf16, steady-state

config path 峰值显存 tok/s
batch8 seq1024 composed (off) 24670 MiB 58.6K
batch8 seq1024 flash (on) 20736 MiB16% 25.0K57%, ~2.3× 慢)
batch2 seq2048 composed (off) 17264 MiB 36.7K
batch2 seq2048 flash (on) 13246 MiB23% 13.2K64%

显存按预期降(不物化 [bh,S,S]),且收益随 seq 增长seq1024 16% → seq2048 23%O(S²) 砍掉)。 tok/s 如设计 ① 预测的「持平或略慢」实为 ~2.32.8×hd=64 的小头维下,手写「一行一 block + 串行扫 KV」kernel 喂不满 SM干不过 cuBLAS tensor-core 的两发批量 GEMM——这正是 flash 的已知权衡(胜场在显存,不是小模型 wall-clock诚实报告不掩饰。两个落地的优化softmax 权重缓存进 shared 省 hd× 的 expfdK/dV 原子加摊到全 block 而非串行在列 owner 内)把 backward 从 6.8× 慢拉到 2.3× 慢——主瓶颈是 backward 的跨行原子累加FA2 用 K-block 拥有 dK/dV 的独立 pass 解,本版未做,留 follow-up

escape hatchfollow-up未做记给后续:① FA2 式 query-tile 划分(一 block 多 query 行K/V 进 shared 复用)提 SM 占用;② backward 的 dK/dV 改 K-block-owned 独立 pass 消跨行原子;③ 纯 bf16 in-kernel省两次 cast。本规模 attention 非训练瓶颈、且会动数值贴合闸门,按 escape hatch 推迟——T14 先把**正确性 + 不物化 N×N + 显存↓**钉死。