Merge t16-grad-accum into main

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

# Conflicts:
#	README.md
#	docs/evolution.md
This commit is contained in:
2026-06-18 00:37:11 +08:00
10 changed files with 697 additions and 48 deletions

165
docs/15-grad-accum.md Normal file
View File

@@ -0,0 +1,165 @@
# Phase T16: Gradient Accumulation — Design Document
## Goal
在已有的训练 loopT6/T10与 DDPT8之上**micro-batch 梯度累积**:把 `accum_steps=N`
**micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到
**有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。
两条硬约束:
1. **数值等效**`accum_steps=N`N 个 micro-step 后一次 step必须对住「一个 N× 大 batch
的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。
2. **DDP 只在累积边界通信**`world>1`N 个 micro-step 里**只在最后一个**做 all-reduce
(中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度,
loss 对单卡。
并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致**
(回归保护)。
**不做**micro-batch 间变 LR / 变 batch恒定 micro_batch累积里换 dropout RNGT18 才有
dropoutZeROT17。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有
的 SUM 累加。
## Module Layout
```
crates/xtrain-train/src/
├── train_loop.rs # TrainConfig += accum_stepsinner micro-loop缩放 loss + tape SUM
└── bin/train.rs # 新 --accum-steps flag打印有效 batch
crates/xtrain-distributed/src/
└── ddp.rs # DdpConfig += accum_stepsall-reduce 门控到累积边界
crates/xtrain-train/tests/
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
crates/xtrain-distributed/tests/
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
docs/15-grad-accum.md # 本文
```
无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**tape 早已 SUM 累加,
缩放用既有 `ops::scale`DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。
## Key Design Decisions
### ① 等效性的数学:缩放每个 micro-loss 为 `1/N`
模型的 `loss_batched`**CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有
`B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`
- **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S`
→ backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`
- **累积**N 个 micro-step每个 `B`micro-step n 的 `loss_batched(B)` = CE-mean over `B·S`
→ 若直接 backward 得 `Σ_micro_n / (B·S)`**N 个 backward 之间不 `zero_grad`**tape SUM 累加 →
`Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`
差一个因子 N。修正**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale`
backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`
累积后 `Σ_all / (N·B·S) = G_big`**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和
T8 DDP-vs-单卡同性质)。
> 为什么不在 clip 里用 `pre_scale=1/N`clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。
> 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd不碰 clip/optimizer且 `N=1` 时
> `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。
报告的 step-loss = N 个 micro 的**原始** loss未缩放值之和 / N = 有效 batch 的 mean loss
和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
### ② 单卡 train loopinner micro-loop
每个 optimizer step
```text
for micro in 0..N:
抽 B 序列 → loss = loss_batched(B)
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
# —— 累积边界 ——
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
opt.step(lr); zero_grad()
losses.push(step_loss_acc / N)
tokens_seen += N * B * S # 有效 batch tok
```
`accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad本就如此
→ 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**Rust `Rc`
循环变量出作用域时 drop所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。
抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组
forward——并集同序所以 `Σ_all` 的项一致。
### ③ 显存平 + 有效 batch 实测
「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑
`(N=1,B=E)`(大 batchvs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存——
后者应**显著低**(少 N× 激活。train 入口打印 `effective batch = accum_steps × batch`
### ④ `accum_steps=1` 逐位回归
`N=1` 时 inner loop 跑一次、`scale(loss, 1.0)``ops::scale(_, 1.0)` 的 fwd 是
`value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel
都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale与现路径**字节一致**。测试
`accum1_bit_identical_to_no_accum` 锁这条。
### ⑤ DDPall-reduce 门控到累积边界
T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个
micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM**不发 NCCL**。
均值的账(沿用 T8 的「通信里 /worldclip 里 /b_local」拆分再叠加 ① 的 /N
```text
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
clip pre_scale = 1.0
```
精确推导:每卡每 micro 的 `loss_batched(B_local)`**本卡 mean over `B_local·S` 行**
`scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM →
`Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。
`all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)``Σ_global_all` = 全
`world·N·B_local = N·B_global` 行之和);`/world``Σ_global_all / (N · B_local · S · world)`
`= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch
`N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效
(求和顺序差进容差)。
> 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关N 的
> 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。
> 单卡(`world=1`退化all-reduce 是 no-op`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
### ⑥ 不变量小结
| | 单卡基线(大 batch E | 单卡 accumN×B=E | DDP accumworld, N×B_local·world=E |
|---|---|---|---|
| loss 缩放 | 无CE-mean | 每 micro `×1/N` | 每 micro `×1/N` |
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
| 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world |
| clip pre_scale | 1.0 | 1.0 | 1.0 |
| 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** |
## 验证方法(验收,全部 dash5 实跑 capture
GPU 测试 `#[cfg(not(no_cuda))]` 门控。
1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序
跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step断言
①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP
闸门约定)。多步版(跑 K 个 optimizer step再断言**终参**贴合(误差不发散)。
2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical``accum_steps=1` 与现
no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale字节一致
3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡
有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`
`world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信
micro 间不发 NCCL由实现保证 + 不变量推导)。
4. **显存平 + 有效 batch** :固定有效 batch`(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存
后者显著低train 入口打印 effective batch。capture nvidia-smi。
5. **全回归套**autograd grad-check / structural / batched==looped / bf16 / recompute逐位/
overfit 27/27 / AdamWGPU bit-exact + host vs torch/ DDP loss-match + 跨 rank / **xserv
闭环 md5**——`accum_steps=1` 默认值保证全部不回归。

View File

@@ -7,7 +7,7 @@
---
## 一、基建 phaseT1T13—— 主要动「算法」与「Infra」
## 一、基建 phaseT1T13 + Phase 2 systems-depth)—— 主要动「算法」与「Infra」
| Phase | 维度 | 变化 | 结果 / 验证 |
|---|---|---|---|
@@ -25,6 +25,8 @@
| T12 | 算法/Infra | **bf16 混合精度**fp32 mastercuBLAS GemmExnorm/softmax/CE 保 fp32 | dim768 OOM 解除29% 显存/+13% tok/s修 KI-2 |
| T13 | 算法/Infra | **激活重计算**per-block gradient checkpointing前向 no-tape + 反向重算,`backward_seeded` | 梯度对非重计算版**逐位一致**(0.00)dim768 31.1→14.6GB**dim1024 batch32 OOM→16.6GB 装下**(修 KI-3解锁 v8 |
| T14 | 算法/Infra | **融合 flash-attention kernel**(手写单 kernelonline softmax、tiled over KV、**不物化 N×N scores**flash 式 bwd重算 scores + `D=ΣdO·O` 化简雅可比 + dQ/dK/dVopt-in `--flash`,默认保 composedPhase 2 | fwd 对 composed 6.7e-5、bwd 对 composed dQ 1.7e-5、PyTorch B>1 7.9e-6、flash==composed loss rel 0.0**峰值显存 16%@seq1024 / 23%@seq2048**(不物化 N×N收益随 seq 增长tok/s ~2.32.8×hd=64 小头维干不过 cuBLAS tensor-coreflash 已知权衡=胜场在显存md5 闭环逐位一致 |
| T16 | 算法/Infra | **梯度累积**N 个 micro-step每个 micro-loss `×1/N` 再 backwardtape SUM 累加 → 一次 AdamW step+zero`--accum-steps`**DDP 只在累积边界 all-reduce**(中间 micro-step 不发 NCCL`/world``1/N` 正交);显存随 micro 不随有效 batch | 等效大 batch**逐位贴合**loss rel 8.5e-8、grad rel 3.8e-5`accum=1` 逐位回归(0.00)DDP+accum 对单卡 loss 5.7e-7/跨 rank 一致;**显存平**:同有效 batch 64big-batch 27.7GB→accum(4×16) **7.2GB(74%)**big-batch OOM 而 accum 装下);全回归+xserv 闭环 md5 一致 |
| T18 | 算法 | **dropout**(手写 counter-based 设备 RNG → Bernoulli mask训练 inverted 1/(1-p) scaling、eval 恒等);新 autodiff `dropout` 算子fwd 生成+施加 maskbwd 用同 mask接 residual/ffn 两处;`--dropout` flag 默认 0 | 固定 seed grad-check 过E[out]≈input + keep≈1-p**p=0 与无 dropout 逐位一致**recompute(T13) 组合下梯度仍逐位一致counter-based seed 重算复现同 mask全回归 + xserv 闭环绿(导出/推理 dropout 关) |
---
@@ -50,9 +52,9 @@
## 三、各维度的累积演进(轴向看一条线怎么走的)
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14online softmax + flash 式 bwd)。
- **算法**:手写 autograd(tape)+扇出累加 → AdamW/LR-sched/grad-clip → +QK-norm(Qwen3) → batched forward → bf16 混合精度(fp32 master) → 激活重计算(T13) → 融合 flash-attention(T14online softmax + flash 式 bwd) → 梯度累积(T16复用 tape SUM等效大 batch 而显存随 micro) → dropout(T18counter-based 设备 RNG + inverted scalingtrain/eval 切换)
- **模型架构**:固定 Qwen3-styledim **32→256→384→512→768→1024**v8 首拨容量轴,头数 24→32核心参数 **41K→226M**(总 3.26M→329M
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13解锁 dim1024) → flash-attention(T14不物化 N×Nattention 显存收益随 seq 增长)。吞吐 **3.3K→217K tok/s**dim768 bf16dim1024+重算 ~129K重算税MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。
- **Infra**:单卡 fp32 → cuBLAS/GPU-optim(T7) → NCCL DDP(T8) → batched forward(T10) → caching allocator(T11) → bf16(T12) → 激活重计算(T13解锁 dim1024) → flash-attention(T14不物化 N×Nattention 显存收益随 seq 增长) → 梯度累积(T16DDP 只在累积边界通信,显存随 micro 不随有效 batch)。吞吐 **3.3K→217K tok/s**dim768 bf16dim1024+重算 ~129K重算税MFU **0.4%→17%**(每次提升都对应一块 perf 基建,详见 known-issues + MFU 分析)。T13/T14/T16 是三条**显存杠杆**重计算压激活峰值、flash 不物化 N×N attention scores、梯度累积解耦有效 batch 与激活显存),可叠加放大有效 batch。
- **数据集**TinyStories 3MB 切片 → 全量 TinyStoriesepoch 0.01→5.33**至饱和**)→ **v6 毕业到 FineWeb-edu 真实网页**2.255B 语料1.02ep)→ **v7 同子集多 epoch1.45ep,近顶)→ v8 同子集换大模型**dim10241.05ep。tokenizer 全程 gpt2 BPE复用 xserv-tokenizerv6 刻意不换 tokenizer 以隔离「数据来源」变量KI-4 留后续版本)。
- **v5→v6 数据轴的质变**v0v5 都吃合成幼儿故事TinyStories低熵、词汇受控v5 证明同尺寸模型在它上面已饱和v6 第一版换成**真实教育类网页文本**FineWeb-edu语言种类发生质变——采样从「只会写小故事」变成「能写历史/科学/说明文」。
- ⚠️ **同子集多 epoch 也有天花板v6→v7**v6 的 FineWeb val 才训 1.02ep、末步仍单调降曾被读作「还没喂够」v7 把**同一 2.255B 子集**喂到 1.45ep(多 ~1B tokenFineWeb val 仅 ↓0.053.07→3.01)且 ~step44000 后走平、采样无质变 ⇒ **该子集在 dim768 已近天花板**。这与 v5 的 TinyStories 数据量饱和是**同一类现象****「重复喂老数据」边际都薄,无论是 v5 的同语料多 epoch 还是 v7 的同子集多 epoch**。真正抬天花板的是 v6「换更广的新语料」那一步——**杠杆在「更多样的新 token」不在「同数据多读几遍」**。后续要继续降 val必须补**新 FineWeb shards**(更多样、不重复),不是同子集加 epoch。