Files
xtrain/docs/12-activation-recompute.md
Gahow Wang 0150263055 perf: KI-3 fixed — dim1024 batch32 fits, mem 31.1→14.6GB, tok/s 39.7K→31.5K
Per-block activation recompute (T13) measured on dash5 (1× RTX 5090 32GB, bf16,
batch32 seq256, steady-state):

- Correctness (exact, hard gate): recompute on-vs-off grads are BIT-IDENTICAL —
  fp32 AND bf16: loss / logits / every param grad max rel = 0.00e0 (not "within
  tol", exactly equal). Full suite green with recompute on/off; DDP loss-match
  5.67e-7; DDP+recompute 2-rank descends 11.079→6.010.
- dim768 (18L/24h ffn2048, core 127M): peak mem 31144→14562 MiB (−53%), tok/s
  39.7K→31.5K (−20%, the extra-forward tradeoff, in the predicted 20–35% band).
- dim1024 (18L/32h ffn2730, core 226M): recompute OFF OOMs (hits 32100/32607
  MiB → OutOfMemory); recompute ON fits at 16596 MiB, ~23K tok/s, converges.
  → KI-3 payoff achieved: dim1024 batch32 unblocked, v8 can proceed.

Fill docs/12 bench table; mark KI-3 FIXED in docs/known-issues.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:50:29 +08:00

9.3 KiB
Raw Permalink Blame History

Phase T13: 激活重计算gradient checkpointing— Design Document

KI-3 的具体落地。autograd tape 为反向保存了所有中间激活dim768/bf16 在单卡 32GB 能跑 batch32T12 解 OOM容量轴放大到 dim1024 会再次 OOM——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时重算该段前向重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。

Goal

不动非重计算路径任何数值的前提下,新增一个 opt-in 的 per-block 激活重计算

  1. 正确性硬闸门exact:重计算是数学精确的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 on-vs-off 梯度对拍(紧容差)+ 全套回归T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开 --recompute 全绿。绝不提交一个改变梯度的重计算。
  2. 显存payoff:测 dim768 batch32 峰值显存 on vs off应降确认 dim1024 batch32 现在装得下(不开重计算时 OOM开了 fit
  3. 吞吐:测 tok/s on vs off多一次前向预计慢 ~2035%)——报告 compute/memory 权衡。

什么是激活重计算

反向传播需要前向的中间激活(如 SwiGLU 的 gate、attention 的 probs来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。

梯度检查点把模型切成若干segment。前向时,段内 op 不记到 tapedetached / no-grad只保留段的输入(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,从保存的输入重跑该段前向(输入作为 require-grad 的 leaf用上游 grad seed 重算出的 output在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。

代价:每段多一次前向(约 +1/3 的总 FLOPs因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。

切粒度 = 每个 transformer block:一个 blockattention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 [batch*seq, dim] 的残差流张量,接口最干净。

Module Layoutsurgical非重计算路径逐字节不动

1. xtrain-autodiff::checkpointcheckpoint 高阶原语

新增 checkpoint(segment_fn, input, params) -> Var,类比 torch.utils.checkpoint

  • segment_fn: Fn(&Var, &[Var]) -> Var——从单个输入 x 和参数 slice p 构建段前向、返回段输出。必须确定性、只依赖 xp(这是重算精确的前提)。
  • 前向(不 tape 内部):把 input/params detach 成新 leafVar::leaf(v.value())),跑 segment_fn 得到 out_local只取 out_local.value()。局部 Var 出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents = [input, ..params](参数梯度落进优化器读的同一批 leaf
  • 反向(重算):闭包捕获 Rc<segment_fn>。给定 dout,从 parents 当前值重建 detached leaf重跑 segment_fn 重建局部 tapeout_local.backward_seeded(dout),再把 x_det.grad() push 给 parents[0]、各 param_det.grad() push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。

2. xtrain-autodiff::tapebackward_seeded

引擎原 backward() 只能从标量 root 出发seed ones_like + 断言 numel==1。段输出一般非标量,故新增 Var::backward_seeded(seed):同样的 topo + 反向遍历,但用显式上游 grad seed不断言标量backward() 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。

3. xtrain-model::TinyTransformer — per-block 包裹 + recompute 开关

  • 把 block 前向体抽成自由函数 block_forward(cfg, compute_dtype, batch, seq, input, params)(不借 &self,才能在反向闭包里重跑);同时把 linear / norm_gamma / attention / swiglu_mlp 改成参数化 (cfg, compute_dtype) 的自由函数。Block::block_params() 给出 11 个 leaf 的固定序(与 params() 每 block 段一致)。
  • forward_batched 的 block 循环:recompute 开 → checkpoint(seg, &h, &b.block_params());关 → 直接 block_forward(...)与之前完全同图)。
  • with_recompute(bool) builderopt-in默认关 = 原 tape数值逐字节同

4. xtrain-train / xtrain-distributed--recompute flag

  • bin/train / bin/train_ddp--recompute,调 model.with_recompute(true)。AdamW / clip / checkpoint / DDP all-reduce 不改(梯度语义与非重计算一致)。

Key Design Decisions

  • 切粒度 = 每个 block:接口最干净(输入输出都是残差流 [B*S, dim]),且峰值激活降到 ~1 个 block。比「整模型一段」省得多比「逐 op」简单得多。
  • 参数作为 checkpoint 节点的 parents:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用当前参数值即原值。把它们列为 parents重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → DDP/AdamW 零改动
  • detached leaf 隔离局部 tape:前向/反向都把输入和参数 detach 成新 leaf使 segment_fn 构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop释放重算激活
  • backward_seeded 而非改 backward:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量 backward() 不动。
  • 重算精确 → 梯度逐位一致(硬闸门):同一 segment_fn、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。

与 bf16 / DDP / batched 的组合

  • bf16T12segment_fn 就是不变的 block 前向,重算跑同一条 bf16 路径cast 算子的 grad 升精度桥bf16→fp32照常。重计算节点的参数 parents 是 fp32 master leaf恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。
  • DDPT8:每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的 .grad() 槽,再被 all-reduce 取均值——分布式路径不感知重计算。
  • batchedT10:段输入/输出透明带 [batch*seq, …] batch 维;checkpoint 与形状无关。

验证方法

1. 正确性exact硬闸门

  • on-vs-off 梯度对拍crates/xtrain-model/tests/recompute.rs):同 init 建两个模型recompute on/off跑同一 batched loss+backward断言前向 logits、loss、每个参数梯度在紧容差内一致——参数化跑 fp32 和 bf16 两路。fp32 期望近逐位(容差 1e-4bf16 仅放松到 bf16 舍入级(非重计算误差)。
  • 全套回归开 --recomputeT4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、xserv 闭环 md5——全绿。

2. 显存payoff

  • dash5 1× RTX 5090 32GBdim768/18L batch32 seq256bf16测峰值显存 recompute on vs off应降
  • dim1024 batch32:先验证不开重计算 OOM,再验证开了 fit——capture 实际 nvidia-smi 峰值。

3. 吞吐

  • 同 config 测 steady-state tok/s recompute on vs off报告慢多少预计 ~2035%,多一次前向)。

实测结果dash5 1× RTX 5090 32GB, bf16, batch32 seq256, steady-state

正确性exact硬闸门on-vs-off 梯度对拍 —— fp32 与 bf16 双路都逐位一致loss rel 0.00e0、logits max rel 0.00e0每个参数梯度 max rel 0.00e0(不是「在容差内」,是逐位相同——证实重算确实精确)。全套回归开/关重计算全绿T4 15 算子 grad-check、5 结构、batched、bf16、overfit、AdamWGPU+host、GEMM、checkpoint roundtrip、T8 DDP loss 对单卡 5.67e-7 + 跨 rank 0.0DDP+recompute 2 卡短训 loss 单调降11.079→6.010)。

显存 + 吞吐dim768 = 18L/24h×32/ffn2048 core 127Mdim1024 = 18L/32h×32/ffn2730 core 226M

config per-rank batch 峰值显存 tok/s fits 32GB?
dim768 recompute off 32 31144 MiB 39.7K
dim768 recompute on 32 14562 MiB53% 31.5K20%
dim1024 recompute off 32 32100 → OOM OOM
dim1024 recompute on 32 16596 MiB 23.1K 解 OOM

→ dim768重计算把峰值显存 31.1→14.6GB53%~砍半激活),代价 tok/s 20%(多一次前向,落在预测 2035% 区间。dim1024 batch32不开 OOM撞 32100/32607MiB 上限)→ 开了 16.6GB 稳稳装下~23K tok/s 训练正常收敛 —— KI-3 的目标达成dim1024 解锁