Files
xtrain/docs/15-grad-accum.md
2026-06-17 23:41:17 +08:00

9.6 KiB
Raw Permalink Blame History

Phase T16: Gradient Accumulation — Design Document

Goal

在已有的训练 loopT6/T10与 DDPT8之上micro-batch 梯度累积:把 accum_steps=Nmicro-step 的梯度在 tape 里累加起来,再做一次 AdamW.step + zero_grad——得到 有效 batch = N × micro_batch 的更新,而显存只占一个 micro-batch 的激活峰值(不随 N 增长)。

两条硬约束:

  1. 数值等效accum_steps=NN 个 micro-step 后一次 step必须对住「一个 N× 大 batch 的单 step」——梯度/loss 在仓内既有容差内逐位贴合。这是核心等效性证明。
  2. DDP 只在累积边界通信world>1N 个 micro-step 里只在最后一个做 all-reduce (中间 micro-step 跳过跨卡通信),最终喂给优化器的仍是 global 有效 batch 的均值梯度, loss 对单卡。

并暴露 train 入口的 --accum-steps flag。accum_steps=1 必须对当前无累积路径逐位一致 (回归保护)。

不做micro-batch 间变 LR / 变 batch恒定 micro_batch累积里换 dropout RNGT18 才有 dropoutZeROT17。本 Phase 只动优化器 step 的节奏DDP 通信门控,复用 tape 既有 的 SUM 累加。

Module Layout

crates/xtrain-train/src/
├── train_loop.rs        # TrainConfig += accum_stepsinner micro-loop缩放 loss + tape SUM
└── bin/train.rs         # 新 --accum-steps flag打印有效 batch

crates/xtrain-distributed/src/
└── ddp.rs               # DdpConfig += accum_stepsall-reduce 门控到累积边界

crates/xtrain-train/tests/
└── grad_accum.rs        # 等效性硬闸门 + accum_steps=1 逐位回归(单卡)

crates/xtrain-distributed/tests/
└── ddp_correctness.rs   # += DDP+accum 对单卡(复用既有 ddp_matches… 框架)

docs/15-grad-accum.md     # 本文

无新 crate、无新 kernel、无新 autograd op——梯度累积是纯调度tape 早已 SUM 累加, 缩放用既有 ops::scaleDDP 通信用既有 all_reduce_average_grads,只是改调用节奏与门控

Key Design Decisions

① 等效性的数学:缩放每个 micro-loss 为 1/N

模型的 loss_batchedCE-mean over batch*seq(见 model.rs)。设一个 micro-batch 有 B 序列、seq 长 S,记某 micro-step 那批 B*S 行的 per-row 梯度之和为 Σ_micro

  • 大 batch 基线(有效 batch N·B):一次 loss_batched(N·B 序列) = CE-mean over N·B·S 行 → backward 给 G_big = Σ_all / (N·B·S),其中 Σ_all = Σ_n Σ_micro_n
  • 累积N 个 micro-step每个 Bmicro-step n 的 loss_batched(B) = CE-mean over B·S 行 → 若直接 backward 得 Σ_micro_n / (B·S)N 个 backward 之间不 zero_gradtape SUM 累加 → Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big

差一个因子 N。修正每个 micro-loss 先 ops::scale(loss, 1/N) 再 backward——scale 的 backward 把上游梯度乘 1/N(见 ops.rs),于是每个 micro 贡献 Σ_micro_n / (N·B·S) 累积后 Σ_all / (N·B·S) = G_big与大 batch 逐位等效(仅 fp 求和顺序不同 → 进容差,和 T8 DDP-vs-单卡同性质)。

为什么不在 clip 里用 pre_scale=1/Nclip 的 pre_scale 已被 batch-mean 占用(=1.0)。 在 loss 上 scale(1/N) 更内聚:缩放穿过既有 autograd不碰 clip/optimizerN=1scale(1.0) 的 backward 是恒等乘 1 —— 这正是 accum_steps=1 逐位回归的保证(见 ④)。

报告的 step-loss = N 个 micro 的原始 loss未缩放值之和 / N = 有效 batch 的 mean loss 和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。

② 单卡 train loopinner micro-loop

每个 optimizer step

for micro in 0..N:
    抽 B 序列 → loss = loss_batched(B)
    step_loss_acc += raw_loss(loss)              # 累报告用的原始 loss
    scale(loss, 1/N).backward()                  # tape SUM 累加缩放后的梯度
# —— 累积边界 ——
clip_grad_norm_gpu(params, max_norm, 1.0)        # 梯度已是有效 batch 均值
opt.step(lr); zero_grad()
losses.push(step_loss_acc / N)
tokens_seen += N * B * S                         # 有效 batch tok

accum_steps 默认 1 → micro-loop 跑一次、scale(loss,1.0)、不在 micro 间 zero_grad本就如此 → 与现路径完全等价。每个 micro-step 的计算图在它自己的 backward 后即可释放Rust Rc 在 循环变量出作用域时 drop所以显存峰值 = 单个 micro-batch 的激活,不随 N 增长(③ 实测)。

抽样次序保持:单卡仍是连续从 RNG 抽 N·B 序列;与「大 batch 抽 N·B」逐序列对齐,只是分 N 组 forward——并集同序所以 Σ_all 的项一致。

③ 显存平 + 有效 batch 实测

「显存不随 N 增长」是 grad-accum 的卖点,要实测而非断言:固定有效 batch E = N·B,跑 (N=1,B=E)(大 batchvs (N=E,B=1)(极端累积),用 nvidia-smi/cudaMemGetInfo 量峰值显存—— 后者应显著低(少 N× 激活。train 入口打印 effective batch = accum_steps × batch

accum_steps=1 逐位回归

N=1 时 inner loop 跑一次、scale(loss, 1.0)ops::scale(_, 1.0) 的 fwd 是 value.scale(1.0)、bwd 是 grad.scale(1.0)——数学恒等。为绝对逐位(连一次 ×1.0 kernel 都不引入),实现里 N==1 直接 loss.backward()(跳过 scale与现路径字节一致。测试 accum1_bit_identical_to_no_accum 锁这条。

⑤ DDPall-reduce 门控到累积边界

T8 的 all_reduce_average_grads(params) 每 step 调一次。grad-accum 下只在最后一个 micro-step 之后调一次——中间 micro-step 的 backward 只在本卡 tape 里 SUM不发 NCCL

均值的账(沿用 T8 的「通信里 /worldclip 里 /b_local」拆分再叠加 ① 的 /N

每卡每 micro:  scale(loss, 1/N).backward()  → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/...
N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和
all-reduce(sum)+/world (累积边界一次):  跨卡求和后 /world
→ 每卡持有 Σ_global,(N·B) / (N · world · ?)   # 见下:用 1/N·scale 替代单卡的 1/b
clip pre_scale = 1.0

精确推导:每卡每 micro 的 loss_batched(B_local)本卡 mean over B_local·Sscale(1/N) 后 backward = Σ_local_micro / (N · B_local · S)。N 个 micro tape SUM → Σ_local_all / (N · B_local · S),其中 Σ_local_all = 本卡 N·B_local 行之和。 all_reduce(sum) 跨 world 卡 → Σ_global_all / (N · B_local · S)Σ_global_all = 全 world·N·B_local = N·B_global 行之和);/worldΣ_global_all / (N · B_local · S · world) = Σ_global_all / (N · B_global · S)(因 B_global = world·B_local)。这正是有效 batch N·B_global 的 mean 梯度——与单卡「有效 batch N·B_global 的大 batch 单 step」逐位等效 (求和顺序差进容差)。

关键正确性点:all_reduce_average_grads 里的 /world 是按 world 缩放(与 N 无关N 的 那个 1/N 已由 ① 的 scale 在每个 micro 的 backward 里完成。两者正交,不会互相污染。 单卡(world=1退化all-reduce 是 no-op/world=1,只剩 ① 的 1/N → 与 ② 一致。

DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。

⑥ 不变量小结

单卡基线(大 batch E 单卡 accumN×B=E DDP accumworld, N×B_local·world=E
loss 缩放 CE-mean 每 micro ×1/N 每 micro ×1/N
grad 累加 tape SUM 一批 tape SUM N 批 tape SUM N 批/卡
跨卡通信 仅累积边界 1 次 all-reduce + /world
clip pre_scale 1.0 1.0 1.0
显存峰值 E 的激活 B 的激活 B_local 的激活

验证方法(验收,全部 dash5 实跑 capture

GPU 测试 #[cfg(not(no_cuda))] 门控。

  1. 等效性(核心硬闸门) grad_accum.rs::accum_equiv_big_batch:同 init、同数据同序 跑「accum_steps=N, micro_batch=B」与「accum_steps=1, batch=N·B」各一 step断言 ①loss、②每个参数的 grad rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP 闸门约定)。多步版(跑 K 个 optimizer step再断言终参贴合(误差不发散)。
  2. accum_steps=1 逐位回归 grad_accum.rs::accum1_bit_identicalaccum_steps=1 与现 no-accum 路径同 init/同数据 → 每参数 grad max|Δ| == 0.0(④ 跳过 scale字节一致
  3. DDP+accum 对单卡 ddp_correctness.rs(扩既有 ddp_matches_single_gpu…):单卡 有效 batch E 的大 batch baseline vs world=2 + accum_steps=N(每卡每 micro B_local world·N·B_local=E)→ loss 轨迹 max_rel<1e-3、跨 rank 参数一致、且 only-at-boundary 通信 micro 间不发 NCCL由实现保证 + 不变量推导)。
  4. 显存平 + 有效 batch :固定有效 batch(N=1,大batch) vs (N=大,micro=1) 峰值显存 后者显著低train 入口打印 effective batch。capture nvidia-smi。
  5. 全回归套autograd grad-check / structural / batched==looped / bf16 / recompute逐位/ overfit 27/27 / AdamWGPU bit-exact + host vs torch/ DDP loss-match + 跨 rank / xserv 闭环 md5——accum_steps=1 默认值保证全部不回归。