Design doc covering the tiled forward, the dA/dB math + how transpose is handled (materialize + reuse forward), the cuBLAS row-major reference, and the finite-diff harness design + how T4 reuses it per-op. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
6.7 KiB
Phase: GEMM Forward/Backward + Finite-Diff — Design Document
Goal
在 T2 的 Tensor 之上交付手写 GEMM 的前向 + 反向,以及一个可复用的有限差分梯度检查 harness。
这是 autograd(T4)的地基:T4 的每个算子 backward 都会用同一个 harness 对拍。本 Phase 只做三件事:
- GEMM 前向:手写 tiled CUDA kernel,
C = A @ B,row-major F32,正确性优先; - GEMM 反向:给上游梯度
dC,算dA = dC · Bᵀ、dB = Aᵀ · dC(复用前向 + 一个 transpose kernel); - finite-diff harness:给标量 loss
f(x)和解析梯度g,逐元素用中心差分核对。
明确不做(留给 T4+):autograd tape / 计算图、其它算子、broadcast、非 contiguous 输入、半精度、性能调优(T7)。
Module Layout
csrc/ops/gemm.cu # gemm_tiled_f32 + transpose_f32 + launch_*(由 xtrain-cuda/build.rs 编)
crates/xtrain-cuda/
├── build.rs # 新增 gemm.cu;链接 cublas(仅作前向对拍参考)
└── src/ffi.rs # 新增 launch_gemm_tiled_f32 / launch_transpose_f32(no_cuda 门控)
# + cuBLAS sgemm FFI(no_cuda 门控,仅测试用)
crates/xtrain-tensor/
├── Cargo.toml # dev-dep: xtrain-autodiff
├── src/tensor.rs # matmul / transpose_2d / matmul_backward(no_cuda 门控)
└── tests/gemm.rs # 前向对 cuBLAS、反向对 finite-diff(#![cfg(not(no_cuda))])
crates/xtrain-autodiff/ # 新 crate(reusable harness)
├── Cargo.toml
└── src/
├── lib.rs
└── finite_diff.rs # grad_check + GradCheckConfig/Result
Key Design Decisions
前向:tiled GEMM(沿用 xserv 风格)
csrc/ops/gemm.cu 的 gemm_tiled_f32:TILE_SIZE=32,每个 block 协作把 A/B 的 tile 载入 shared
memory,FP32 累加。这就是 xserv csrc/gemm/tiled.cu 的 F32 版(去掉 BF16)——正确性已被 xserv 验证过,
边界 mask(row<M && col<N、tile 越界填 0)让非 32 对齐的 M/N/K 也对。
张量层 Tensor::matmul(&self, other) 镜像 T2 scale 的接线方式:校验 2D / contiguous / F32 / 同卡 →
zeros 输出 → launch_gemm_tiled_f32 → synchronize。
反向:dA / dB 的数学与 transpose 处理
C = A·B,A:[M,K]、B:[K,N]、C:[M,N]。给上游 dC:[M,N]:
dA = dC · Bᵀ # [M,N] · [N,K] = [M,K]
dB = Aᵀ · dC # [K,M] · [M,N] = [K,N]
推导(标量 loss L,C_ij = Σ_k A_ik B_kj):
∂L/∂A_ik = Σ_j (∂L/∂C_ij)(∂C_ij/∂A_ik) = Σ_j dC_ij · B_kj = (dC·Bᵀ)_ik;
∂L/∂B_kj = Σ_i dC_ij · A_ik = (Aᵀ·dC)_kj。
transpose 怎么处理:T2 没有 transpose/reshape。这里加一个最小的 out-of-place transpose
kernel(transpose_f32:out[j,i]=in[i,j],写出连续的 [cols,rows]),matmul_backward 先把
B/A 物化转置成连续张量,再复用 matmul:
let da = dc.matmul(&b.transpose_2d()); // [M,N] @ [N,K]
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N]
选「物化转置 + 复用前向」而非「带 transpose flag 的 GEMM 变体」:代码量最小、复用已对拍过的前向 kernel、 逻辑一眼能看懂。代价是多一次拷贝 + 临时显存——属于 T7 性能范畴(届时可改成 transpose-flag GEMM 或 non-contiguous GEMM 来省掉)。
cuBLAS 仅作前向参考(row-major vs col-major)
cuBLAS Sgemm 是 column-major,而我们 row-major。用恒等式
row-major C=A·B ⟺ col-major Cᵀ=Bᵀ·Aᵀ:把 row-major 的 B、A 原样喂给 cuBLAS(它当成各自的
col-major 转置),传 OP_N, OP_N,参数 m=N, n=M, k=K, lda=N, ldb=K, ldc=N,输出原地就是 row-major
的 C,不做任何显式转置。cuBLAS FFI(cublasSgemm_v2 等)声明在 xtrain-cuda/src/ffi.rs,和
kernel FFI 一样 #[cfg(not(no_cuda))] 门控,build.rs 加 -lcublas。它只在测试里当 oracle,不进
production 路径。
finite-diff harness(T4 的复用核心)
xtrain-autodiff::grad_check:
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a; // (flat 参数, shape) -> 标量 loss
pub fn grad_check(x: &[f32], shape: &[usize], f: &ParamFn,
analytic_grad: &[f32], cfg: GradCheckConfig) -> GradCheckResult;
- 逐元素中心差分:
(f(x+ε·eᵢ) − f(x−ε·eᵢ)) / 2ε,共调f两次x.len()遍。 - 相对误差判据:
|num − ana| / (|num| + |ana| + atol),atol防近零梯度炸比值。 - 默认
eps=1e-3 / rel_tol=2e-2 / atol=1e-4:eps平衡截断误差(∝eps²)与 f32 在f(x±eps)上 ~1e-7 的舍入噪声;rel_tol=2e-2是 f32 GPU GEMM 对 f32 中心差分的常规裕度。 - host-only、不依赖 CUDA:harness 只碰
&[f32]+shape+闭包;GPU 前向藏在闭包里(from_slice → to_device → matmul → sum)。所以这个 crate 本地能 build/test,自带两个纯 host 单测(sum(x²)梯度2x通过;错误梯度被拒)。
T4 怎么用:每个算子 wrap forward 为标量 loss(L = sum(W∘out)),跑该算子 backward 拿解析梯度,
对每个输入张量调一次 grad_check。本 Phase 的 GEMM 反向测试就是这个模式的样板。
验证方法
GPU 测试 #![cfg(not(no_cuda))] 门控,dash5 实跑:
ssh dash5
export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
cd ~/projects/xtrain
cargo test -p xtrain-autodiff # host-only harness 自检
cargo test -p xtrain-tensor --test gemm -- --nocapture
- 前向:手写 tiled GEMM vs cuBLAS sgemm,随机矩阵(square / rect 非对齐 / 256³),
max_rel_err < 1e-3。 - 反向:标量 loss
L = sum(W∘C)(dC = W),matmul_backward出dA/dB,各自对 finite-diff harness 核对,rel_tol = 2e-2内通过。
本地(无 GPU):cargo check --workspace --all-targets + cargo fmt --all -- --check 绿;GPU 测试编译出局。
Takeaways
- 反向 = 前向 + 转置的复用:
dA=dC·Bᵀ、dB=Aᵀ·dC全靠已对拍过的前向 kernel + 一个 transpose kernel 拼出来,不必为 backward 单写 kernel。代价(多余拷贝)留给 T7。 - harness 与设备解耦:把 finite-diff 设计成「只吃
&[f32]+闭包」,GPU 细节进闭包,于是 harness 本地可测、且 T4 任何算子(不限 GEMM)都能直接复用。 - cuBLAS 收口在 FFI + 门控:参考实现也守
no_cuda约定,本地 check 不被 cublas 链接拖累。 - row-major/col-major 是 GEMM 最大坑:靠
C=AB ⟺ Cᵀ=BᵀAᵀ交换 A/B 传入顺序 + 调 ld,零显式转置。