# Phase T16: Gradient Accumulation — Design Document ## Goal 在已有的训练 loop(T6/T10)与 DDP(T8)之上,加 **micro-batch 梯度累积**:把 `accum_steps=N` 个 **micro-step** 的梯度在 tape 里累加起来,再做**一次** `AdamW.step` + `zero_grad`——得到 **有效 batch = N × micro_batch** 的更新,而显存只占**一个 micro-batch** 的激活峰值(不随 N 增长)。 两条硬约束: 1. **数值等效**:`accum_steps=N`(N 个 micro-step 后一次 step)必须对住「一个 N× 大 batch 的单 step」——梯度/loss 在仓内既有容差内**逐位贴合**。这是核心等效性证明。 2. **DDP 只在累积边界通信**:`world>1` 下,N 个 micro-step 里**只在最后一个**做 all-reduce (中间 micro-step **跳过跨卡通信**),最终喂给优化器的仍是 global 有效 batch 的均值梯度, loss 对单卡。 并暴露 train 入口的 `--accum-steps` flag。`accum_steps=1` 必须对当前无累积路径**逐位一致** (回归保护)。 **不做**:micro-batch 间变 LR / 变 batch(恒定 micro_batch);累积里换 dropout RNG(T18 才有 dropout);ZeRO(T17)。本 Phase 只动**优化器 step 的节奏**与 **DDP 通信门控**,复用 tape 既有 的 SUM 累加。 ## Module Layout ``` crates/xtrain-train/src/ ├── train_loop.rs # TrainConfig += accum_steps;inner micro-loop(缩放 loss + tape SUM) └── bin/train.rs # 新 --accum-steps flag;打印有效 batch crates/xtrain-distributed/src/ └── ddp.rs # DdpConfig += accum_steps;all-reduce 门控到累积边界 crates/xtrain-train/tests/ └── grad_accum.rs # 等效性硬闸门 + accum_steps=1 逐位回归(单卡) crates/xtrain-distributed/tests/ └── ddp_correctness.rs # += DDP+accum 对单卡(复用既有 ddp_matches… 框架) docs/15-grad-accum.md # 本文 ``` 无新 crate、无新 kernel、无新 autograd op——梯度累积是**纯调度**:tape 早已 SUM 累加, 缩放用既有 `ops::scale`,DDP 通信用既有 `all_reduce_average_grads`,只是改**调用节奏与门控**。 ## Key Design Decisions ### ① 等效性的数学:缩放每个 micro-loss 为 `1/N` 模型的 `loss_batched` 是 **CE-mean over `batch*seq` 行**(见 `model.rs`)。设一个 micro-batch 有 `B` 序列、seq 长 `S`,记某 micro-step 那批 `B*S` 行的 per-row 梯度之和为 `Σ_micro`: - **大 batch 基线**(有效 batch `N·B`):一次 `loss_batched(N·B 序列)` = CE-mean over `N·B·S` 行 → backward 给 `G_big = Σ_all / (N·B·S)`,其中 `Σ_all = Σ_n Σ_micro_n`。 - **累积**(N 个 micro-step,每个 `B`):micro-step n 的 `loss_batched(B)` = CE-mean over `B·S` 行 → 若直接 backward 得 `Σ_micro_n / (B·S)`;**N 个 backward 之间不 `zero_grad`**,tape SUM 累加 → `Σ_n Σ_micro_n / (B·S) = Σ_all / (B·S) = N · G_big`。 差一个因子 N。修正:**每个 micro-loss 先 `ops::scale(loss, 1/N)` 再 backward**——`scale` 的 backward 把上游梯度乘 `1/N`(见 `ops.rs`),于是每个 micro 贡献 `Σ_micro_n / (N·B·S)`, 累积后 `Σ_all / (N·B·S) = G_big`,**与大 batch 逐位等效**(仅 fp 求和顺序不同 → 进容差,和 T8 DDP-vs-单卡同性质)。 > 为什么不在 clip 里用 `pre_scale=1/N`?clip 的 `pre_scale` 已被 batch-mean 占用(=1.0)。 > 在 loss 上 `scale(1/N)` 更内聚:缩放穿过既有 autograd,不碰 clip/optimizer,且 `N=1` 时 > `scale(1.0)` 的 backward 是恒等乘 1 —— 这正是 `accum_steps=1` 逐位回归的保证(见 ④)。 报告的 step-loss = N 个 micro 的**原始** loss(未缩放值)之和 / N = 有效 batch 的 mean loss, 和大 batch 的单一 mean loss 一致(同样仅求和顺序差)。 ### ② 单卡 train loop:inner micro-loop 每个 optimizer step: ```text for micro in 0..N: 抽 B 序列 → loss = loss_batched(B) step_loss_acc += raw_loss(loss) # 累报告用的原始 loss scale(loss, 1/N).backward() # tape SUM 累加缩放后的梯度 # —— 累积边界 —— clip_grad_norm_gpu(params, max_norm, 1.0) # 梯度已是有效 batch 均值 opt.step(lr); zero_grad() losses.push(step_loss_acc / N) tokens_seen += N * B * S # 有效 batch tok ``` `accum_steps` 默认 1 → micro-loop 跑一次、`scale(loss,1.0)`、不在 micro 间 zero_grad(本就如此) → 与现路径完全等价。**每个 micro-step 的计算图在它自己的 backward 后即可释放**(Rust `Rc` 在 循环变量出作用域时 drop),所以**显存峰值 = 单个 micro-batch 的激活**,不随 N 增长(③ 实测)。 抽样次序保持:单卡仍是连续从 RNG 抽 `N·B` 序列;与「大 batch 抽 `N·B`」逐序列对齐,只是分 N 组 forward——并集同序,所以 `Σ_all` 的项一致。 ### ③ 显存平 + 有效 batch 实测 「显存不随 N 增长」是 grad-accum 的卖点,要**实测**而非断言:固定有效 batch `E = N·B`,跑 `(N=1,B=E)`(大 batch)vs `(N=E,B=1)`(极端累积),用 `nvidia-smi`/`cudaMemGetInfo` 量峰值显存—— 后者应**显著低**(少 N× 激活)。train 入口打印 `effective batch = accum_steps × batch`。 ### ④ `accum_steps=1` 逐位回归 `N=1` 时 inner loop 跑一次、`scale(loss, 1.0)`。`ops::scale(_, 1.0)` 的 fwd 是 `value.scale(1.0)`、bwd 是 `grad.scale(1.0)`——数学恒等。为**绝对**逐位(连一次 `×1.0` kernel 都不引入),实现里 `N==1` 直接 `loss.backward()`(跳过 scale),与现路径**字节一致**。测试 `accum1_bit_identical_to_no_accum` 锁这条。 ### ⑤ DDP:all-reduce 门控到累积边界 T8 的 `all_reduce_average_grads(params)` 每 step 调一次。grad-accum 下**只在最后一个 micro-step 之后调一次**——中间 micro-step 的 backward 只在本卡 tape 里 SUM,**不发 NCCL**。 均值的账(沿用 T8 的「通信里 /world,clip 里 /b_local」拆分,再叠加 ① 的 /N): ```text 每卡每 micro: scale(loss, 1/N).backward() → 本卡 tape SUM 该 micro 的 (Σ_micro / N)/... N 个 micro 后, 本卡 grad = Σ_{micro∈本卡所有micro} ... = 本卡 N·B_local 行的 (1/N) 缩放和 all-reduce(sum)+/world (累积边界一次): 跨卡求和后 /world → 每卡持有 Σ_global,(N·B) / (N · world · ?) # 见下:用 1/N·scale 替代单卡的 1/b clip pre_scale = 1.0 ``` 精确推导:每卡每 micro 的 `loss_batched(B_local)` 是 **本卡 mean over `B_local·S` 行**。 `scale(1/N)` 后 backward = `Σ_local_micro / (N · B_local · S)`。N 个 micro tape SUM → `Σ_local_all / (N · B_local · S)`,其中 `Σ_local_all` = 本卡 `N·B_local` 行之和。 `all_reduce(sum)` 跨 world 卡 → `Σ_global_all / (N · B_local · S)`(`Σ_global_all` = 全 `world·N·B_local = N·B_global` 行之和);`/world` → `Σ_global_all / (N · B_local · S · world)` `= Σ_global_all / (N · B_global · S)`(因 `B_global = world·B_local`)。这正是**有效 batch `N·B_global` 的 mean 梯度**——与单卡「有效 batch `N·B_global` 的大 batch 单 step」逐位等效 (求和顺序差进容差)。 > 关键正确性点:`all_reduce_average_grads` 里的 `/world` 是按 **world** 缩放(与 N 无关);N 的 > 那个 `1/N` 已由 ① 的 `scale` 在每个 micro 的 backward 里完成。两者正交,不会互相污染。 > 单卡(`world=1`)退化:all-reduce 是 no-op,`/world=1`,只剩 ① 的 `1/N` → 与 ② 一致。 DDP 报告 loss = N 个 micro 的本卡原始 loss·B_local 之和、跨卡 all-reduce(sum)、/(N·B_global)。 ### ⑥ 不变量小结 | | 单卡基线(大 batch E) | 单卡 accum(N×B=E) | DDP accum(world, N×B_local·world=E) | |---|---|---|---| | loss 缩放 | 无(CE-mean) | 每 micro `×1/N` | 每 micro `×1/N` | | grad 累加 | tape SUM 一批 | tape SUM N 批 | tape SUM N 批/卡 | | 跨卡通信 | — | — | **仅累积边界 1 次** all-reduce + /world | | clip pre_scale | 1.0 | 1.0 | 1.0 | | 显存峰值 | E 的激活 | **B 的激活** | **B_local 的激活** | ## 验证方法(验收,全部 dash5 实跑 capture) GPU 测试 `#[cfg(not(no_cuda))]` 门控。 1. **等效性(核心硬闸门)** `grad_accum.rs::accum_equiv_big_batch`:同 init、同数据同序, 跑「`accum_steps=N`, micro_batch=B」与「`accum_steps=1`, batch=N·B」各一 step,断言 ①loss、②**每个参数的 grad** rel-err 进 fp 容差(求和顺序差,~1e-4 量级,对齐 recompute/DDP 闸门约定)。多步版(跑 K 个 optimizer step)再断言**终参**贴合(误差不发散)。 2. **`accum_steps=1` 逐位回归** `grad_accum.rs::accum1_bit_identical`:`accum_steps=1` 与现 no-accum 路径同 init/同数据 → 每参数 grad `max|Δ| == 0.0`(④ 跳过 scale,字节一致)。 3. **DDP+accum 对单卡** `ddp_correctness.rs`(扩既有 `ddp_matches_single_gpu…`):单卡 有效 batch `E` 的大 batch baseline vs `world=2 + accum_steps=N`(每卡每 micro `B_local`, `world·N·B_local=E`)→ loss 轨迹 `max_rel<1e-3`、跨 rank 参数一致、且 only-at-boundary 通信 (micro 间不发 NCCL,由实现保证 + 不变量推导)。 4. **显存平 + 有效 batch** :固定有效 batch,量 `(N=1,大batch)` vs `(N=大,micro=1)` 峰值显存 (后者显著低),train 入口打印 effective batch。capture nvidia-smi。 5. **全回归套**:autograd grad-check / structural / batched==looped / bf16 / recompute(逐位)/ overfit 27/27 / AdamW(GPU bit-exact + host vs torch)/ DDP loss-match + 跨 rank / **xserv 闭环 md5**——`accum_steps=1` 默认值保证全部不回归。