docs: Phase T3 — GEMM fwd/bwd + finite-diff
Design doc covering the tiled forward, the dA/dB math + how transpose is handled (materialize + reuse forward), the cuBLAS row-major reference, and the finite-diff harness design + how T4 reuses it per-op. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
130
docs/02-gemm-autodiff.md
Normal file
130
docs/02-gemm-autodiff.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# Phase: GEMM Forward/Backward + Finite-Diff — Design Document
|
||||
|
||||
## Goal
|
||||
|
||||
在 T2 的 `Tensor` 之上交付**手写 GEMM 的前向 + 反向**,以及一个**可复用的有限差分梯度检查 harness**。
|
||||
这是 autograd(T4)的地基:T4 的每个算子 backward 都会用同一个 harness 对拍。本 Phase 只做三件事:
|
||||
|
||||
1. **GEMM 前向**:手写 tiled CUDA kernel,`C = A @ B`,row-major F32,正确性优先;
|
||||
2. **GEMM 反向**:给上游梯度 `dC`,算 `dA = dC · Bᵀ`、`dB = Aᵀ · dC`(复用前向 + 一个 transpose kernel);
|
||||
3. **finite-diff harness**:给标量 loss `f(x)` 和解析梯度 `g`,逐元素用中心差分核对。
|
||||
|
||||
**明确不做**(留给 T4+):autograd tape / 计算图、其它算子、broadcast、非 contiguous 输入、半精度、性能调优(T7)。
|
||||
|
||||
## Module Layout
|
||||
|
||||
```
|
||||
csrc/ops/gemm.cu # gemm_tiled_f32 + transpose_f32 + launch_*(由 xtrain-cuda/build.rs 编)
|
||||
|
||||
crates/xtrain-cuda/
|
||||
├── build.rs # 新增 gemm.cu;链接 cublas(仅作前向对拍参考)
|
||||
└── src/ffi.rs # 新增 launch_gemm_tiled_f32 / launch_transpose_f32(no_cuda 门控)
|
||||
# + cuBLAS sgemm FFI(no_cuda 门控,仅测试用)
|
||||
|
||||
crates/xtrain-tensor/
|
||||
├── Cargo.toml # dev-dep: xtrain-autodiff
|
||||
├── src/tensor.rs # matmul / transpose_2d / matmul_backward(no_cuda 门控)
|
||||
└── tests/gemm.rs # 前向对 cuBLAS、反向对 finite-diff(#![cfg(not(no_cuda))])
|
||||
|
||||
crates/xtrain-autodiff/ # 新 crate(reusable harness)
|
||||
├── Cargo.toml
|
||||
└── src/
|
||||
├── lib.rs
|
||||
└── finite_diff.rs # grad_check + GradCheckConfig/Result
|
||||
```
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### 前向:tiled GEMM(沿用 xserv 风格)
|
||||
|
||||
`csrc/ops/gemm.cu` 的 `gemm_tiled_f32`:`TILE_SIZE=32`,每个 block 协作把 A/B 的 tile 载入 shared
|
||||
memory,FP32 累加。这就是 xserv `csrc/gemm/tiled.cu` 的 F32 版(去掉 BF16)——正确性已被 xserv 验证过,
|
||||
边界 mask(`row<M && col<N`、tile 越界填 0)让非 32 对齐的 M/N/K 也对。
|
||||
|
||||
张量层 `Tensor::matmul(&self, other)` 镜像 T2 `scale` 的接线方式:校验 2D / contiguous / F32 / 同卡 →
|
||||
`zeros` 输出 → `launch_gemm_tiled_f32` → `synchronize`。
|
||||
|
||||
### 反向:dA / dB 的数学与 transpose 处理
|
||||
|
||||
`C = A·B`,`A:[M,K]`、`B:[K,N]`、`C:[M,N]`。给上游 `dC:[M,N]`:
|
||||
|
||||
```
|
||||
dA = dC · Bᵀ # [M,N] · [N,K] = [M,K]
|
||||
dB = Aᵀ · dC # [K,M] · [M,N] = [K,N]
|
||||
```
|
||||
|
||||
推导(标量 loss `L`,`C_ij = Σ_k A_ik B_kj`):
|
||||
`∂L/∂A_ik = Σ_j (∂L/∂C_ij)(∂C_ij/∂A_ik) = Σ_j dC_ij · B_kj = (dC·Bᵀ)_ik`;
|
||||
`∂L/∂B_kj = Σ_i dC_ij · A_ik = (Aᵀ·dC)_kj`。
|
||||
|
||||
**transpose 怎么处理**:T2 没有 transpose/reshape。这里加一个**最小的 out-of-place transpose
|
||||
kernel**(`transpose_f32`:`out[j,i]=in[i,j]`,写出连续的 `[cols,rows]`),`matmul_backward` 先把
|
||||
`B`/`A` 物化转置成连续张量,再复用 `matmul`:
|
||||
|
||||
```rust
|
||||
let da = dc.matmul(&b.transpose_2d()); // [M,N] @ [N,K]
|
||||
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N]
|
||||
```
|
||||
|
||||
选「物化转置 + 复用前向」而非「带 transpose flag 的 GEMM 变体」:代码量最小、复用已对拍过的前向 kernel、
|
||||
逻辑一眼能看懂。代价是多一次拷贝 + 临时显存——属于 T7 性能范畴(届时可改成 transpose-flag GEMM 或
|
||||
non-contiguous GEMM 来省掉)。
|
||||
|
||||
### cuBLAS 仅作前向参考(row-major vs col-major)
|
||||
|
||||
cuBLAS `Sgemm` 是 column-major,而我们 row-major。用恒等式
|
||||
`row-major C=A·B ⟺ col-major Cᵀ=Bᵀ·Aᵀ`:把 row-major 的 B、A 原样喂给 cuBLAS(它当成各自的
|
||||
col-major 转置),传 `OP_N, OP_N`,参数 `m=N, n=M, k=K, lda=N, ldb=K, ldc=N`,输出原地就是 row-major
|
||||
的 `C`,**不做任何显式转置**。cuBLAS FFI(`cublasSgemm_v2` 等)声明在 `xtrain-cuda/src/ffi.rs`,和
|
||||
kernel FFI 一样 `#[cfg(not(no_cuda))]` 门控,`build.rs` 加 `-lcublas`。它只在测试里当 oracle,不进
|
||||
production 路径。
|
||||
|
||||
### finite-diff harness(T4 的复用核心)
|
||||
|
||||
`xtrain-autodiff::grad_check`:
|
||||
|
||||
```rust
|
||||
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a; // (flat 参数, shape) -> 标量 loss
|
||||
|
||||
pub fn grad_check(x: &[f32], shape: &[usize], f: &ParamFn,
|
||||
analytic_grad: &[f32], cfg: GradCheckConfig) -> GradCheckResult;
|
||||
```
|
||||
|
||||
- 逐元素中心差分:`(f(x+ε·eᵢ) − f(x−ε·eᵢ)) / 2ε`,共调 `f` 两次 `x.len()` 遍。
|
||||
- 相对误差判据:`|num − ana| / (|num| + |ana| + atol)`,`atol` 防近零梯度炸比值。
|
||||
- 默认 `eps=1e-3 / rel_tol=2e-2 / atol=1e-4`:`eps` 平衡截断误差(∝eps²)与 f32 在 `f(x±eps)` 上
|
||||
~1e-7 的舍入噪声;`rel_tol=2e-2` 是 f32 GPU GEMM 对 f32 中心差分的常规裕度。
|
||||
- **host-only、不依赖 CUDA**:harness 只碰 `&[f32]`+shape+闭包;GPU 前向藏在闭包里(`from_slice →
|
||||
to_device → matmul → sum`)。所以这个 crate 本地能 build/test,自带两个纯 host 单测(`sum(x²)` 梯度
|
||||
`2x` 通过;错误梯度被拒)。
|
||||
|
||||
**T4 怎么用**:每个算子 wrap forward 为标量 loss(`L = sum(W∘out)`),跑该算子 backward 拿解析梯度,
|
||||
对每个输入张量调一次 `grad_check`。本 Phase 的 GEMM 反向测试就是这个模式的样板。
|
||||
|
||||
## 验证方法
|
||||
|
||||
GPU 测试 `#![cfg(not(no_cuda))]` 门控,dash5 实跑:
|
||||
|
||||
```sh
|
||||
ssh dash5
|
||||
export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
cd ~/projects/xtrain
|
||||
cargo test -p xtrain-autodiff # host-only harness 自检
|
||||
cargo test -p xtrain-tensor --test gemm -- --nocapture
|
||||
```
|
||||
|
||||
- **前向**:手写 tiled GEMM vs cuBLAS sgemm,随机矩阵(square / rect 非对齐 / 256³),`max_rel_err < 1e-3`。
|
||||
- **反向**:标量 loss `L = sum(W∘C)`(`dC = W`),`matmul_backward` 出 `dA`/`dB`,各自对 finite-diff
|
||||
harness 核对,`rel_tol = 2e-2` 内通过。
|
||||
|
||||
本地(无 GPU):`cargo check --workspace --all-targets` + `cargo fmt --all -- --check` 绿;GPU 测试编译出局。
|
||||
|
||||
## Takeaways
|
||||
|
||||
1. **反向 = 前向 + 转置的复用**:`dA=dC·Bᵀ`、`dB=Aᵀ·dC` 全靠已对拍过的前向 kernel + 一个 transpose
|
||||
kernel 拼出来,不必为 backward 单写 kernel。代价(多余拷贝)留给 T7。
|
||||
2. **harness 与设备解耦**:把 finite-diff 设计成「只吃 `&[f32]`+闭包」,GPU 细节进闭包,于是 harness
|
||||
本地可测、且 T4 任何算子(不限 GEMM)都能直接复用。
|
||||
3. **cuBLAS 收口在 FFI + 门控**:参考实现也守 `no_cuda` 约定,本地 check 不被 cublas 链接拖累。
|
||||
4. **row-major/col-major 是 GEMM 最大坑**:靠 `C=AB ⟺ Cᵀ=BᵀAᵀ` 交换 A/B 传入顺序 + 调 ld,零显式转置。
|
||||
```
|
||||
Reference in New Issue
Block a user