Files
xtrain/docs/02-gemm-autodiff.md
Gahow Wang dde2fde297 docs: Phase T3 — GEMM fwd/bwd + finite-diff
Design doc covering the tiled forward, the dA/dB math + how transpose is
handled (materialize + reuse forward), the cuBLAS row-major reference, and
the finite-diff harness design + how T4 reuses it per-op.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:27:03 +08:00

131 lines
6.7 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase: GEMM Forward/Backward + Finite-Diff — Design Document
## Goal
在 T2 的 `Tensor` 之上交付**手写 GEMM 的前向 + 反向**,以及一个**可复用的有限差分梯度检查 harness**。
这是 autogradT4的地基T4 的每个算子 backward 都会用同一个 harness 对拍。本 Phase 只做三件事:
1. **GEMM 前向**:手写 tiled CUDA kernel`C = A @ B`row-major F32正确性优先
2. **GEMM 反向**:给上游梯度 `dC`,算 `dA = dC · Bᵀ``dB = Aᵀ · dC`(复用前向 + 一个 transpose kernel
3. **finite-diff harness**:给标量 loss `f(x)` 和解析梯度 `g`,逐元素用中心差分核对。
**明确不做**(留给 T4+autograd tape / 计算图、其它算子、broadcast、非 contiguous 输入、半精度、性能调优T7
## Module Layout
```
csrc/ops/gemm.cu # gemm_tiled_f32 + transpose_f32 + launch_*(由 xtrain-cuda/build.rs 编)
crates/xtrain-cuda/
├── build.rs # 新增 gemm.cu链接 cublas仅作前向对拍参考
└── src/ffi.rs # 新增 launch_gemm_tiled_f32 / launch_transpose_f32no_cuda 门控)
# + cuBLAS sgemm FFIno_cuda 门控,仅测试用)
crates/xtrain-tensor/
├── Cargo.toml # dev-dep: xtrain-autodiff
├── src/tensor.rs # matmul / transpose_2d / matmul_backwardno_cuda 门控)
└── tests/gemm.rs # 前向对 cuBLAS、反向对 finite-diff#![cfg(not(no_cuda))]
crates/xtrain-autodiff/ # 新 cratereusable harness
├── Cargo.toml
└── src/
├── lib.rs
└── finite_diff.rs # grad_check + GradCheckConfig/Result
```
## Key Design Decisions
### 前向tiled GEMM沿用 xserv 风格)
`csrc/ops/gemm.cu``gemm_tiled_f32``TILE_SIZE=32`,每个 block 协作把 A/B 的 tile 载入 shared
memoryFP32 累加。这就是 xserv `csrc/gemm/tiled.cu` 的 F32 版(去掉 BF16——正确性已被 xserv 验证过,
边界 mask`row<M && col<N`、tile 越界填 0让非 32 对齐的 M/N/K 也对。
张量层 `Tensor::matmul(&self, other)` 镜像 T2 `scale` 的接线方式:校验 2D / contiguous / F32 / 同卡 →
`zeros` 输出 → `launch_gemm_tiled_f32``synchronize`
### 反向dA / dB 的数学与 transpose 处理
`C = A·B``A:[M,K]``B:[K,N]``C:[M,N]`。给上游 `dC:[M,N]`
```
dA = dC · Bᵀ # [M,N] · [N,K] = [M,K]
dB = Aᵀ · dC # [K,M] · [M,N] = [K,N]
```
推导(标量 loss `L``C_ij = Σ_k A_ik B_kj`
`∂L/∂A_ik = Σ_j (∂L/∂C_ij)(∂C_ij/∂A_ik) = Σ_j dC_ij · B_kj = (dC·Bᵀ)_ik`
`∂L/∂B_kj = Σ_i dC_ij · A_ik = (Aᵀ·dC)_kj`
**transpose 怎么处理**T2 没有 transpose/reshape。这里加一个**最小的 out-of-place transpose
kernel**`transpose_f32``out[j,i]=in[i,j]`,写出连续的 `[cols,rows]``matmul_backward` 先把
`B`/`A` 物化转置成连续张量,再复用 `matmul`
```rust
let da = dc.matmul(&b.transpose_2d()); // [M,N] @ [N,K]
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N]
```
选「物化转置 + 复用前向」而非「带 transpose flag 的 GEMM 变体」:代码量最小、复用已对拍过的前向 kernel、
逻辑一眼能看懂。代价是多一次拷贝 + 临时显存——属于 T7 性能范畴(届时可改成 transpose-flag GEMM 或
non-contiguous GEMM 来省掉)。
### cuBLAS 仅作前向参考row-major vs col-major
cuBLAS `Sgemm` 是 column-major而我们 row-major。用恒等式
`row-major C=A·B ⟺ col-major Cᵀ=Bᵀ·Aᵀ`:把 row-major 的 B、A 原样喂给 cuBLAS它当成各自的
col-major 转置),传 `OP_N, OP_N`,参数 `m=N, n=M, k=K, lda=N, ldb=K, ldc=N`,输出原地就是 row-major
`C`**不做任何显式转置**。cuBLAS FFI`cublasSgemm_v2` 等)声明在 `xtrain-cuda/src/ffi.rs`,和
kernel FFI 一样 `#[cfg(not(no_cuda))]` 门控,`build.rs``-lcublas`。它只在测试里当 oracle不进
production 路径。
### finite-diff harnessT4 的复用核心)
`xtrain-autodiff::grad_check`
```rust
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a; // (flat 参数, shape) -> 标量 loss
pub fn grad_check(x: &[f32], shape: &[usize], f: &ParamFn,
analytic_grad: &[f32], cfg: GradCheckConfig) -> GradCheckResult;
```
- 逐元素中心差分:`(f(x+ε·eᵢ) f(xε·eᵢ)) / 2ε`,共调 `f` 两次 `x.len()` 遍。
- 相对误差判据:`|num ana| / (|num| + |ana| + atol)``atol` 防近零梯度炸比值。
- 默认 `eps=1e-3 / rel_tol=2e-2 / atol=1e-4``eps` 平衡截断误差(∝eps²)与 f32 在 `f(x±eps)`
~1e-7 的舍入噪声;`rel_tol=2e-2` 是 f32 GPU GEMM 对 f32 中心差分的常规裕度。
- **host-only、不依赖 CUDA**harness 只碰 `&[f32]`+shape+闭包GPU 前向藏在闭包里(`from_slice →
to_device → matmul → sum`)。所以这个 crate 本地能 build/test自带两个纯 host 单测(`sum(x²)` 梯度
`2x` 通过;错误梯度被拒)。
**T4 怎么用**:每个算子 wrap forward 为标量 loss`L = sum(W∘out)`),跑该算子 backward 拿解析梯度,
对每个输入张量调一次 `grad_check`。本 Phase 的 GEMM 反向测试就是这个模式的样板。
## 验证方法
GPU 测试 `#![cfg(not(no_cuda))]` 门控dash5 实跑:
```sh
ssh dash5
export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
cd ~/projects/xtrain
cargo test -p xtrain-autodiff # host-only harness 自检
cargo test -p xtrain-tensor --test gemm -- --nocapture
```
- **前向**:手写 tiled GEMM vs cuBLAS sgemm随机矩阵square / rect 非对齐 / 256³`max_rel_err < 1e-3`。
- **反向**:标量 loss `L = sum(WC)``dC = W``matmul_backward` 出 `dA`/`dB`,各自对 finite-diff
harness 核对,`rel_tol = 2e-2` 内通过。
本地(无 GPU`cargo check --workspace --all-targets` + `cargo fmt --all -- --check` 绿GPU 测试编译出局。
## Takeaways
1. **反向 = 前向 + 转置的复用**`dA=dC·Bᵀ`、`dB=Aᵀ·dC` 全靠已对拍过的前向 kernel + 一个 transpose
kernel 拼出来,不必为 backward 单写 kernel。代价多余拷贝留给 T7。
2. **harness 与设备解耦**:把 finite-diff 设计成「只吃 `&[f32]`+闭包」GPU 细节进闭包,于是 harness
本地可测、且 T4 任何算子(不限 GEMM都能直接复用。
3. **cuBLAS 收口在 FFI + 门控**:参考实现也守 `no_cuda` 约定,本地 check 不被 cublas 链接拖累。
4. **row-major/col-major 是 GEMM 最大坑**:靠 `C=AB Cᵀ=BᵀAᵀ` 交换 A/B 传入顺序 + 调 ld零显式转置。
```