9.6 KiB
Phase T16: Gradient Accumulation — Design Document
Goal
在已有的训练 loop(T6/T10)与 DDP(T8)之上,加 micro-batch 梯度累积:把 accum_steps=N
个 micro-step 的梯度在 tape 里累加起来,再做一次 AdamW.step + zero_grad——得到
有效 batch = N × micro_batch 的更新,而显存只占一个 micro-batch 的激活峰值(不随 N 增长)。
两条硬约束:
- 数值等效:
accum_steps=N(N 个 micro-step 后一次 step)必须对住「一个 N× 大 batch 的单 step」——梯度/loss 在仓内既有容差内逐位贴合。这是核心等效性证明。 - DDP 只在累积边界通信:
world>1下,N 个 micro-step 里只在最后一个做 all-reduce (中间 micro-step 跳过跨卡通信),最终喂给优化器的仍是 global 有效 batch 的均值梯度, loss 对单卡。
并暴露 train 入口的 --accum-steps flag。accum_steps=1 必须对当前无累积路径逐位一致
(回归保护)。
不做:micro-batch 间变 LR / 变 batch(恒定 micro_batch);累积里换 dropout RNG(T18 才有 dropout);ZeRO(T17)。本 Phase 只动优化器 step 的节奏与 DDP 通信门控,复用 tape 既有 的 SUM 累加。
Module Layout
crates/xtrain-train/src/
├── train_loop.rs # TrainConfig += accum_steps;inner micro-loop(缩放 loss + tape SUM)
└── bin/train.rs # 新 --accum-steps flag;打印有效 batch
crates/xtrain-distributed/src/
└── ddp.rs # DdpConfig += accum_steps;all-reduce 门控到累积边界
crates/xtrain-train/tests/
└── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)
crates/xtrain-distributed/tests/
└── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)
docs/15-grad-accum.md # 本文
无新 crate、无新 kernel、无新 autograd op——梯度累积是纯调度:tape 早已 SUM 累加,
缩放用既有 ops::scale,DDP 通信用既有 all_reduce_average_grads,只是改调用节奏与门控。
Key Design Decisions
① 等效性的数学:缩放每个 micro-loss 为 1/N
模型的 loss_batched 是 CE-mean over batch*seq 行(见 model.rs)。设一个 micro-batch 有
B 序列、seq 长 S,记某 micro-step 那批 B*S 行的 per-row 梯度之和为 Σ_micro:
- 大 batch 基线(有效 batch
N·B):一次loss_batched(N·B 序列)= CE-mean overN·B·S行 → backward 给G_big = Σ_all / (N·B·S),其中Σ_all = Σ_n Σ_micro_n。 - 累积(N 个 micro-step,每个
B):micro-step n 的loss_batched(B)= CE-mean overB·S行 → 若直接 backward 得Σ_micro_n / (B·S);N 个 backward 之间不zero_grad,tape SUM 累加 →Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big。
差一个因子 N。修正:每个 micro-loss 先 ops::scale(loss, 1/N) 再 backward——scale 的
backward 把上游梯度乘 1/N(见 ops.rs),于是每个 micro 贡献 Σ_micro_n / (N·B·S),
累积后 Σ_all / (N·B·S) = G_big,与大 batch 逐位等效(仅 fp 求和顺序不同 → 进容差,和
T8 DDP-vs-单卡同性质)。
为什么不在 clip 里用
pre_scale=1/N?clip 的pre_scale已被 batch-mean 占用(=1.0)。 在 loss 上scale(1/N)更内聚:缩放穿过既有 autograd,不碰 clip/optimizer,且N=1时scale(1.0)的 backward 是恒等乘 1 —— 这正是accum_steps=1逐位回归的保证(见 ④)。
报告的 step-loss = N 个 micro 的原始 loss(未缩放值)之和 / N = 有效 batch 的 mean loss, 和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。
② 单卡 train loop:inner micro-loop
每个 optimizer step:
for micro in 0..N:
抽 B 序列 → loss = loss_batched(B)
step_loss_acc += raw_loss(loss) # 累报告用的原始 loss
scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度
# —— 累积边界 ——
clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值
opt.step(lr); zero_grad()
losses.push(step_loss_acc / N)
tokens_seen += N * B * S # 有效 batch tok
accum_steps 默认 1 → micro-loop 跑一次、scale(loss,1.0)、不在 micro 间 zero_grad(本就如此)
→ 与现路径完全等价。每个 micro-step 的计算图在它自己的 backward 后即可释放(Rust Rc 在
循环变量出作用域时 drop),所以显存峰值 = 单个 micro-batch 的激活,不随 N 增长(③ 实测)。
抽样次序保持:单卡仍是连续从 RNG 抽 N·B 序列;与「大 batch 抽 N·B」逐序列对齐,只是分 N 组
forward——并集同序,所以 Σ_all 的项一致。
③ 显存平 + 有效 batch 实测
「显存不随 N 增长」是 grad-accum 的卖点,要实测而非断言:固定有效 batch E = N·B,跑
(N=1,B=E)(大 batch)vs (N=E,B=1)(极端累积),用 nvidia-smi/cudaMemGetInfo 量峰值显存——
后者应显著低(少 N× 激活)。train 入口打印 effective batch = accum_steps × batch。
④ accum_steps=1 逐位回归
N=1 时 inner loop 跑一次、scale(loss, 1.0)。ops::scale(_, 1.0) 的 fwd 是
value.scale(1.0)、bwd 是 grad.scale(1.0)——数学恒等。为绝对逐位(连一次 ×1.0 kernel
都不引入),实现里 N==1 直接 loss.backward()(跳过 scale),与现路径字节一致。测试
accum1_bit_identical_to_no_accum 锁这条。
⑤ DDP:all-reduce 门控到累积边界
T8 的 all_reduce_average_grads(params) 每 step 调一次。grad-accum 下只在最后一个
micro-step 之后调一次——中间 micro-step 的 backward 只在本卡 tape 里 SUM,不发 NCCL。
均值的账(沿用 T8 的「通信里 /world,clip 里 /b_local」拆分,再叠加 ① 的 /N):
每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world
→ 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b
clip pre_scale = 1.0
精确推导:每卡每 micro 的 loss_batched(B_local) 是 本卡 mean over B_local·S 行。
scale(1/N) 后 backward = Σ_local_micro / (N · B_local · S)。N 个 micro tape SUM →
Σ_local_all / (N · B_local · S),其中 Σ_local_all = 本卡 N·B_local 行之和。
all_reduce(sum) 跨 world 卡 → Σ_global_all / (N · B_local · S)(Σ_global_all = 全
world·N·B_local = N·B_global 行之和);/world → Σ_global_all / (N · B_local · S · world)
= Σ_global_all / (N · B_global · S)(因 B_global = world·B_local)。这正是有效 batch
N·B_global 的 mean 梯度——与单卡「有效 batch N·B_global 的大 batch 单 step」逐位等效
(求和顺序差进容差)。
关键正确性点:
all_reduce_average_grads里的/world是按 world 缩放(与 N 无关);N 的 那个1/N已由 ① 的scale在每个 micro 的 backward 里完成。两者正交,不会互相污染。 单卡(world=1)退化:all-reduce 是 no-op,/world=1,只剩 ① 的1/N→ 与 ② 一致。
DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。
⑥ 不变量小结
| 单卡基线(大 batch E) | 单卡 accum(N×B=E) | DDP accum(world, N×B_local·world=E) | |
|---|---|---|---|
| loss 缩放 | 无(CE-mean) | 每 micro ×1/N |
每 micro ×1/N |
| grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 |
| 跨卡通信 | — | — | 仅累积边界 1 次 all-reduce + /world |
| clip pre_scale | 1.0 | 1.0 | 1.0 |
| 显存峰值 | E 的激活 | B 的激活 | B_local 的激活 |
验证方法(验收,全部 dash5 实跑 capture)
GPU 测试 #[cfg(not(no_cuda))] 门控。
- 等效性(核心硬闸门)
grad_accum.rs::accum_equiv_big_batch:同 init、同数据同序, 跑「accum_steps=N, micro_batch=B」与「accum_steps=1, batch=N·B」各一 step,断言 ①loss、②每个参数的 grad rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP 闸门约定)。多步版(跑 K 个 optimizer step)再断言终参贴合(误差不发散)。 accum_steps=1逐位回归grad_accum.rs::accum1_bit_identical:accum_steps=1与现 no-accum 路径同 init/同数据 → 每参数 gradmax|Δ| == 0.0(④ 跳过 scale,字节一致)。- DDP+accum 对单卡
ddp_correctness.rs(扩既有ddp_matches_single_gpu…):单卡 有效 batchE的大 batch baseline vsworld=2 + accum_steps=N(每卡每 microB_local,world·N·B_local=E)→ loss 轨迹max_rel<1e-3、跨 rank 参数一致、且 only-at-boundary 通信 (micro 间不发 NCCL,由实现保证 + 不变量推导)。 - 显存平 + 有效 batch :固定有效 batch,量
(N=1,大batch)vs(N=大,micro=1)峰值显存 (后者显著低),train 入口打印 effective batch。capture nvidia-smi。 - 全回归套:autograd grad-check / structural / batched==looped / bf16 / recompute(逐位)/
overfit 27/27 / AdamW(GPU bit-exact + host vs torch)/ DDP loss-match + 跨 rank / xserv
闭环 md5——
accum_steps=1默认值保证全部不回归。