166 lines
9.6 KiB
Markdown
166 lines
9.6 KiB
Markdown
# Phase T16: Gradient Accumulation — Design Document
|
||
|
||
## Goal
|
||
|
||
在已有的训练 loop(T6/T10)与 DDP(T8)之上,加 **micro-batch 梯度累积**:把 `accum_steps=N`
|
||
个 **micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到
|
||
**有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。
|
||
|
||
两条硬约束:
|
||
|
||
1. **数值等效**:`accum_steps=N`(N 个 micro-step 后一次 step)必须对住「一个 N× 大 batch
|
||
的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。
|
||
2. **DDP 只在累积边界通信**:`world>1` 下,N 个 micro-step 里**只在最后一个**做 all-reduce
|
||
(中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度,
|
||
loss 对单卡。
|
||
|
||
并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致**
|
||
(回归保护)。
|
||
|
||
**不做**:micro-batch 间变 LR / 变 batch(恒定 micro_batch);累积里换 dropout RNG(T18 才有
|
||
dropout);ZeRO(T17)。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有
|
||
的 SUM 累加。
|
||
|
||
## Module Layout
|
||
|
||
```
|
||
crates/xtrain-train/src/
|
||
├── train_loop.rs # TrainConfig += accum_steps;inner micro-loop(缩放 loss + tape SUM)
|
||
└── bin/train.rs # 新 --accum-steps flag;打印有效 batch
|
||
|
||
crates/xtrain-distributed/src/
|
||
└── ddp.rs # DdpConfig += accum_steps;all-reduce 门控到累积边界
|
||
|
||
crates/xtrain-train/tests/
|
||
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
|
||
|
||
crates/xtrain-distributed/tests/
|
||
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
|
||
|
||
docs/15-grad-accum.md # 本文
|
||
```
|
||
|
||
无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**:tape 早已 SUM 累加,
|
||
缩放用既有 `ops::scale`,DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。
|
||
|
||
## Key Design Decisions
|
||
|
||
### ① 等效性的数学:缩放每个 micro-loss 为 `1/N`
|
||
|
||
模型的 `loss_batched` 是 **CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有
|
||
`B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`:
|
||
|
||
- **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S` 行
|
||
→ backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`。
|
||
- **累积**(N 个 micro-step,每个 `B`):micro-step n 的 `loss_batched(B)` = CE-mean over `B·S` 行
|
||
→ 若直接 backward 得 `Σ_micro_n / (B·S)`;**N 个 backward 之间不 `zero_grad`**,tape SUM 累加 →
|
||
`Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`。
|
||
|
||
差一个因子 N。修正:**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale` 的
|
||
backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`,
|
||
累积后 `Σ_all / (N·B·S) = G_big`,**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和
|
||
T8 DDP-vs-单卡同性质)。
|
||
|
||
> 为什么不在 clip 里用 `pre_scale=1/N`?clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。
|
||
> 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd,不碰 clip/optimizer,且 `N=1` 时
|
||
> `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。
|
||
|
||
报告的 step-loss = N 个 micro 的**原始** loss(未缩放值)之和 / N = 有效 batch 的 mean loss,
|
||
和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
|
||
|
||
### ② 单卡 train loop:inner micro-loop
|
||
|
||
每个 optimizer step:
|
||
|
||
```text
|
||
for micro in 0..N:
|
||
抽 B 序列 → loss = loss_batched(B)
|
||
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
|
||
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
|
||
# —— 累积边界 ——
|
||
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
|
||
opt.step(lr); zero_grad()
|
||
losses.push(step_loss_acc / N)
|
||
tokens_seen += N * B * S # 有效 batch tok
|
||
```
|
||
|
||
`accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad(本就如此)
|
||
→ 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**(Rust `Rc` 在
|
||
循环变量出作用域时 drop),所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。
|
||
|
||
抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组
|
||
forward——并集同序,所以 `Σ_all` 的项一致。
|
||
|
||
### ③ 显存平 + 有效 batch 实测
|
||
|
||
「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑
|
||
`(N=1,B=E)`(大 batch)vs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存——
|
||
后者应**显著低**(少 N× 激活)。train 入口打印 `effective batch = accum_steps × batch`。
|
||
|
||
### ④ `accum_steps=1` 逐位回归
|
||
|
||
`N=1` 时 inner loop 跑一次、`scale(loss, 1.0)`。`ops::scale(_, 1.0)` 的 fwd 是
|
||
`value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel
|
||
都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale),与现路径**字节一致**。测试
|
||
`accum1_bit_identical_to_no_accum` 锁这条。
|
||
|
||
### ⑤ DDP:all-reduce 门控到累积边界
|
||
|
||
T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个
|
||
micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM,**不发 NCCL**。
|
||
|
||
均值的账(沿用 T8 的「通信里 /world,clip 里 /b_local」拆分,再叠加 ① 的 /N):
|
||
|
||
```text
|
||
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
|
||
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
|
||
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
|
||
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
|
||
clip pre_scale = 1.0
|
||
```
|
||
|
||
精确推导:每卡每 micro 的 `loss_batched(B_local)` 是 **本卡 mean over `B_local·S` 行**。
|
||
`scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM →
|
||
`Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。
|
||
`all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)`(`Σ_global_all` = 全
|
||
`world·N·B_local = N·B_global` 行之和);`/world` → `Σ_global_all / (N · B_local · S · world)`
|
||
`= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch
|
||
`N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效
|
||
(求和顺序差进容差)。
|
||
|
||
> 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关);N 的
|
||
> 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。
|
||
> 单卡(`world=1`)退化:all-reduce 是 no-op,`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。
|
||
|
||
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
|
||
|
||
### ⑥ 不变量小结
|
||
|
||
| | 单卡基线(大 batch E) | 单卡 accum(N×B=E) | DDP accum(world, N×B_local·world=E) |
|
||
|---|---|---|---|
|
||
| loss 缩放 | 无(CE-mean) | 每 micro `×1/N` | 每 micro `×1/N` |
|
||
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
|
||
| 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world |
|
||
| clip pre_scale | 1.0 | 1.0 | 1.0 |
|
||
| 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** |
|
||
|
||
## 验证方法(验收,全部 dash5 实跑 capture)
|
||
|
||
GPU 测试 `#[cfg(not(no_cuda))]` 门控。
|
||
|
||
1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序,
|
||
跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step,断言
|
||
①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP
|
||
闸门约定)。多步版(跑 K 个 optimizer step)再断言**终参**贴合(误差不发散)。
|
||
2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical`:`accum_steps=1` 与现
|
||
no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale,字节一致)。
|
||
3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡
|
||
有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`,
|
||
`world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信
|
||
(micro 间不发 NCCL,由实现保证 + 不变量推导)。
|
||
4. **显存平 + 有效 batch** :固定有效 batch,量 `(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存
|
||
(后者显著低),train 入口打印 effective batch。capture nvidia-smi。
|
||
5. **全回归套**:autograd grad-check / structural / batched==looped / bf16 / recompute(逐位)/
|
||
overfit 27/27 / AdamW(GPU bit-exact + host vs torch)/ DDP loss-match + 跨 rank / **xserv
|
||
闭环 md5**——`accum_steps=1` 默认值保证全部不回归。
|