Files
xtrain/docs/15-grad-accum.md
2026-06-17 23:41:17 +08:00

166 lines
9.6 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T16: Gradient Accumulation — Design Document
## Goal
在已有的训练 loopT6/T10与 DDPT8之上**micro-batch 梯度累积**:把 `accum_steps=N`
**micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到
**有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。
两条硬约束:
1. **数值等效**`accum_steps=N`N 个 micro-step 后一次 step必须对住「一个 N× 大 batch
的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。
2. **DDP 只在累积边界通信**`world>1`N 个 micro-step 里**只在最后一个**做 all-reduce
(中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度,
loss 对单卡。
并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致**
(回归保护)。
**不做**micro-batch 间变 LR / 变 batch恒定 micro_batch累积里换 dropout RNGT18 才有
dropoutZeROT17。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有
的 SUM 累加。
## Module Layout
```
crates/xtrain-train/src/
├── train_loop.rs # TrainConfig += accum_stepsinner micro-loop缩放 loss + tape SUM
└── bin/train.rs # 新 --accum-steps flag打印有效 batch
crates/xtrain-distributed/src/
└── ddp.rs # DdpConfig += accum_stepsall-reduce 门控到累积边界
crates/xtrain-train/tests/
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
crates/xtrain-distributed/tests/
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
docs/15-grad-accum.md # 本文
```
无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**tape 早已 SUM 累加,
缩放用既有 `ops::scale`DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。
## Key Design Decisions
### ① 等效性的数学:缩放每个 micro-loss 为 `1/N`
模型的 `loss_batched`**CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有
`B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`
- **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S`
→ backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`
- **累积**N 个 micro-step每个 `B`micro-step n 的 `loss_batched(B)` = CE-mean over `B·S`
→ 若直接 backward 得 `Σ_micro_n / (B·S)`**N 个 backward 之间不 `zero_grad`**tape SUM 累加 →
`Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`
差一个因子 N。修正**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale`
backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`
累积后 `Σ_all / (N·B·S) = G_big`**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和
T8 DDP-vs-单卡同性质)。
> 为什么不在 clip 里用 `pre_scale=1/N`clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。
> 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd不碰 clip/optimizer且 `N=1` 时
> `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。
报告的 step-loss = N 个 micro 的**原始** loss未缩放值之和 / N = 有效 batch 的 mean loss
和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
### ② 单卡 train loopinner micro-loop
每个 optimizer step
```text
for micro in 0..N:
抽 B 序列 → loss = loss_batched(B)
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
# —— 累积边界 ——
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
opt.step(lr); zero_grad()
losses.push(step_loss_acc / N)
tokens_seen += N * B * S # 有效 batch tok
```
`accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad本就如此
→ 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**Rust `Rc`
循环变量出作用域时 drop所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。
抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组
forward——并集同序所以 `Σ_all` 的项一致。
### ③ 显存平 + 有效 batch 实测
「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑
`(N=1,B=E)`(大 batchvs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存——
后者应**显著低**(少 N× 激活。train 入口打印 `effective batch = accum_steps × batch`
### ④ `accum_steps=1` 逐位回归
`N=1` 时 inner loop 跑一次、`scale(loss, 1.0)``ops::scale(_, 1.0)` 的 fwd 是
`value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel
都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale与现路径**字节一致**。测试
`accum1_bit_identical_to_no_accum` 锁这条。
### ⑤ DDPall-reduce 门控到累积边界
T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个
micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM**不发 NCCL**。
均值的账(沿用 T8 的「通信里 /worldclip 里 /b_local」拆分再叠加 ① 的 /N
```text
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
clip pre_scale = 1.0
```
精确推导:每卡每 micro 的 `loss_batched(B_local)`**本卡 mean over `B_local·S` 行**
`scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM →
`Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。
`all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)``Σ_global_all` = 全
`world·N·B_local = N·B_global` 行之和);`/world``Σ_global_all / (N · B_local · S · world)`
`= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch
`N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效
(求和顺序差进容差)。
> 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关N 的
> 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。
> 单卡(`world=1`退化all-reduce 是 no-op`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
### ⑥ 不变量小结
| | 单卡基线(大 batch E | 单卡 accumN×B=E | DDP accumworld, N×B_local·world=E |
|---|---|---|---|
| loss 缩放 | 无CE-mean | 每 micro `×1/N` | 每 micro `×1/N` |
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
| 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world |
| clip pre_scale | 1.0 | 1.0 | 1.0 |
| 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** |
## 验证方法(验收,全部 dash5 实跑 capture
GPU 测试 `#[cfg(not(no_cuda))]` 门控。
1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序
跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step断言
①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP
闸门约定)。多步版(跑 K 个 optimizer step再断言**终参**贴合(误差不发散)。
2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical``accum_steps=1` 与现
no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale字节一致
3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡
有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`
`world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信
micro 间不发 NCCL由实现保证 + 不变量推导)。
4. **显存平 + 有效 batch** :固定有效 batch`(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存
后者显著低train 入口打印 effective batch。capture nvidia-smi。
5. **全回归套**autograd grad-check / structural / batched==looped / bf16 / recompute逐位/
overfit 27/27 / AdamWGPU bit-exact + host vs torch/ DDP loss-match + 跨 rank / **xserv
闭环 md5**——`accum_steps=1` 默认值保证全部不回归。