perf: cuBLAS matmul fwd/bwd

Route Tensor::matmul and matmul_backward through cuBLAS Sgemm instead of
the hand-written tiled kernel. fp32 → same GEMM up to rounding order, so
the T3 cuBLAS tolerance and downstream grad-checks are preserved.

- cublas.rs: thread-local persistent handle + row-major sgemm helper with
  transpose flags (col-major⟺row-major as the T3 oracle does).
- matmul_backward: dA/dB via cuBLAS OP_T, dropping the two transpose
  kernels + their allocations the T3 version ran.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:48:35 +08:00
parent 5df1d4d57b
commit 0e5c7d22e2
4 changed files with 150 additions and 17 deletions

View File

@@ -0,0 +1,95 @@
//! cuBLAS GEMM backend (Phase T7).
//!
//! The hand-written tiled kernel (csrc/ops/gemm.cu) is kept as the T3 learning
//! artifact + correctness oracle's counterpart, but the forward + both backward
//! matmuls now route through cuBLAS `Sgemm` — fp32, so the result is numerically
//! the same GEMM (only the rounding order changes), which is why the T3 tolerance
//! against cuBLAS holds unchanged.
//!
//! **Layout.** cuBLAS is column-major; our tensors are row-major. A row-major
//! `[r,c]` matrix handed to cuBLAS with leading dim `c` is read as its transpose
//! (col-major `[c,r]`). To get a row-major result `C[m,n] = opA(A)·opB(B)` we
//! compute the col-major transpose `Cᵀ[n,m] = opB(B)ᵀ·opA(A)ᵀ`; the bytes of
//! col-major `Cᵀ` are exactly row-major `C`. See [`sgemm`] for the index algebra.
//!
//! **Handle.** cuBLAS handle creation is expensive (T3's oracle made one per
//! call). We cache one handle per thread for the lifetime of the process.
#![cfg(not(no_cuda))]
use crate::ffi::{self, CublasHandle};
use std::cell::RefCell;
thread_local! {
static HANDLE: RefCell<Option<CublasHandle>> = const { RefCell::new(None) };
}
/// Run `f` with the thread's cached cuBLAS handle, creating it on first use.
fn with_handle<R>(f: impl FnOnce(CublasHandle) -> R) -> R {
HANDLE.with(|h| {
let mut slot = h.borrow_mut();
if slot.is_none() {
let mut handle: CublasHandle = std::ptr::null_mut();
let status = unsafe { ffi::cublasCreate_v2(&mut handle) };
assert_eq!(status, 0, "cublasCreate failed: {status}");
*slot = Some(handle);
}
f(slot.unwrap())
})
}
/// Row-major single-precision GEMM: `C[m,n] = opA(A) · opB(B)` with
/// `C = alpha·(…) + beta·C`. `A`/`B`/`C` are device pointers to row-major fp32
/// matrices; `trans_a`/`trans_b` request the transpose of the *logical* operand.
///
/// `m,n,k` are the dims of the math (`opA(A)` is `[m,k]`, `opB(B)` is `[k,n]`).
/// The stored, untransposed shapes are: `A` is `[m,k]` (or `[k,m]` if `trans_a`),
/// `B` is `[k,n]` (or `[n,k]` if `trans_b`). Their row-major leading dims are the
/// stored column counts, derived below.
///
/// We ask cuBLAS for col-major `Cᵀ[n,m] = opB(B)ᵀ · opA(A)ᵀ`. Since a row-major
/// `[r,c]` buffer is col-major `[c,r]`, a row-major operand already *is* its own
/// transpose to cuBLAS — so `opB(B)ᵀ` over the row-major bytes of `B` is obtained
/// by passing `B` with the OPPOSITE op flag of what `opB` would suggest. Working
/// it through: first cuBLAS arg = `B` with op `trans_b ? N : T`, second = `A` with
/// op `trans_a ? N : T`, sizes (m=n, n=m, k=k).
#[allow(clippy::too_many_arguments)]
pub fn sgemm(
trans_a: bool,
trans_b: bool,
m: usize,
n: usize,
k: usize,
alpha: f32,
a: *const f32,
b: *const f32,
beta: f32,
c: *mut f32,
) {
// Leading dims = stored (row-major) column count of each untransposed matrix.
let lda = if trans_a { m } else { k }; // A stored [m,k] or [k,m]
let ldb = if trans_b { k } else { n }; // B stored [k,n] or [n,k]
let ldc = n; // Cᵀ is [n,m] col-major with ld n (== row-major C[m,n])
let op_b = if trans_b {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
let op_a = if trans_a {
ffi::CUBLAS_OP_T
} else {
ffi::CUBLAS_OP_N
};
with_handle(|handle| {
let status = unsafe {
ffi::cublasSgemm_v2(
handle, op_b, op_a, n as i32, // rows of Cᵀ
m as i32, // cols of Cᵀ
k as i32, &alpha, b, ldb as i32, a, lda as i32, &beta, c, ldc as i32,
)
};
assert_eq!(status, 0, "cublasSgemm failed: {status}");
});
}

View File

@@ -212,8 +212,9 @@ unsafe extern "C" {
);
}
// cuBLAS — used ONLY as a correctness reference for the hand-written GEMM in
// tests. Declared (and linked, see build.rs) only when CUDA is compiled in.
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
// when CUDA is compiled in.
#[cfg(not(no_cuda))]
pub type CublasHandle = *mut c_void;
@@ -241,3 +242,5 @@ unsafe extern "C" {
#[cfg(not(no_cuda))]
pub const CUBLAS_OP_N: i32 = 0;
#[cfg(not(no_cuda))]
pub const CUBLAS_OP_T: i32 = 1;

View File

@@ -1,3 +1,5 @@
#[cfg(not(no_cuda))]
pub mod cublas;
pub mod device;
pub mod error;
pub mod ffi;

View File

@@ -161,8 +161,9 @@ impl Tensor {
/// Matrix multiply: `C = self @ other`. `self`:[M,K], `other`:[K,N] → [M,N].
///
/// Runs the tiled `gemm_tiled_f32` CUDA kernel. Requires contiguous F32
/// tensors on the same GPU. Available only when CUDA is compiled in.
/// Routes through cuBLAS `Sgemm` (Phase T7). fp32, so it is the same GEMM as
/// the T3 tiled kernel up to rounding order. Requires contiguous F32 tensors
/// on the same GPU. Available only when CUDA is compiled in.
#[cfg(not(no_cuda))]
pub fn matmul(&self, other: &Tensor) -> Self {
assert_eq!(self.dtype, DType::F32, "matmul only supports F32");
@@ -188,17 +189,18 @@ impl Tensor {
let k = self.shape[1];
let n = other.shape[1];
let out = Tensor::zeros(&[m, n], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_gemm_tiled_f32(
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
m as i32,
n as i32,
k as i32,
std::ptr::null_mut(),
);
}
xtrain_cuda::cublas::sgemm(
false,
false,
m,
n,
k,
1.0,
self.data_ptr() as *const f32,
other.data_ptr() as *const f32,
0.0,
out.data_ptr() as *mut f32,
);
xtrain_cuda::device::synchronize().expect("matmul kernel sync failed");
out
}
@@ -234,6 +236,9 @@ impl Tensor {
/// Backward of `C = A @ B` given the upstream gradient `dC` (shape [M,N]).
/// Returns `(dA, dB)` where `dA = dC @ Bᵀ` ([M,K]) and `dB = Aᵀ @ dC`
/// ([K,N]). All tensors contiguous F32 on the same GPU.
///
/// Phase T7: cuBLAS applies the transposes internally via its op flags, so we
/// avoid the two transpose kernels (and their allocations) the T3 version ran.
#[cfg(not(no_cuda))]
pub fn matmul_backward(a: &Tensor, b: &Tensor, dc: &Tensor) -> (Tensor, Tensor) {
assert_eq!(a.ndim(), 2, "matmul_backward requires 2D A");
@@ -243,8 +248,36 @@ impl Tensor {
assert_eq!(dc.shape[0], a.shape[0], "dC rows != A rows (M)");
assert_eq!(dc.shape[1], b.shape[1], "dC cols != B cols (N)");
let da = dc.matmul(&b.transpose_2d()); // [M,N] @ [N,K] = [M,K]
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N] = [K,N]
let (m, k, n) = (a.shape[0], a.shape[1], b.shape[1]);
// dA[M,K] = dC[M,N] · Bᵀ (B stored [K,N], transposed by cuBLAS)
let da = Tensor::zeros(&[m, k], DType::F32, a.device());
xtrain_cuda::cublas::sgemm(
false,
true,
m,
k,
n,
1.0,
dc.data_ptr() as *const f32,
b.data_ptr() as *const f32,
0.0,
da.data_ptr() as *mut f32,
);
// dB[K,N] = Aᵀ · dC[M,N] (A stored [M,K], transposed by cuBLAS)
let db = Tensor::zeros(&[k, n], DType::F32, a.device());
xtrain_cuda::cublas::sgemm(
true,
false,
k,
n,
m,
1.0,
a.data_ptr() as *const f32,
dc.data_ptr() as *const f32,
0.0,
db.data_ptr() as *mut f32,
);
xtrain_cuda::device::synchronize().expect("matmul_backward sync failed");
(da, db)
}