Files
xtrain/docs/12-activation-recompute.md
Gahow Wang 0150263055 perf: KI-3 fixed — dim1024 batch32 fits, mem 31.1→14.6GB, tok/s 39.7K→31.5K
Per-block activation recompute (T13) measured on dash5 (1× RTX 5090 32GB, bf16,
batch32 seq256, steady-state):

- Correctness (exact, hard gate): recompute on-vs-off grads are BIT-IDENTICAL —
  fp32 AND bf16: loss / logits / every param grad max rel = 0.00e0 (not "within
  tol", exactly equal). Full suite green with recompute on/off; DDP loss-match
  5.67e-7; DDP+recompute 2-rank descends 11.079→6.010.
- dim768 (18L/24h ffn2048, core 127M): peak mem 31144→14562 MiB (−53%), tok/s
  39.7K→31.5K (−20%, the extra-forward tradeoff, in the predicted 20–35% band).
- dim1024 (18L/32h ffn2730, core 226M): recompute OFF OOMs (hits 32100/32607
  MiB → OutOfMemory); recompute ON fits at 16596 MiB, ~23K tok/s, converges.
  → KI-3 payoff achieved: dim1024 batch32 unblocked, v8 can proceed.

Fill docs/12 bench table; mark KI-3 FIXED in docs/known-issues.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:50:29 +08:00

91 lines
9.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T13: 激活重计算gradient checkpointing— Design Document
> KI-3 的具体落地。autograd tape 为反向保存了所有中间激活dim768/bf16 在单卡 32GB 能跑 batch32T12 解 OOM但**容量轴放大到 dim1024 会再次 OOM**——激活显存随 dim 线性增长。激活重计算用「多一次前向」换显存:段内激活不在前向保存,反向时**重算该段前向**重建局部 tape 再回传。峰值激活从「所有 block 同时在显存」降到「~一个 block + 每 block 的输入」→ dim1024 batch32 装得下。
## Goal
在**不动非重计算路径任何数值**的前提下,新增一个 **opt-in 的 per-block 激活重计算**
1. **正确性硬闸门exact**:重计算是**数学精确**的——同一段前向、同一输入、同一(未变的)参数值、确定性 kernel ⟹ 重算出的激活与原激活逐位相同,回传的梯度与非重计算版一致。直接的 **on-vs-off 梯度对拍**(紧容差)+ 全套回归T4 grad-check、T5 overfit+PyTorch 对拍、T6 AdamW、T8 DDP loss-match+跨 rank、xserv 闭环)开 `--recompute` 全绿。**绝不提交一个改变梯度的重计算。**
2. **显存payoff**:测 dim768 batch32 峰值显存 on vs off应降确认 **dim1024 batch32 现在装得下**(不开重计算时 OOM开了 fit
3. **吞吐**:测 tok/s on vs off多一次前向预计慢 ~2035%)——报告 compute/memory 权衡。
## 什么是激活重计算
反向传播需要前向的中间激活(如 SwiGLU 的 `gate`、attention 的 `probs`来算梯度。define-by-run 的 tape 默认把它们全部留在显存,直到对应 op 的 backward 跑完。模型越深、激活越多,峰值显存越高。
**梯度检查点**把模型切成若干**段segment**。前向时,段内 op **不记到 tape**detached / no-grad只保留**段的输入**(参数作为 leaf 本就常驻,不算激活)。反向时,当段的 output-grad 到达,**从保存的输入重跑该段前向**(输入作为 require-grad 的 leaf用上游 grad seed 重算出的 output在局部 tape 上回传,得到输入梯度(并累加参数梯度),然后释放局部 tape。
代价:每段多一次前向(约 +1/3 的总 FLOPs因为反向本就 ~2× 前向)。收益:峰值激活从「所有段」降到「~一段 + 每段输入」。
**切粒度 = 每个 transformer block**:一个 blockattention 子块 + MLP 子块 + 两个残差)是天然的段边界,输入/输出都是 `[batch*seq, dim]` 的残差流张量,接口最干净。
## Module Layoutsurgical非重计算路径逐字节不动
### 1. `xtrain-autodiff::checkpoint` — `checkpoint` 高阶原语
新增 `checkpoint(segment_fn, input, params) -> Var`,类比 `torch.utils.checkpoint`
- `segment_fn: Fn(&Var, &[Var]) -> Var`——从单个输入 `x` 和参数 slice `p` 构建段前向、返回段输出。必须**确定性**、只依赖 `x``p`(这是重算精确的前提)。
- **前向(不 tape 内部)**:把 `input`/`params` detach 成新 leaf`Var::leaf(v.value())`),跑 `segment_fn` 得到 `out_local`**只取 `out_local.value()`**。局部 `Var` 出作用域即 drop → 段内激活立即释放。checkpoint 节点的 parents = `[input, ..params]`(参数梯度落进优化器读的同一批 leaf
- **反向(重算)**:闭包捕获 `Rc<segment_fn>`。给定 `dout`,从 `parents` 当前值重建 detached leaf重跑 `segment_fn` 重建局部 tape`out_local.backward_seeded(dout)`,再把 `x_det.grad()` push 给 `parents[0]`、各 `param_det.grad()` push 给对应参数 parent。闭包结束 → 局部 tape drop → 重算激活释放。
### 2. `xtrain-autodiff::tape` — `backward_seeded`
引擎原 `backward()` 只能从标量 root 出发seed `ones_like` + 断言 numel==1。段输出一般**非标量**,故新增 `Var::backward_seeded(seed)`:同样的 topo + 反向遍历,但用显式上游 grad seed不断言标量`backward()` 退化为「seed ones」的薄包装——标量 loss 路径逐字节不变。
### 3. `xtrain-model::TinyTransformer` — per-block 包裹 + `recompute` 开关
- 把 block 前向体抽成**自由函数** `block_forward(cfg, compute_dtype, batch, seq, input, params)`(不借 `&self`,才能在反向闭包里重跑);同时把 `linear / norm_gamma / attention / swiglu_mlp` 改成参数化 `(cfg, compute_dtype)` 的自由函数。`Block::block_params()` 给出 11 个 leaf 的固定序(与 `params()` 每 block 段一致)。
- `forward_batched` 的 block 循环:`recompute` 开 → `checkpoint(seg, &h, &b.block_params())`;关 → 直接 `block_forward(...)`**与之前完全同图**)。
- `with_recompute(bool)` builderopt-in默认关 = 原 tape数值逐字节同
### 4. `xtrain-train` / `xtrain-distributed` — `--recompute` flag
- `bin/train` / `bin/train_ddp``--recompute`,调 `model.with_recompute(true)`。AdamW / clip / checkpoint / DDP all-reduce **不改**(梯度语义与非重计算一致)。
## Key Design Decisions
- **切粒度 = 每个 block**:接口最干净(输入输出都是残差流 `[B*S, dim]`),且峰值激活降到 ~1 个 block。比「整模型一段」省得多比「逐 op」简单得多。
- **参数作为 checkpoint 节点的 parents**:参数 leaf 跨前向/反向不变(只 grad 槽变),重算用**当前**参数值即原值。把它们列为 parents重算恢复的参数梯度直接 push 到优化器读的同一批 leaf → **DDP/AdamW 零改动**
- **detached leaf 隔离局部 tape**:前向/反向都把输入和参数 detach 成新 leaf使 `segment_fn` 构建的图与外层 tape 断开。前向丢弃局部图(释放激活);反向局部图回传完即 drop释放重算激活
- **`backward_seeded` 而非改 `backward`**:段输出非标量,需要用上游 output-grad 作 seed 回传局部 tape。新增方法、原标量 `backward()` 不动。
- **重算精确 → 梯度逐位一致(硬闸门)**:同一 `segment_fn`、同一输入值、同一参数值、确定性前向 kernel ⟹ 重算 output 与原 output 相同,局部反向就是该段的普通解析反向。故输入/参数梯度与非重计算版一致——这是绝不能违反的闸门。
## 与 bf16 / DDP / batched 的组合
- **bf16T12**`segment_fn` 就是不变的 block 前向,重算跑**同一条 bf16 路径**`cast` 算子的 grad 升精度桥bf16→fp32照常。重计算节点的参数 parents 是 fp32 master leaf恢复的是 fp32 梯度。on-vs-off 对拍同时跑 fp32 和 bf16 两路。
- **DDPT8**:每个 rank 独立 checkpoint 自己的前向/反向;恢复的参数梯度落进各 rank 的 `.grad()` 槽,再被 all-reduce 取均值——分布式路径不感知重计算。
- **batchedT10**:段输入/输出透明带 `[batch*seq, …]` batch 维;`checkpoint` 与形状无关。
## 验证方法
### 1. 正确性exact硬闸门
- **on-vs-off 梯度对拍**`crates/xtrain-model/tests/recompute.rs`):同 init 建两个模型recompute on/off跑同一 batched loss+backward断言**前向 logits、loss、每个参数梯度**在紧容差内一致——参数化跑 **fp32 和 bf16** 两路。fp32 期望近逐位(容差 1e-4bf16 仅放松到 bf16 舍入级(非重计算误差)。
- **全套回归开 `--recompute`**T4 15 算子 grad-check、T5 overfit 27/27 + PyTorch 对拍、T6 AdamW、T8 DDP loss-match + 跨 rank、**xserv 闭环 md5**——全绿。
### 2. 显存payoff
- dash5 1× RTX 5090 32GBdim768/18L batch32 seq256bf16测峰值显存 recompute on vs off应降
- **dim1024 batch32**:先验证不开重计算 **OOM**,再验证开了 **fit**——capture 实际 `nvidia-smi` 峰值。
### 3. 吞吐
- 同 config 测 steady-state tok/s recompute on vs off报告慢多少预计 ~2035%,多一次前向)。
## 实测结果dash5 1× RTX 5090 32GB, bf16, batch32 seq256, steady-state
**正确性exact硬闸门**on-vs-off 梯度对拍 —— **fp32 与 bf16 双路都逐位一致**loss rel `0.00e0`、logits max rel `0.00e0`、**每个参数梯度 max rel `0.00e0`**(不是「在容差内」,是逐位相同——证实重算确实精确)。全套回归开/关重计算全绿T4 15 算子 grad-check、5 结构、batched、bf16、overfit、AdamWGPU+host、GEMM、checkpoint roundtrip、**T8 DDP loss 对单卡 5.67e-7 + 跨 rank 0.0**DDP+recompute 2 卡短训 loss 单调降11.079→6.010)。
**显存 + 吞吐**dim768 = 18L/24h×32/ffn2048 core 127Mdim1024 = 18L/32h×32/ffn2730 core 226M
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|---|---|---|---|---|
| dim768 recompute **off** | 32 | 31144 MiB | 39.7K | ✅ |
| **dim768 recompute on** | 32 | **14562 MiB53%** | **31.5K20%** | ✅ |
| **dim1024** recompute **off** | 32 | 32100 → **OOM** | — | ❌ **OOM** |
| **dim1024 recompute on** | 32 | **16596 MiB** | 23.1K | ✅ **解 OOM** |
→ dim768重计算把峰值显存 **31.1→14.6GB53%~砍半激活)**,代价 tok/s **20%**(多一次前向,落在预测 2035% 区间。dim1024 batch32**不开 OOM撞 32100/32607MiB 上限)→ 开了 16.6GB 稳稳装下**~23K tok/s 训练正常收敛 —— **KI-3 的目标达成dim1024 解锁**