From dde2fde297bc2823da435c5d525b973391a4c5ab Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 15 Jun 2026 15:27:03 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20Phase=20T3=20=E2=80=94=20GEMM=20fwd/bwd?= =?UTF-8?q?=20+=20finite-diff?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Design doc covering the tiled forward, the dA/dB math + how transpose is handled (materialize + reuse forward), the cuBLAS row-major reference, and the finite-diff harness design + how T4 reuses it per-op. Co-Authored-By: Claude Opus 4.8 --- docs/02-gemm-autodiff.md | 130 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 docs/02-gemm-autodiff.md diff --git a/docs/02-gemm-autodiff.md b/docs/02-gemm-autodiff.md new file mode 100644 index 0000000..a91b494 --- /dev/null +++ b/docs/02-gemm-autodiff.md @@ -0,0 +1,130 @@ +# Phase: GEMM Forward/Backward + Finite-Diff — Design Document + +## Goal + +在 T2 的 `Tensor` 之上交付**手写 GEMM 的前向 + 反向**,以及一个**可复用的有限差分梯度检查 harness**。 +这是 autograd(T4)的地基:T4 的每个算子 backward 都会用同一个 harness 对拍。本 Phase 只做三件事: + +1. **GEMM 前向**:手写 tiled CUDA kernel,`C = A @ B`,row-major F32,正确性优先; +2. **GEMM 反向**:给上游梯度 `dC`,算 `dA = dC · Bᵀ`、`dB = Aᵀ · dC`(复用前向 + 一个 transpose kernel); +3. **finite-diff harness**:给标量 loss `f(x)` 和解析梯度 `g`,逐元素用中心差分核对。 + +**明确不做**(留给 T4+):autograd tape / 计算图、其它算子、broadcast、非 contiguous 输入、半精度、性能调优(T7)。 + +## Module Layout + +``` +csrc/ops/gemm.cu # gemm_tiled_f32 + transpose_f32 + launch_*(由 xtrain-cuda/build.rs 编) + +crates/xtrain-cuda/ +├── build.rs # 新增 gemm.cu;链接 cublas(仅作前向对拍参考) +└── src/ffi.rs # 新增 launch_gemm_tiled_f32 / launch_transpose_f32(no_cuda 门控) + # + cuBLAS sgemm FFI(no_cuda 门控,仅测试用) + +crates/xtrain-tensor/ +├── Cargo.toml # dev-dep: xtrain-autodiff +├── src/tensor.rs # matmul / transpose_2d / matmul_backward(no_cuda 门控) +└── tests/gemm.rs # 前向对 cuBLAS、反向对 finite-diff(#![cfg(not(no_cuda))]) + +crates/xtrain-autodiff/ # 新 crate(reusable harness) +├── Cargo.toml +└── src/ + ├── lib.rs + └── finite_diff.rs # grad_check + GradCheckConfig/Result +``` + +## Key Design Decisions + +### 前向:tiled GEMM(沿用 xserv 风格) + +`csrc/ops/gemm.cu` 的 `gemm_tiled_f32`:`TILE_SIZE=32`,每个 block 协作把 A/B 的 tile 载入 shared +memory,FP32 累加。这就是 xserv `csrc/gemm/tiled.cu` 的 F32 版(去掉 BF16)——正确性已被 xserv 验证过, +边界 mask(`row = dyn Fn(&[f32], &[usize]) -> f32 + 'a; // (flat 参数, shape) -> 标量 loss + +pub fn grad_check(x: &[f32], shape: &[usize], f: &ParamFn, + analytic_grad: &[f32], cfg: GradCheckConfig) -> GradCheckResult; +``` + +- 逐元素中心差分:`(f(x+ε·eᵢ) − f(x−ε·eᵢ)) / 2ε`,共调 `f` 两次 `x.len()` 遍。 +- 相对误差判据:`|num − ana| / (|num| + |ana| + atol)`,`atol` 防近零梯度炸比值。 +- 默认 `eps=1e-3 / rel_tol=2e-2 / atol=1e-4`:`eps` 平衡截断误差(∝eps²)与 f32 在 `f(x±eps)` 上 + ~1e-7 的舍入噪声;`rel_tol=2e-2` 是 f32 GPU GEMM 对 f32 中心差分的常规裕度。 +- **host-only、不依赖 CUDA**:harness 只碰 `&[f32]`+shape+闭包;GPU 前向藏在闭包里(`from_slice → + to_device → matmul → sum`)。所以这个 crate 本地能 build/test,自带两个纯 host 单测(`sum(x²)` 梯度 + `2x` 通过;错误梯度被拒)。 + +**T4 怎么用**:每个算子 wrap forward 为标量 loss(`L = sum(W∘out)`),跑该算子 backward 拿解析梯度, +对每个输入张量调一次 `grad_check`。本 Phase 的 GEMM 反向测试就是这个模式的样板。 + +## 验证方法 + +GPU 测试 `#![cfg(not(no_cuda))]` 门控,dash5 实跑: + +```sh +ssh dash5 +export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH +cd ~/projects/xtrain +cargo test -p xtrain-autodiff # host-only harness 自检 +cargo test -p xtrain-tensor --test gemm -- --nocapture +``` + +- **前向**:手写 tiled GEMM vs cuBLAS sgemm,随机矩阵(square / rect 非对齐 / 256³),`max_rel_err < 1e-3`。 +- **反向**:标量 loss `L = sum(W∘C)`(`dC = W`),`matmul_backward` 出 `dA`/`dB`,各自对 finite-diff + harness 核对,`rel_tol = 2e-2` 内通过。 + +本地(无 GPU):`cargo check --workspace --all-targets` + `cargo fmt --all -- --check` 绿;GPU 测试编译出局。 + +## Takeaways + +1. **反向 = 前向 + 转置的复用**:`dA=dC·Bᵀ`、`dB=Aᵀ·dC` 全靠已对拍过的前向 kernel + 一个 transpose + kernel 拼出来,不必为 backward 单写 kernel。代价(多余拷贝)留给 T7。 +2. **harness 与设备解耦**:把 finite-diff 设计成「只吃 `&[f32]`+闭包」,GPU 细节进闭包,于是 harness + 本地可测、且 T4 任何算子(不限 GEMM)都能直接复用。 +3. **cuBLAS 收口在 FFI + 门控**:参考实现也守 `no_cuda` 约定,本地 check 不被 cublas 链接拖累。 +4. **row-major/col-major 是 GEMM 最大坑**:靠 `C=AB ⟺ Cᵀ=BᵀAᵀ` 交换 A/B 传入顺序 + 调 ld,零显式转置。 +```