Design doc for per-block gradient checkpointing (KI-3): the no-tape forward + recompute-on-backward design, the `checkpoint` primitive, per-block wrapping, the exactness/correctness argument (same kernels + inputs → identical grads), composition with bf16+DDP+batched, and the verification plan (on-vs-off grad gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD to fill after the dash5 run. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
8.4 KiB
Phase T13: 激活重计算(gradient checkpointing)— Design Document
KI-3 的具体落地。autograd tape 为反向保存了所有中间激活;dim768/bf16 在单卡 32GB 能跑 batch32(T12 解 OOM),但容量轴放大到 dim1024 会再次 OOM——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时重算该段前向重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。
Goal
在不动非重计算路径任何数值的前提下,新增一个 opt-in 的 per-block 激活重计算:
- 正确性硬闸门(exact):重计算是数学精确的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 on-vs-off 梯度对拍(紧容差)+ 全套回归(T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开
--recompute全绿。绝不提交一个改变梯度的重计算。 - 显存(payoff):测 dim768 batch32 峰值显存 on vs off(应降);确认 dim1024 batch32 现在装得下(不开重计算时 OOM,开了 fit)。
- 吞吐:测 tok/s on vs off(多一次前向,预计慢 ~20–35%)——报告 compute/memory 权衡。
什么是激活重计算
反向传播需要前向的中间激活(如 SwiGLU 的 gate、attention 的 probs)来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。
梯度检查点把模型切成若干段(segment)。前向时,段内 op 不记到 tape(detached / no-grad),只保留段的输入(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,从保存的输入重跑该段前向(输入作为 require-grad 的 leaf),用上游 grad seed 重算出的 output,在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。
代价:每段多一次前向(约 +1/3 的总 FLOPs,因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。
切粒度 = 每个 transformer block:一个 block(attention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 [batch*seq, dim] 的残差流张量,接口最干净。
Module Layout(surgical:非重计算路径逐字节不动)
1. xtrain-autodiff::checkpoint — checkpoint 高阶原语
新增 checkpoint(segment_fn, input, params) -> Var,类比 torch.utils.checkpoint:
segment_fn: Fn(&Var, &[Var]) -> Var——从单个输入x和参数 slicep构建段前向、返回段输出。必须确定性、只依赖x和p(这是重算精确的前提)。- 前向(不 tape 内部):把
input/paramsdetach 成新 leaf(Var::leaf(v.value())),跑segment_fn得到out_local,只取out_local.value()。局部Var出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents =[input, ..params](参数梯度落进优化器读的同一批 leaf)。 - 反向(重算):闭包捕获
Rc<segment_fn>。给定dout,从parents当前值重建 detached leaf,重跑segment_fn重建局部 tape,调out_local.backward_seeded(dout),再把x_det.grad()push 给parents[0]、各param_det.grad()push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。
2. xtrain-autodiff::tape — backward_seeded
引擎原 backward() 只能从标量 root 出发(seed ones_like + 断言 numel==1)。段输出一般非标量,故新增 Var::backward_seeded(seed):同样的 topo + 反向遍历,但用显式上游 grad seed(不断言标量)。backward() 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。
3. xtrain-model::TinyTransformer — per-block 包裹 + recompute 开关
- 把 block 前向体抽成自由函数
block_forward(cfg, compute_dtype, batch, seq, input, params)(不借&self,才能在反向闭包里重跑);同时把linear / norm_gamma / attention / swiglu_mlp改成参数化(cfg, compute_dtype)的自由函数。Block::block_params()给出 11 个 leaf 的固定序(与params()每 block 段一致)。 forward_batched的 block 循环:recompute开 →checkpoint(seg, &h, &b.block_params());关 → 直接block_forward(...)(与之前完全同图)。with_recompute(bool)builder(opt-in,默认关 = 原 tape,数值逐字节同)。
4. xtrain-train / xtrain-distributed — --recompute flag
bin/train/bin/train_ddp加--recompute,调model.with_recompute(true)。AdamW / clip / checkpoint / DDP all-reduce 不改(梯度语义与非重计算一致)。
Key Design Decisions
- 切粒度 = 每个 block:接口最干净(输入输出都是残差流
[B*S, dim]),且峰值激活降到 ~1 个 block。比「整模型一段」省得多,比「逐 op」简单得多。 - 参数作为 checkpoint 节点的 parents:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用当前参数值即原值。把它们列为 parents,重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → DDP/AdamW 零改动。
- detached leaf 隔离局部 tape:前向/反向都把输入和参数 detach 成新 leaf,使
segment_fn构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop(释放重算激活)。 backward_seeded而非改backward:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量backward()不动。- 重算精确 → 梯度逐位一致(硬闸门):同一
segment_fn、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。
与 bf16 / DDP / batched 的组合
- bf16(T12):
segment_fn就是不变的 block 前向,重算跑同一条 bf16 路径;cast算子的 grad 升精度桥(bf16→fp32)照常。重计算节点的参数 parents 是 fp32 master leaf,恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。 - DDP(T8):每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的
.grad()槽,再被 all-reduce 取均值——分布式路径不感知重计算。 - batched(T10):段输入/输出透明带
[batch*seq, …]batch 维;checkpoint与形状无关。
验证方法
1. 正确性(exact,硬闸门)
- on-vs-off 梯度对拍(
crates/xtrain-model/tests/recompute.rs):同 init 建两个模型(recompute on/off),跑同一 batched loss+backward,断言前向 logits、loss、每个参数梯度在紧容差内一致——参数化跑 fp32 和 bf16 两路。fp32 期望近逐位(容差 1e-4),bf16 仅放松到 bf16 舍入级(非重计算误差)。 - 全套回归开
--recompute:T4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、xserv 闭环 md5——全绿。
2. 显存(payoff)
- dash5 1× RTX 5090 32GB,dim768/18L batch32 seq256,bf16:测峰值显存 recompute on vs off(应降)。
- dim1024 batch32:先验证不开重计算 OOM,再验证开了 fit——capture 实际
nvidia-smi峰值。
3. 吞吐
- 同 config 测 steady-state tok/s recompute on vs off,报告慢多少(预计 ~20–35%,多一次前向)。
实测结果(dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, bf16, steady-state)
待 dash5 实跑回填。
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|---|---|---|---|---|
| dim768 recompute off | 32 | TBD | TBD | ✅ |
| dim768 recompute on | 32 | TBD(↓) | TBD(↓~xx%) | ✅ |
| dim1024 recompute off | 32 | — | — | ❌ OOM |
| dim1024 recompute on | 32 | TBD | TBD | ✅ 解 OOM |
正确性:on-vs-off 梯度 max rel = TBD(fp32)/ TBD(bf16);全套回归 + xserv 闭环全绿。