Files
xtrain/docs/12-activation-recompute.md
Gahow Wang 69c5f07359 docs: Phase T13 — activation recompute
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:45:16 +08:00

8.4 KiB
Raw Blame History

Phase T13: 激活重计算gradient checkpointing— Design Document

KI-3 的具体落地。autograd tape 为反向保存了所有中间激活dim768/bf16 在单卡 32GB 能跑 batch32T12 解 OOM容量轴放大到 dim1024 会再次 OOM——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时重算该段前向重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。

Goal

不动非重计算路径任何数值的前提下,新增一个 opt-in 的 per-block 激活重计算

  1. 正确性硬闸门exact:重计算是数学精确的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 on-vs-off 梯度对拍(紧容差)+ 全套回归T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开 --recompute 全绿。绝不提交一个改变梯度的重计算。
  2. 显存payoff:测 dim768 batch32 峰值显存 on vs off应降确认 dim1024 batch32 现在装得下(不开重计算时 OOM开了 fit
  3. 吞吐:测 tok/s on vs off多一次前向预计慢 ~2035%)——报告 compute/memory 权衡。

什么是激活重计算

反向传播需要前向的中间激活(如 SwiGLU 的 gate、attention 的 probs来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。

梯度检查点把模型切成若干segment。前向时,段内 op 不记到 tapedetached / no-grad只保留段的输入(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,从保存的输入重跑该段前向(输入作为 require-grad 的 leaf用上游 grad seed 重算出的 output在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。

代价:每段多一次前向(约 +1/3 的总 FLOPs因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。

切粒度 = 每个 transformer block:一个 blockattention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 [batch*seq, dim] 的残差流张量,接口最干净。

Module Layoutsurgical非重计算路径逐字节不动

1. xtrain-autodiff::checkpointcheckpoint 高阶原语

新增 checkpoint(segment_fn, input, params) -> Var,类比 torch.utils.checkpoint

  • segment_fn: Fn(&Var, &[Var]) -> Var——从单个输入 x 和参数 slice p 构建段前向、返回段输出。必须确定性、只依赖 xp(这是重算精确的前提)。
  • 前向(不 tape 内部):把 input/params detach 成新 leafVar::leaf(v.value())),跑 segment_fn 得到 out_local只取 out_local.value()。局部 Var 出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents = [input, ..params](参数梯度落进优化器读的同一批 leaf
  • 反向(重算):闭包捕获 Rc<segment_fn>。给定 dout,从 parents 当前值重建 detached leaf重跑 segment_fn 重建局部 tapeout_local.backward_seeded(dout),再把 x_det.grad() push 给 parents[0]、各 param_det.grad() push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。

2. xtrain-autodiff::tapebackward_seeded

引擎原 backward() 只能从标量 root 出发seed ones_like + 断言 numel==1。段输出一般非标量,故新增 Var::backward_seeded(seed):同样的 topo + 反向遍历,但用显式上游 grad seed不断言标量backward() 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。

3. xtrain-model::TinyTransformer — per-block 包裹 + recompute 开关

  • 把 block 前向体抽成自由函数 block_forward(cfg, compute_dtype, batch, seq, input, params)(不借 &self,才能在反向闭包里重跑);同时把 linear / norm_gamma / attention / swiglu_mlp 改成参数化 (cfg, compute_dtype) 的自由函数。Block::block_params() 给出 11 个 leaf 的固定序(与 params() 每 block 段一致)。
  • forward_batched 的 block 循环:recompute 开 → checkpoint(seg, &h, &b.block_params());关 → 直接 block_forward(...)与之前完全同图)。
  • with_recompute(bool) builderopt-in默认关 = 原 tape数值逐字节同

4. xtrain-train / xtrain-distributed--recompute flag

  • bin/train / bin/train_ddp--recompute,调 model.with_recompute(true)。AdamW / clip / checkpoint / DDP all-reduce 不改(梯度语义与非重计算一致)。

Key Design Decisions

  • 切粒度 = 每个 block:接口最干净(输入输出都是残差流 [B*S, dim]),且峰值激活降到 ~1 个 block。比「整模型一段」省得多比「逐 op」简单得多。
  • 参数作为 checkpoint 节点的 parents:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用当前参数值即原值。把它们列为 parents重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → DDP/AdamW 零改动
  • detached leaf 隔离局部 tape:前向/反向都把输入和参数 detach 成新 leaf使 segment_fn 构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop释放重算激活
  • backward_seeded 而非改 backward:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量 backward() 不动。
  • 重算精确 → 梯度逐位一致(硬闸门):同一 segment_fn、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。

与 bf16 / DDP / batched 的组合

  • bf16T12segment_fn 就是不变的 block 前向,重算跑同一条 bf16 路径cast 算子的 grad 升精度桥bf16→fp32照常。重计算节点的参数 parents 是 fp32 master leaf恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。
  • DDPT8:每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的 .grad() 槽,再被 all-reduce 取均值——分布式路径不感知重计算。
  • batchedT10:段输入/输出透明带 [batch*seq, …] batch 维;checkpoint 与形状无关。

验证方法

1. 正确性exact硬闸门

  • on-vs-off 梯度对拍crates/xtrain-model/tests/recompute.rs):同 init 建两个模型recompute on/off跑同一 batched loss+backward断言前向 logits、loss、每个参数梯度在紧容差内一致——参数化跑 fp32 和 bf16 两路。fp32 期望近逐位(容差 1e-4bf16 仅放松到 bf16 舍入级(非重计算误差)。
  • 全套回归开 --recomputeT4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、xserv 闭环 md5——全绿。

2. 显存payoff

  • dash5 1× RTX 5090 32GBdim768/18L batch32 seq256bf16测峰值显存 recompute on vs off应降
  • dim1024 batch32:先验证不开重计算 OOM,再验证开了 fit——capture 实际 nvidia-smi 峰值。

3. 吞吐

  • 同 config 测 steady-state tok/s recompute on vs off报告慢多少预计 ~2035%,多一次前向)。

实测结果dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, bf16, steady-state

待 dash5 实跑回填。

config per-rank batch 峰值显存 tok/s fits 32GB?
dim768 recompute off 32 TBD TBD
dim768 recompute on 32 TBD TBD↓~xx%
dim1024 recompute off 32 OOM
dim1024 recompute on 32 TBD TBD 解 OOM

正确性on-vs-off 梯度 max rel = TBDfp32/ TBDbf16全套回归 + xserv 闭环全绿。