docs: Phase T4 — autograd engine
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
136
docs/03-autograd-engine.md
Normal file
136
docs/03-autograd-engine.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Phase: Autograd Engine + Op Backward — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
在 T3 的 `Tensor`(matmul/transpose/finite-diff harness)之上,交付 **tape-based 动态 autograd 引擎** + 一个 tiny 现代 transformer 所需算子的**前向 kernel + 解析 backward**,每个 backward 都用 T3 的有限差分 harness 对拍通过。
|
||||
|
||||
具体三件事:
|
||||
|
||||
1. **autograd 引擎**:define-by-run 的反向自动微分。`Var` 包一个 `Tensor` + 可选 grad;每个 op 在 tape 上记一个节点(父节点 + backward 闭包);`backward()` 按逆拓扑序遍历,把梯度推给父节点。**关键正确性点:梯度累加**——一个张量被多个 op 消费(扇出)时,各路梯度必须**求和**(T3 没有累加路径,在这里实现)。
|
||||
2. **算子节点**:`matmul` / `add` / `mul` / `add_bias`(broadcast) / `scale` / `rms_norm` / `silu` / `swiglu` / `rope` / `softmax` / `cross_entropy`,各带前向 CUDA kernel(需要时)+ 解析 backward。
|
||||
3. **Attention 用组合**:`attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V)`。一旦 matmul/softmax/scale 是 autograd 节点,attention 的 backward 自动成立——**不写 fused attention backward kernel**,只加一个端到端 grad-check 测试。
|
||||
|
||||
**明确不做**(留给 T5/T6):组装 transformer / 训练 loop / 优化器 / embedding / KV-cache / GQA 重复。本 Phase 只到「算子 backward 逐个对拍通过」。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
csrc/ops/nn.cu # 所有 T4 算子的 fwd+bwd kernel + launch_*(含 inlined warp/block reduce)
|
||||
|
||||
crates/xtrain-cuda/
|
||||
├── build.rs # 新增 nn.cu
|
||||
└── src/ffi.rs # 新增 launch_* 声明(no_cuda 门控)
|
||||
|
||||
crates/xtrain-tensor/
|
||||
├── src/dtype.rs # 新增 I32(cross-entropy target 用)
|
||||
└── src/tensor.rs # add/mul/add_bias/sum_rows/rms_norm(+bwd)/silu(+bwd)/
|
||||
# rope(+bwd)/softmax(+bwd)/cross_entropy(+bwd)(no_cuda 门控)
|
||||
|
||||
crates/xtrain-autodiff/ # 引擎落在这里(已含 grad_check harness,自然归宿)
|
||||
├── build.rs # 新增:检测 nvcc → no_cuda cfg(cfg 不跨 crate 传播)
|
||||
├── src/
|
||||
│ ├── lib.rs # 导出 tape::Var + ops(no_cuda 门控)
|
||||
│ ├── finite_diff.rs # T3 既有 harness(不动)
|
||||
│ ├── tape.rs # Var / VarNode / backward / 梯度累加
|
||||
│ └── ops.rs # 各算子的 Var 节点构造器
|
||||
└── tests/autograd.rs # 每算子 grad-check + 扇出累加 + 组合 attention(#![cfg(not(no_cuda))])
|
||||
```
|
||||
|
||||
为什么引擎放 `xtrain-autodiff` 而不是新 crate:该 crate 本就是「自动微分」语义的归宿,且已持有 `grad_check`。前向 kernel/`Tensor` 方法仍按 T2/T3 约定落在 `xtrain-tensor`(与 `scale`/`matmul` 一致),引擎只是在其上叠 tape。
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### Tape 设计:`Rc<RefCell<VarNode>>` + 逆拓扑遍历
|
||||
|
||||
```rust
|
||||
pub struct VarNode {
|
||||
value: Tensor, // 前向输出
|
||||
grad: Option<Tensor>, // 反向累加的梯度
|
||||
parents: Vec<Var>, // 计算来源
|
||||
backward: Option<BackwardFn>, // None=叶子
|
||||
}
|
||||
pub struct Var(Rc<RefCell<VarNode>>);
|
||||
type BackwardFn = Box<dyn Fn(&Tensor, &[Var])>;
|
||||
```
|
||||
|
||||
- `Var` clone 只是 bump `Rc`,**clone 共享同一节点**——这正是「扇出」的识别方式(同一 `Rc::as_ptr` 在多处出现)。
|
||||
- `backward()`:① post-order DFS 建拓扑序(按指针去重);② 把 loss(必须是标量)的 grad 种子设为 1;③ 逆序遍历,每个节点把自己的 grad 传给父节点的 backward 闭包。
|
||||
- 闭包签名 `Fn(&grad, &parents)`:给本节点已累加的 grad 和父节点列表,闭包算出各父的梯度贡献并 `push_grad` 回去。前向需要 cache 的中间量(softmax 的 `y`、rms 的 `inv_rms`、ce 的 `probs`)用 `move` 闭包捕获。
|
||||
|
||||
### 梯度累加(扇出求和)——本 Phase 的正确性核心
|
||||
|
||||
`push_grad(parent, g)` 一律走 `accumulate`:
|
||||
|
||||
```rust
|
||||
fn accumulate(&self, g: Tensor) {
|
||||
match self.grad.take() {
|
||||
None => self.grad = Some(g), // 首次
|
||||
Some(prev) => self.grad = Some(prev.add(&g)),// 扇出:SUM
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
任何节点(叶子或中间)都累加:中间节点需要完整 grad 才能继续链式;叶子的累加结果就是输出。一个张量喂多个消费者时,多路 `push_grad` 自动求和。`mul(&x, &x)` 这类「同一 `Var` 进同一节点两次」也正确:`parents=[x,x]`(同指针),两次 `push_grad` 累加,拓扑去重保证 x 只遍历一次但收齐两路。测试 `fanout_grad_accumulation` 专门验证:`y=x∘x + x∘x`,`dL/dx` 须 = 4x(四处 x 全部求和)。
|
||||
|
||||
### 各算子 backward 数学
|
||||
|
||||
记上游梯度为 `d`(=本节点输出的梯度)。
|
||||
|
||||
| op | forward | backward |
|
||||
|----|---------|----------|
|
||||
| `matmul` | `C=A@B` | `dA=d@Bᵀ`, `dB=Aᵀ@d`(复用 T3 `matmul_backward`)|
|
||||
| `add` | `a+b` | `da=d`, `db=d` |
|
||||
| `mul` | `a∘b` | `da=d∘b`, `db=d∘a` |
|
||||
| `add_bias` | `x[r,c]+bias[c]` | `dx=d`, `dbias[c]=Σ_r d[r,c]`(沿广播维求和)|
|
||||
| `scale` | `x·α` | `dx=d·α` |
|
||||
| `silu` | `x·σ(x)` | `dx=d·(σ + x·σ·(1−σ))`, `σ=σ(x)` |
|
||||
| `swiglu` | `silu(g)∘u` | 由 `silu`+`mul` 组合自动得 |
|
||||
| `rope` | rotate_half 旋转 | RoPE 是正交变换,`dx` = 用**逆(转置)旋转**作用于 `d`(角度 +θ 的转置 ≡ −θ)|
|
||||
| `softmax` | row-wise safe softmax → `y` | Jacobian:`dx[r,c]=y[r,c]·(d[r,c] − Σ_c' d·y)` |
|
||||
| `cross_entropy` | mean NLL(softmax(x), tgt) | `dx = (probs − onehot)/rows`,再乘上游标量 grad |
|
||||
|
||||
**RMSNorm**(`y[r,c]=x[r,c]·ir·γ[c]`, `ir=rsqrt(mean(x²)+eps)`):
|
||||
设 `g[c]=d[r,c]·γ[c]`,`n=cols`,
|
||||
```
|
||||
dx[r,c] = ir·g[c] − x[r,c]·ir³/n·Σ_c'(g[c']·x[r,c'])
|
||||
dγ[c] = Σ_r d[r,c]·x[r,c]·ir
|
||||
```
|
||||
前向 cache 每行 `inv_rms[r]`,backward 直接复用,避免重算 reduce。
|
||||
|
||||
**RoPE 反向推导**:前向是 2×2 旋转矩阵 `R(θ)`,正交 ⇒ `Rᵀ = R(−θ)`。故
|
||||
```
|
||||
dx[i] = d[i]·cos + d[i+h]·sin
|
||||
dx[i+h] = d[i+h]·cos − d[i]·sin
|
||||
```
|
||||
position=0 时旋转是恒等,backward 也恒等(sanity check)。
|
||||
|
||||
**Softmax 反向推导**:`∂y_i/∂x_j = y_i(δ_ij − y_j)`,链式后
|
||||
`dx_i = Σ_j d_j·y_i(δ_ij − y_j) = y_i(d_i − Σ_j d_j y_j)`,即每行减去 `Σ(d∘y)` 后乘 `y`。
|
||||
|
||||
**Cross-entropy 反向推导**:`L=−log softmax(x)[t]`,softmax+NLL 的经典结果 `∂L/∂x_c = softmax_c − [c=t]`;取 batch 平均 ⇒ 除以 rows。kernel 把 `scale=upstream/rows` 折进去。
|
||||
|
||||
### Attention 用组合,不写 fused kernel
|
||||
|
||||
```
|
||||
Kᵀ = transpose(K)
|
||||
scores = scale(matmul(Q, Kᵀ), 1/√d) # [s,s]
|
||||
probs = softmax(scores)
|
||||
out = matmul(probs, V) # [s,d]
|
||||
```
|
||||
|
||||
每一步都是已有 autograd 节点,`backward()` 自动沿 matmul→softmax→scale→matmul 链回传,得到 `dQ/dK/dV`,无需手写 attention backward。测试 `attention_composed_bwd` 单头单 batch 端到端 grad-check Q/K/V 三者。(transpose 在测试里用一个临时 `Var::from_op` 节点包,因为引擎暂未把 transpose 列为 op——T5 若需要再补。)
|
||||
|
||||
### kernel 实现要点
|
||||
|
||||
- `nn.cu` 自带 inlined `warp/block_reduce_sum/max`(不引外部头文件,与现有 csrc/ 单文件风格一致);block-reduce 末尾广播到全 block,便于 softmax/rms 的「标量广播」模式。
|
||||
- 每个 op 各自 `cudaDeviceSynchronize()`(T3 约定,无 stream)。
|
||||
- 全 F32、row-major、contiguous;cross-entropy target 用新增的 `DType::I32`。
|
||||
|
||||
## 验证方法
|
||||
|
||||
模板沿用 T3 `gemm.rs::run_bwd`:标量 loss `L = sum(W∘out)`,`W` 固定随机 ⇒ 上游 `dOut = W`;跑 op 的 `backward()` 拿 `.grad()`,对每个输入用 `grad_check` 与中心差分对拍。
|
||||
|
||||
- **每算子**一个 grad-check(线性/双线性 op 用大 eps=1e-2、rel_tol=2e-2;非线性 op 用 eps=1e-3、rel_tol=3e-2、atol=1e-3 压住近零梯度)。
|
||||
- **扇出累加**:`fanout_grad_accumulation`,验证 `dL/dx=4x`。
|
||||
- **组合 attention**:`attention_composed_bwd`,端到端 grad-check `dQ/dK/dV`。
|
||||
- 全部 `#![cfg(not(no_cuda))]` 门控;本地只 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090, sm_120),capture 每 op 的 pass + max rel-err。
|
||||
Reference in New Issue
Block a user