docs: Phase T13 — activation recompute
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward + recompute-on-backward design, the `checkpoint` primitive, per-block wrapping, the exactness/correctness argument (same kernels + inputs → identical grads), composition with bf16+DDP+batched, and the verification plan (on-vs-off grad gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD to fill after the dash5 run. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -48,8 +48,13 @@ fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTran
|
||||
m.with_compute_dtype(dtype).with_recompute(recompute)
|
||||
}
|
||||
|
||||
/// Upcast to fp32 then read to host — logits are bf16 in bf16 mode (grads are
|
||||
/// always fp32 master, but this is uniform and harmless for fp32 tensors).
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {
|
||||
|
||||
88
docs/12-activation-recompute.md
Normal file
88
docs/12-activation-recompute.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Phase T13: 激活重计算(gradient checkpointing)— Design Document
|
||||
|
||||
> KI-3 的具体落地。autograd tape 为反向保存了所有中间激活;dim768/bf16 在单卡 32GB 能跑 batch32(T12 解 OOM),但**容量轴放大到 dim1024 会再次 OOM**——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时**重算该段前向**重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。
|
||||
|
||||
## Goal
|
||||
|
||||
在**不动非重计算路径任何数值**的前提下,新增一个 **opt-in 的 per-block 激活重计算**:
|
||||
|
||||
1. **正确性硬闸门(exact)**:重计算是**数学精确**的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 **on-vs-off 梯度对拍**(紧容差)+ 全套回归(T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开 `--recompute` 全绿。**绝不提交一个改变梯度的重计算。**
|
||||
2. **显存(payoff)**:测 dim768 batch32 峰值显存 on vs off(应降);确认 **dim1024 batch32 现在装得下**(不开重计算时 OOM,开了 fit)。
|
||||
3. **吞吐**:测 tok/s on vs off(多一次前向,预计慢 ~20–35%)——报告 compute/memory 权衡。
|
||||
|
||||
## 什么是激活重计算
|
||||
|
||||
反向传播需要前向的中间激活(如 SwiGLU 的 `gate`、attention 的 `probs`)来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。
|
||||
|
||||
**梯度检查点**把模型切成若干**段(segment)**。前向时,段内 op **不记到 tape**(detached / no-grad),只保留**段的输入**(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,**从保存的输入重跑该段前向**(输入作为 require-grad 的 leaf),用上游 grad seed 重算出的 output,在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。
|
||||
|
||||
代价:每段多一次前向(约 +1/3 的总 FLOPs,因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。
|
||||
|
||||
**切粒度 = 每个 transformer block**:一个 block(attention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 `[batch*seq, dim]` 的残差流张量,接口最干净。
|
||||
|
||||
## Module Layout(surgical:非重计算路径逐字节不动)
|
||||
|
||||
### 1. `xtrain-autodiff::checkpoint` — `checkpoint` 高阶原语
|
||||
|
||||
新增 `checkpoint(segment_fn, input, params) -> Var`,类比 `torch.utils.checkpoint`:
|
||||
|
||||
- `segment_fn: Fn(&Var, &[Var]) -> Var`——从单个输入 `x` 和参数 slice `p` 构建段前向、返回段输出。必须**确定性**、只依赖 `x` 和 `p`(这是重算精确的前提)。
|
||||
- **前向(不 tape 内部)**:把 `input`/`params` detach 成新 leaf(`Var::leaf(v.value())`),跑 `segment_fn` 得到 `out_local`,**只取 `out_local.value()`**。局部 `Var` 出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents = `[input, ..params]`(参数梯度落进优化器读的同一批 leaf)。
|
||||
- **反向(重算)**:闭包捕获 `Rc<segment_fn>`。给定 `dout`,从 `parents` 当前值重建 detached leaf,重跑 `segment_fn` 重建局部 tape,调 `out_local.backward_seeded(dout)`,再把 `x_det.grad()` push 给 `parents[0]`、各 `param_det.grad()` push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。
|
||||
|
||||
### 2. `xtrain-autodiff::tape` — `backward_seeded`
|
||||
|
||||
引擎原 `backward()` 只能从标量 root 出发(seed `ones_like` + 断言 numel==1)。段输出一般**非标量**,故新增 `Var::backward_seeded(seed)`:同样的 topo + 反向遍历,但用显式上游 grad seed(不断言标量)。`backward()` 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。
|
||||
|
||||
### 3. `xtrain-model::TinyTransformer` — per-block 包裹 + `recompute` 开关
|
||||
|
||||
- 把 block 前向体抽成**自由函数** `block_forward(cfg, compute_dtype, batch, seq, input, params)`(不借 `&self`,才能在反向闭包里重跑);同时把 `linear / norm_gamma / attention / swiglu_mlp` 改成参数化 `(cfg, compute_dtype)` 的自由函数。`Block::block_params()` 给出 11 个 leaf 的固定序(与 `params()` 每 block 段一致)。
|
||||
- `forward_batched` 的 block 循环:`recompute` 开 → `checkpoint(seg, &h, &b.block_params())`;关 → 直接 `block_forward(...)`(**与之前完全同图**)。
|
||||
- `with_recompute(bool)` builder(opt-in,默认关 = 原 tape,数值逐字节同)。
|
||||
|
||||
### 4. `xtrain-train` / `xtrain-distributed` — `--recompute` flag
|
||||
|
||||
- `bin/train` / `bin/train_ddp` 加 `--recompute`,调 `model.with_recompute(true)`。AdamW / clip / checkpoint / DDP all-reduce **不改**(梯度语义与非重计算一致)。
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
- **切粒度 = 每个 block**:接口最干净(输入输出都是残差流 `[B*S, dim]`),且峰值激活降到 ~1 个 block。比「整模型一段」省得多,比「逐 op」简单得多。
|
||||
- **参数作为 checkpoint 节点的 parents**:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用**当前**参数值即原值。把它们列为 parents,重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → **DDP/AdamW 零改动**。
|
||||
- **detached leaf 隔离局部 tape**:前向/反向都把输入和参数 detach 成新 leaf,使 `segment_fn` 构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop(释放重算激活)。
|
||||
- **`backward_seeded` 而非改 `backward`**:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量 `backward()` 不动。
|
||||
- **重算精确 → 梯度逐位一致(硬闸门)**:同一 `segment_fn`、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。
|
||||
|
||||
## 与 bf16 / DDP / batched 的组合
|
||||
|
||||
- **bf16(T12)**:`segment_fn` 就是不变的 block 前向,重算跑**同一条 bf16 路径**;`cast` 算子的 grad 升精度桥(bf16→fp32)照常。重计算节点的参数 parents 是 fp32 master leaf,恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。
|
||||
- **DDP(T8)**:每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的 `.grad()` 槽,再被 all-reduce 取均值——分布式路径不感知重计算。
|
||||
- **batched(T10)**:段输入/输出透明带 `[batch*seq, …]` batch 维;`checkpoint` 与形状无关。
|
||||
|
||||
## 验证方法
|
||||
|
||||
### 1. 正确性(exact,硬闸门)
|
||||
|
||||
- **on-vs-off 梯度对拍**(`crates/xtrain-model/tests/recompute.rs`):同 init 建两个模型(recompute on/off),跑同一 batched loss+backward,断言**前向 logits、loss、每个参数梯度**在紧容差内一致——参数化跑 **fp32 和 bf16** 两路。fp32 期望近逐位(容差 1e-4),bf16 仅放松到 bf16 舍入级(非重计算误差)。
|
||||
- **全套回归开 `--recompute`**:T4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、**xserv 闭环 md5**——全绿。
|
||||
|
||||
### 2. 显存(payoff)
|
||||
|
||||
- dash5 1× RTX 5090 32GB,dim768/18L batch32 seq256,bf16:测峰值显存 recompute on vs off(应降)。
|
||||
- **dim1024 batch32**:先验证不开重计算 **OOM**,再验证开了 **fit**——capture 实际 `nvidia-smi` 峰值。
|
||||
|
||||
### 3. 吞吐
|
||||
|
||||
- 同 config 测 steady-state tok/s recompute on vs off,报告慢多少(预计 ~20–35%,多一次前向)。
|
||||
|
||||
## 实测结果(dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, bf16, steady-state)
|
||||
|
||||
> 待 dash5 实跑回填。
|
||||
|
||||
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|
||||
|---|---|---|---|---|
|
||||
| dim768 recompute **off** | 32 | TBD | TBD | ✅ |
|
||||
| dim768 recompute **on** | 32 | **TBD(↓)** | TBD(↓~xx%) | ✅ |
|
||||
| **dim1024** recompute **off** | 32 | — | — | ❌ **OOM** |
|
||||
| **dim1024** recompute **on** | 32 | **TBD** | TBD | ✅ **解 OOM** |
|
||||
|
||||
**正确性**:on-vs-off 梯度 max rel = TBD(fp32)/ TBD(bf16);全套回归 + xserv 闭环全绿。
|
||||
Reference in New Issue
Block a user