docs: Phase T13 — activation recompute

Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 09:43:56 +08:00
parent f202351be5
commit 69c5f07359
2 changed files with 94 additions and 1 deletions

View File

@@ -48,8 +48,13 @@ fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTran
m.with_compute_dtype(dtype).with_recompute(recompute)
}
/// Upcast to fp32 then read to host — logits are bf16 in bf16 mode (grads are
/// always fp32 master, but this is uniform and harmless for fp32 tensors).
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {

View File

@@ -0,0 +1,88 @@
# Phase T13: 激活重计算gradient checkpointing— Design Document
> KI-3 的具体落地。autograd tape 为反向保存了所有中间激活dim768/bf16 在单卡 32GB 能跑 batch32T12 解 OOM但**容量轴放大到 dim1024 会再次 OOM**——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时**重算该段前向**重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。
## Goal
在**不动非重计算路径任何数值**的前提下,新增一个 **opt-in 的 per-block 激活重计算**
1. **正确性硬闸门exact**:重计算是**数学精确**的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 **on-vs-off 梯度对拍**(紧容差)+ 全套回归T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开 `--recompute` 全绿。**绝不提交一个改变梯度的重计算。**
2. **显存payoff**:测 dim768 batch32 峰值显存 on vs off应降确认 **dim1024 batch32 现在装得下**(不开重计算时 OOM开了 fit
3. **吞吐**:测 tok/s on vs off多一次前向预计慢 ~2035%)——报告 compute/memory 权衡。
## 什么是激活重计算
反向传播需要前向的中间激活(如 SwiGLU 的 `gate`、attention 的 `probs`来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。
**梯度检查点**把模型切成若干**段segment**。前向时,段内 op **不记到 tape**detached / no-grad只保留**段的输入**(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,**从保存的输入重跑该段前向**(输入作为 require-grad 的 leaf用上游 grad seed 重算出的 output在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。
代价:每段多一次前向(约 +1/3 的总 FLOPs因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。
**切粒度 = 每个 transformer block**:一个 blockattention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 `[batch*seq, dim]` 的残差流张量,接口最干净。
## Module Layoutsurgical非重计算路径逐字节不动
### 1. `xtrain-autodiff::checkpoint` — `checkpoint` 高阶原语
新增 `checkpoint(segment_fn, input, params) -> Var`,类比 `torch.utils.checkpoint`
- `segment_fn: Fn(&Var, &[Var]) -> Var`——从单个输入 `x` 和参数 slice `p` 构建段前向、返回段输出。必须**确定性**、只依赖 `x``p`(这是重算精确的前提)。
- **前向(不 tape 内部)**:把 `input`/`params` detach 成新 leaf`Var::leaf(v.value())`),跑 `segment_fn` 得到 `out_local`**只取 `out_local.value()`**。局部 `Var` 出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents = `[input, ..params]`(参数梯度落进优化器读的同一批 leaf
- **反向(重算)**:闭包捕获 `Rc<segment_fn>`。给定 `dout`,从 `parents` 当前值重建 detached leaf重跑 `segment_fn` 重建局部 tape`out_local.backward_seeded(dout)`,再把 `x_det.grad()` push 给 `parents[0]`、各 `param_det.grad()` push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。
### 2. `xtrain-autodiff::tape` — `backward_seeded`
引擎原 `backward()` 只能从标量 root 出发seed `ones_like` + 断言 numel==1。段输出一般**非标量**,故新增 `Var::backward_seeded(seed)`:同样的 topo + 反向遍历,但用显式上游 grad seed不断言标量`backward()` 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。
### 3. `xtrain-model::TinyTransformer` — per-block 包裹 + `recompute` 开关
- 把 block 前向体抽成**自由函数** `block_forward(cfg, compute_dtype, batch, seq, input, params)`(不借 `&self`,才能在反向闭包里重跑);同时把 `linear / norm_gamma / attention / swiglu_mlp` 改成参数化 `(cfg, compute_dtype)` 的自由函数。`Block::block_params()` 给出 11 个 leaf 的固定序(与 `params()` 每 block 段一致)。
- `forward_batched` 的 block 循环:`recompute` 开 → `checkpoint(seg, &h, &b.block_params())`;关 → 直接 `block_forward(...)`**与之前完全同图**)。
- `with_recompute(bool)` builderopt-in默认关 = 原 tape数值逐字节同
### 4. `xtrain-train` / `xtrain-distributed` — `--recompute` flag
- `bin/train` / `bin/train_ddp``--recompute`,调 `model.with_recompute(true)`。AdamW / clip / checkpoint / DDP all-reduce **不改**(梯度语义与非重计算一致)。
## Key Design Decisions
- **切粒度 = 每个 block**:接口最干净(输入输出都是残差流 `[B*S, dim]`),且峰值激活降到 ~1 个 block。比「整模型一段」省得多比「逐 op」简单得多。
- **参数作为 checkpoint 节点的 parents**:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用**当前**参数值即原值。把它们列为 parents重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → **DDP/AdamW 零改动**
- **detached leaf 隔离局部 tape**:前向/反向都把输入和参数 detach 成新 leaf使 `segment_fn` 构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop释放重算激活
- **`backward_seeded` 而非改 `backward`**:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量 `backward()` 不动。
- **重算精确 → 梯度逐位一致(硬闸门)**:同一 `segment_fn`、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。
## 与 bf16 / DDP / batched 的组合
- **bf16T12**`segment_fn` 就是不变的 block 前向,重算跑**同一条 bf16 路径**`cast` 算子的 grad 升精度桥bf16→fp32照常。重计算节点的参数 parents 是 fp32 master leaf恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。
- **DDPT8**:每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的 `.grad()` 槽,再被 all-reduce 取均值——分布式路径不感知重计算。
- **batchedT10**:段输入/输出透明带 `[batch*seq, …]` batch 维;`checkpoint` 与形状无关。
## 验证方法
### 1. 正确性exact硬闸门
- **on-vs-off 梯度对拍**`crates/xtrain-model/tests/recompute.rs`):同 init 建两个模型recompute on/off跑同一 batched loss+backward断言**前向 logits、loss、每个参数梯度**在紧容差内一致——参数化跑 **fp32 和 bf16** 两路。fp32 期望近逐位(容差 1e-4bf16 仅放松到 bf16 舍入级(非重计算误差)。
- **全套回归开 `--recompute`**T4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、**xserv 闭环 md5**——全绿。
### 2. 显存payoff
- dash5 1× RTX 5090 32GBdim768/18L batch32 seq256bf16测峰值显存 recompute on vs off应降
- **dim1024 batch32**:先验证不开重计算 **OOM**,再验证开了 **fit**——capture 实际 `nvidia-smi` 峰值。
### 3. 吞吐
- 同 config 测 steady-state tok/s recompute on vs off报告慢多少预计 ~2035%,多一次前向)。
## 实测结果dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, bf16, steady-state
> 待 dash5 实跑回填。
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|---|---|---|---|---|
| dim768 recompute **off** | 32 | TBD | TBD | ✅ |
| dim768 recompute **on** | 32 | **TBD** | TBD↓~xx% | ✅ |
| **dim1024** recompute **off** | 32 | — | — | ❌ **OOM** |
| **dim1024** recompute **on** | 32 | **TBD** | TBD | ✅ **解 OOM** |
**正确性**on-vs-off 梯度 max rel = TBDfp32/ TBDbf16全套回归 + xserv 闭环全绿。