From 9064ced4c256061896febb0dffcb7a36212ff71d Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 17 Jun 2026 23:34:10 +0800 Subject: [PATCH] docs: T14 flash-attention results + evolution/README rows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fill in the design doc's measured results (grad-check, flash==composed, PyTorch parity, peak mem -16%/-23%, tok/s tradeoff), add the T14 row to evolution.md (算法/Infra) and the README build-journey table. Co-Authored-By: Claude Opus 4.8 --- README.md | 5 ++++- docs/13-flash-attention.md | 31 +++++++++++++++++++++++++++++-- docs/evolution.md | 5 +++-- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e93cd5f..f50d2c9 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,12 @@ Each phase: design doc + implementation + tests + a scoped commit (see [`docs/`] | **T11** | **device caching allocator** (fixes KI-5) | single-GPU 2.3×; **8-GPU 461K tok/s** | | **T12** | **bf16 mixed precision** (fp32 master, fixes KI-2) | dim768 OOM solved; −29% mem | | **T13** | **activation recompute** / checkpointing (fixes KI-3) | dim1024 fits; grads bit-identical | +| **T14** | **fused flash-attention** kernel (online softmax, no materialized N×N; opt-in `--flash`) | peak mem −16%@1k / −23%@2k seq; flash==composed (grads/PyTorch) | The four performance fixes (T10–T13) each removed a real bottleneck — see -[`docs/known-issues.md`](docs/known-issues.md). +[`docs/known-issues.md`](docs/known-issues.md). **Phase 2 (systems-stack depth, T14–)** +revisits hand-writing deferred training-stack features; T14 = the fused +flash-attention kernel ([`docs/13-flash-attention.md`](docs/13-flash-attention.md)). ## The scaling study — v0 → v8 diff --git a/docs/13-flash-attention.md b/docs/13-flash-attention.md index b630d71..b418b28 100644 --- a/docs/13-flash-attention.md +++ b/docs/13-flash-attention.md @@ -151,6 +151,33 @@ dash5 1× RTX 5090,同 config,nvidia-smi 峰值,flash off vs on:attentio 同 config steady-state tok/s flash off vs on。预期:本规模 `hd=32` 下 flash kernel **持平或略慢于** cuBLAS 双 GEMM(小头维喂不满 tensor-core 是 flash 的已知权衡,胜场在显存)——诚实报告,不为绿而调。 -## 实测结果(dash5,待 capture) +## 实测结果(dash5 1× RTX 5090) - +**正确性(硬闸门全绿):** + +| 闸门 | 结果 | +|---|---| +| ① 新 kernel dQ/dK/dV finite-diff grad-check | **过** — dQ 9.3e-3 / dK 1.7e-2 / dV 5.6e-4(单 tile 干净区;多 tile 由②兜) | +| flash fwd 对 composed | max rel **6.7e-5** | +| flash bwd 对(已 grad-check 的)composed bwd | dQ **1.7e-5** / dK 1.2e-5 / dV 4.3e-5 | +| ② flash==composed(model 级,logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 1.7e-4、grad 4.4e-5;bf16: loss 1.5e-4、logits mean 1.6e-3/p99 5.9e-3、grad scaled-mean 1.2e-2 | +| ③ PyTorch SDPA 对拍 B>1(flash 路径,共用 composed oracle) | loss relerr **4.98e-8**、logits **7.92e-6**、25 参数 grad 全进 rtol 0.02 | +| ⑤ 回归套(flag off 默认 + flash 路径都测):autograd 18 / structural 5 / batched / bf16 / **flash 3** / overfit 27/27 / recompute 2 / AdamW(GPU+host) / GEMM / DDP 2 / checkpoint-roundtrip | **全绿** | +| ⑤ xserv 闭环 md5(v3 ckpt 用 T14 代码重导 safetensors) | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry 同)→ 默认 export 零漂移 | +| ⑤ xserv 闭环(flash 训练 → 导出 → xserv 服务贪心) | flash-训出 coherent TinyStories;xserv(BF16) 对 xtrain(F32) 贪心:3 prompt 中 "One day" 逐 token 一致,其余在 ~0.5% BF16 漂移处晚分叉(与 v1/v2/v3 同款) | + +> **finite-diff 的诚实记录**:长 softmax(seq>tile)会产生大量近零梯度元素,中心差分在那些元素上不可靠(出现伪 0.0 / 符号翻转——不是 backward bug)。故 ① 的 finite-diff 跑**单 tile 干净区**(seq=5,对齐既有 composed grad-check 的良态区),**多 tile 的 streaming/online 路径**用「flash bwd 对已 grad-check 的 composed bwd」(seq=40,dQ 1.7e-5)兜——比 finite-diff 更利。dQ/dK 用 eps=2e-3 压低 f32 舍入项(~4e-4 小梯度上舍入项压过截断项)。**没有为凑绿放宽容差**。 + +**④ 显存 + 吞吐(payoff vs tradeoff,dim768=8L/12h×64/ffn3072, bf16, steady-state):** + +| config | path | 峰值显存 | tok/s | +|---|---|---|---| +| batch8 seq1024 | composed (off) | 24670 MiB | **58.6K** | +| batch8 seq1024 | **flash (on)** | **20736 MiB(−16%)** | 25.0K(−57%, ~2.3× 慢) | +| batch2 seq2048 | composed (off) | 17264 MiB | 36.7K | +| batch2 seq2048 | **flash (on)** | **13246 MiB(−23%)** | 13.2K(−64%) | + +→ **显存按预期降**(不物化 `[bh,S,S]`),且**收益随 seq 增长**(seq1024 −16% → seq2048 −23%,O(S²) 砍掉)。 +**tok/s 如设计 ① 预测的「持平或略慢」实为 ~2.3–2.8× 慢**:hd=64 的小头维下,手写「一行一 block + 串行扫 KV」kernel 喂不满 SM,干不过 cuBLAS tensor-core 的两发批量 GEMM——这正是 flash 的已知权衡(**胜场在显存,不是小模型 wall-clock**),诚实报告不掩饰。两个落地的优化(softmax 权重缓存进 shared 省 hd× 的 expf;dK/dV 原子加摊到全 block 而非串行在列 owner 内)把 backward 从 6.8× 慢拉到 2.3× 慢——主瓶颈是 backward 的跨行原子累加(FA2 用 K-block 拥有 dK/dV 的独立 pass 解,本版未做,留 follow-up)。 + +> **escape hatch(follow-up,未做,记给后续)**:① FA2 式 query-tile 划分(一 block 多 query 行,K/V 进 shared 复用)提 SM 占用;② backward 的 dK/dV 改 K-block-owned 独立 pass 消跨行原子;③ 纯 bf16 in-kernel(省两次 cast)。本规模 attention 非训练瓶颈、且会动数值贴合闸门,按 escape hatch 推迟——T14 先把**正确性 + 不物化 N×N + 显存↓**钉死。 diff --git a/docs/evolution.md b/docs/evolution.md index da5aa27..e98ea3b 100644 --- a/docs/evolution.md +++ b/docs/evolution.md @@ -24,6 +24,7 @@ | T11 | Infra | **device caching/pool allocator**(复用 op 输出显存,消 per-step cudaMalloc) | 单卡 2.3×;**8卡 461K tok/s** 近线性(修 KI-5) | | T12 | 算法/Infra | **bf16 混合精度**(fp32 master,cuBLAS GemmEx,norm/softmax/CE 保 fp32) | dim768 OOM 解除,−29% 显存/+13% tok/s(修 KI-2) | | T13 | 算法/Infra | **激活重计算**(per-block gradient checkpointing:前向 no-tape + 反向重算,`backward_seeded`) | 梯度对非重计算版**逐位一致**(0.00);dim768 31.1→14.6GB;**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3,解锁 v8) | +| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernel:online softmax、tiled over KV、**不物化 N×N scores**;flash 式 bwd:重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dV);opt-in `--flash`,默认保 composed(Phase 2) | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0;**峰值显存 −16%@seq1024 / −23%@seq2048**(不物化 N×N,收益随 seq 增长);tok/s ~2.3–2.8× 慢(hd=64 小头维干不过 cuBLAS tensor-core,flash 已知权衡=胜场在显存);md5 闭环逐位一致 | --- @@ -49,9 +50,9 @@ ## 三、各维度的累积演进(轴向看一条线怎么走的) -- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13)。 +- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14,online softmax + flash 式 bwd)。 - **模型架构**:固定 Qwen3-style;dim **32→256→384→512→768→1024**(v8 首拨容量轴,头数 24→32);核心参数 **41K→226M**(总 3.26M→329M)。 -- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。 +- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13,解锁 dim1024) → flash-attention(T14,不物化 N×N,attention 显存收益随 seq 增长)。吞吐 **3.3K→217K tok/s**(dim768 bf16),dim1024+重算 ~129K(重算税);MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。 - **数据集**:TinyStories 3MB 切片 → 全量 TinyStories(epoch 0.01→5.33,**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**(2.255B 语料,1.02ep)→ **v7 同子集多 epoch(1.45ep,近顶)→ v8 同子集换大模型**(dim1024,1.05ep)。tokenizer 全程 gpt2 BPE(复用 xserv-tokenizer;v6 刻意不换 tokenizer 以隔离「数据来源」变量,KI-4 留后续版本)。 - **v5→v6 数据轴的质变**:v0–v5 都吃合成幼儿故事(TinyStories,低熵、词汇受控),v5 证明同尺寸模型在它上面已饱和;v6 第一版换成**真实教育类网页文本**(FineWeb-edu),语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。 - ⚠️ **同子集多 epoch 也有天花板(v6→v7)**:v6 的 FineWeb val 才训 1.02ep、末步仍单调降,曾被读作「还没喂够」;v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B token),FineWeb val 仅 ↓0.05(3.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象**:**「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」,不在「同数据多读几遍」**。后续要继续降 val,必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。