gemm: GPU acceptance tests vs cuBLAS + finite-diff
Forward: hand-written tiled GEMM vs cuBLAS sgemm on random matrices (square / non-tile-aligned rect / 256³), max relative error < 1e-3, using the row-major⟺col-major identity to drive cuBLAS without explicit transposes. Backward: scalar loss L = sum(W∘C) (so dC = W), dA/dB from matmul_backward checked against the finite-diff harness. Gated behind not(no_cuda). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -7,3 +7,6 @@ edition.workspace = true
|
||||
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
half.workspace = true
|
||||
smallvec.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
xtrain-autodiff = { path = "../xtrain-autodiff" }
|
||||
|
||||
183
crates/xtrain-tensor/tests/gemm.rs
Normal file
183
crates/xtrain-tensor/tests/gemm.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
// GPU acceptance tests for the hand-written GEMM forward + backward (Phase T3).
|
||||
// Gated behind `not(no_cuda)`: on a GPU-less machine these compile out so host
|
||||
// `cargo check` stays green; they run on dash5.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::{GradCheckConfig, grad_check};
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
// Deterministic pseudo-random fill in [-0.5, 0.5), seeded by a linear
|
||||
// congruential generator so tests are reproducible without an RNG dep.
|
||||
fn fill(n: usize, seed: u64) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn require_gpu() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
// --- cuBLAS reference (correctness oracle for the hand-written kernel) ---
|
||||
|
||||
/// Row-major `C = A @ B` via cuBLAS `Sgemm` (which is column-major).
|
||||
/// Identity: row-major C = A@B ⟺ column-major Cᵀ = Bᵀ @ Aᵀ. We hand cuBLAS
|
||||
/// our row-major B and A as-is (it reads them as the col-major transposes) with
|
||||
/// OP_N, swapped order, and m=N, n=M, k=K. Output lands row-major in `c`.
|
||||
fn cublas_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
let c = Tensor::zeros(&[m, n], xtrain_tensor::DType::F32, a.device());
|
||||
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
unsafe {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
assert_eq!(xtrain_cuda::ffi::cublasCreate_v2(&mut handle), 0);
|
||||
let status = xtrain_cuda::ffi::cublasSgemm_v2(
|
||||
handle,
|
||||
xtrain_cuda::ffi::CUBLAS_OP_N,
|
||||
xtrain_cuda::ffi::CUBLAS_OP_N,
|
||||
n as i32,
|
||||
m as i32,
|
||||
k as i32,
|
||||
&alpha,
|
||||
b.data_ptr() as *const f32,
|
||||
n as i32,
|
||||
a.data_ptr() as *const f32,
|
||||
k as i32,
|
||||
&beta,
|
||||
c.data_ptr() as *mut f32,
|
||||
n as i32,
|
||||
);
|
||||
assert_eq!(status, 0, "cublasSgemm failed: {status}");
|
||||
device::synchronize().unwrap();
|
||||
xtrain_cuda::ffi::cublasDestroy_v2(handle);
|
||||
}
|
||||
c
|
||||
}
|
||||
|
||||
fn max_rel_err(got: &[f32], reference: &[f32]) -> f32 {
|
||||
got.iter()
|
||||
.zip(reference)
|
||||
.map(|(g, r)| (g - r).abs() / (g.abs() + r.abs() + 1e-6))
|
||||
.fold(0.0f32, f32::max)
|
||||
}
|
||||
|
||||
// --- Forward: hand-written tiled GEMM vs cuBLAS sgemm ---
|
||||
|
||||
fn run_fwd(m: usize, k: usize, n: usize) {
|
||||
require_gpu();
|
||||
let a = Tensor::from_slice(&fill(m * k, 1), &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&fill(k * n, 2), &[k, n]).to_device(Device::Cuda(0));
|
||||
|
||||
let mine = a.matmul(&b).to_device(Device::Cpu);
|
||||
let reference = cublas_matmul(&a, &b).to_device(Device::Cpu);
|
||||
|
||||
let rel = max_rel_err(mine.as_slice::<f32>(), reference.as_slice::<f32>());
|
||||
println!("fwd GEMM [{m}x{k}]@[{k}x{n}] vs cuBLAS: max_rel_err = {rel:.3e}");
|
||||
assert!(rel < 1e-3, "fwd rel-err {rel} too high for {m}x{k}x{n}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fwd_square() {
|
||||
run_fwd(64, 64, 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fwd_rect() {
|
||||
run_fwd(65, 97, 33); // non-tile-aligned dims exercise the boundary masking
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fwd_large() {
|
||||
run_fwd(256, 256, 256);
|
||||
}
|
||||
|
||||
// --- Backward: dA, dB vs the finite-difference harness ---
|
||||
//
|
||||
// Scalar loss L = sum(W ∘ C) with C = A @ B and W fixed random weights.
|
||||
// Then dC = W, dA = dC @ Bᵀ, dB = Aᵀ @ dC (matmul_backward). We check each of
|
||||
// dA and dB against central differences of L w.r.t. that input.
|
||||
|
||||
fn run_bwd(m: usize, k: usize, n: usize) {
|
||||
require_gpu();
|
||||
|
||||
let a_host = fill(m * k, 11);
|
||||
let b_host = fill(k * n, 22);
|
||||
let w_host = fill(m * n, 33); // loss weights, == dC
|
||||
|
||||
let a = Tensor::from_slice(&a_host, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_host, &[k, n]).to_device(Device::Cuda(0));
|
||||
let dc = Tensor::from_slice(&w_host, &[m, n]).to_device(Device::Cuda(0));
|
||||
|
||||
let (da, db) = Tensor::matmul_backward(&a, &b, &dc);
|
||||
let da_host = da.to_device(Device::Cpu);
|
||||
let db_host = db.to_device(Device::Cpu);
|
||||
|
||||
let cfg = GradCheckConfig::default();
|
||||
|
||||
// Check dA: vary A, hold B fixed.
|
||||
let b_fixed = b_host.clone();
|
||||
let w_fixed = w_host.clone();
|
||||
let loss_a = move |a_vals: &[f32], a_shape: &[usize]| -> f32 {
|
||||
let av = Tensor::from_slice(a_vals, a_shape).to_device(Device::Cuda(0));
|
||||
let bv = Tensor::from_slice(&b_fixed, &[k, n]).to_device(Device::Cuda(0));
|
||||
let c = av.matmul(&bv).to_device(Device::Cpu);
|
||||
c.as_slice::<f32>()
|
||||
.iter()
|
||||
.zip(&w_fixed)
|
||||
.map(|(c, w)| c * w)
|
||||
.sum()
|
||||
};
|
||||
let res_a = grad_check(&a_host, &[m, k], &loss_a, da_host.as_slice::<f32>(), cfg);
|
||||
println!(
|
||||
"bwd dA [{m}x{k}]: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
||||
res_a.max_rel_err, res_a.worst_numeric, res_a.worst_analytic, res_a.worst_index
|
||||
);
|
||||
assert!(res_a.passed, "dA grad-check failed: {:?}", res_a);
|
||||
|
||||
// Check dB: vary B, hold A fixed.
|
||||
let a_fixed = a_host.clone();
|
||||
let w_fixed2 = w_host.clone();
|
||||
let loss_b = move |b_vals: &[f32], b_shape: &[usize]| -> f32 {
|
||||
let av = Tensor::from_slice(&a_fixed, &[m, k]).to_device(Device::Cuda(0));
|
||||
let bv = Tensor::from_slice(b_vals, b_shape).to_device(Device::Cuda(0));
|
||||
let c = av.matmul(&bv).to_device(Device::Cpu);
|
||||
c.as_slice::<f32>()
|
||||
.iter()
|
||||
.zip(&w_fixed2)
|
||||
.map(|(c, w)| c * w)
|
||||
.sum()
|
||||
};
|
||||
let res_b = grad_check(&b_host, &[k, n], &loss_b, db_host.as_slice::<f32>(), cfg);
|
||||
println!(
|
||||
"bwd dB [{k}x{n}]: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
||||
res_b.max_rel_err, res_b.worst_numeric, res_b.worst_analytic, res_b.worst_index
|
||||
);
|
||||
assert!(res_b.passed, "dB grad-check failed: {:?}", res_b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bwd_square() {
|
||||
run_bwd(16, 16, 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bwd_rect() {
|
||||
run_bwd(12, 20, 8);
|
||||
}
|
||||
Reference in New Issue
Block a user