perf: cuBLAS matmul fwd/bwd
Route Tensor::matmul and matmul_backward through cuBLAS Sgemm instead of the hand-written tiled kernel. fp32 → same GEMM up to rounding order, so the T3 cuBLAS tolerance and downstream grad-checks are preserved. - cublas.rs: thread-local persistent handle + row-major sgemm helper with transpose flags (col-major⟺row-major as the T3 oracle does). - matmul_backward: dA/dB via cuBLAS OP_T, dropping the two transpose kernels + their allocations the T3 version ran. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
95
crates/xtrain-cuda/src/cublas.rs
Normal file
95
crates/xtrain-cuda/src/cublas.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
//! cuBLAS GEMM backend (Phase T7).
|
||||
//!
|
||||
//! The hand-written tiled kernel (csrc/ops/gemm.cu) is kept as the T3 learning
|
||||
//! artifact + correctness oracle's counterpart, but the forward + both backward
|
||||
//! matmuls now route through cuBLAS `Sgemm` — fp32, so the result is numerically
|
||||
//! the same GEMM (only the rounding order changes), which is why the T3 tolerance
|
||||
//! against cuBLAS holds unchanged.
|
||||
//!
|
||||
//! **Layout.** cuBLAS is column-major; our tensors are row-major. A row-major
|
||||
//! `[r,c]` matrix handed to cuBLAS with leading dim `c` is read as its transpose
|
||||
//! (col-major `[c,r]`). To get a row-major result `C[m,n] = opA(A)·opB(B)` we
|
||||
//! compute the col-major transpose `Cᵀ[n,m] = opB(B)ᵀ·opA(A)ᵀ`; the bytes of
|
||||
//! col-major `Cᵀ` are exactly row-major `C`. See [`sgemm`] for the index algebra.
|
||||
//!
|
||||
//! **Handle.** cuBLAS handle creation is expensive (T3's oracle made one per
|
||||
//! call). We cache one handle per thread for the lifetime of the process.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use crate::ffi::{self, CublasHandle};
|
||||
use std::cell::RefCell;
|
||||
|
||||
thread_local! {
|
||||
static HANDLE: RefCell<Option<CublasHandle>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
/// Run `f` with the thread's cached cuBLAS handle, creating it on first use.
|
||||
fn with_handle<R>(f: impl FnOnce(CublasHandle) -> R) -> R {
|
||||
HANDLE.with(|h| {
|
||||
let mut slot = h.borrow_mut();
|
||||
if slot.is_none() {
|
||||
let mut handle: CublasHandle = std::ptr::null_mut();
|
||||
let status = unsafe { ffi::cublasCreate_v2(&mut handle) };
|
||||
assert_eq!(status, 0, "cublasCreate failed: {status}");
|
||||
*slot = Some(handle);
|
||||
}
|
||||
f(slot.unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
/// Row-major single-precision GEMM: `C[m,n] = opA(A) · opB(B)` with
|
||||
/// `C = alpha·(…) + beta·C`. `A`/`B`/`C` are device pointers to row-major fp32
|
||||
/// matrices; `trans_a`/`trans_b` request the transpose of the *logical* operand.
|
||||
///
|
||||
/// `m,n,k` are the dims of the math (`opA(A)` is `[m,k]`, `opB(B)` is `[k,n]`).
|
||||
/// The stored, untransposed shapes are: `A` is `[m,k]` (or `[k,m]` if `trans_a`),
|
||||
/// `B` is `[k,n]` (or `[n,k]` if `trans_b`). Their row-major leading dims are the
|
||||
/// stored column counts, derived below.
|
||||
///
|
||||
/// We ask cuBLAS for col-major `Cᵀ[n,m] = opB(B)ᵀ · opA(A)ᵀ`. Since a row-major
|
||||
/// `[r,c]` buffer is col-major `[c,r]`, a row-major operand already *is* its own
|
||||
/// transpose to cuBLAS — so `opB(B)ᵀ` over the row-major bytes of `B` is obtained
|
||||
/// by passing `B` with the OPPOSITE op flag of what `opB` would suggest. Working
|
||||
/// it through: first cuBLAS arg = `B` with op `trans_b ? N : T`, second = `A` with
|
||||
/// op `trans_a ? N : T`, sizes (m=n, n=m, k=k).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn sgemm(
|
||||
trans_a: bool,
|
||||
trans_b: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
alpha: f32,
|
||||
a: *const f32,
|
||||
b: *const f32,
|
||||
beta: f32,
|
||||
c: *mut f32,
|
||||
) {
|
||||
// Leading dims = stored (row-major) column count of each untransposed matrix.
|
||||
let lda = if trans_a { m } else { k }; // A stored [m,k] or [k,m]
|
||||
let ldb = if trans_b { k } else { n }; // B stored [k,n] or [n,k]
|
||||
let ldc = n; // Cᵀ is [n,m] col-major with ld n (== row-major C[m,n])
|
||||
|
||||
let op_b = if trans_b {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
let op_a = if trans_a {
|
||||
ffi::CUBLAS_OP_T
|
||||
} else {
|
||||
ffi::CUBLAS_OP_N
|
||||
};
|
||||
|
||||
with_handle(|handle| {
|
||||
let status = unsafe {
|
||||
ffi::cublasSgemm_v2(
|
||||
handle, op_b, op_a, n as i32, // rows of Cᵀ
|
||||
m as i32, // cols of Cᵀ
|
||||
k as i32, &alpha, b, ldb as i32, a, lda as i32, &beta, c, ldc as i32,
|
||||
)
|
||||
};
|
||||
assert_eq!(status, 0, "cublasSgemm failed: {status}");
|
||||
});
|
||||
}
|
||||
@@ -212,8 +212,9 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// cuBLAS — used ONLY as a correctness reference for the hand-written GEMM in
|
||||
// tests. Declared (and linked, see build.rs) only when CUDA is compiled in.
|
||||
// cuBLAS — the production GEMM backend (Phase T7) and the correctness oracle the
|
||||
// T3 GEMM tests still compare against. Declared (and linked, see build.rs) only
|
||||
// when CUDA is compiled in.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub type CublasHandle = *mut c_void;
|
||||
|
||||
@@ -241,3 +242,5 @@ unsafe extern "C" {
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUBLAS_OP_N: i32 = 0;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub const CUBLAS_OP_T: i32 = 1;
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod cublas;
|
||||
pub mod device;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
|
||||
@@ -161,8 +161,9 @@ impl Tensor {
|
||||
|
||||
/// Matrix multiply: `C = self @ other`. `self`:[M,K], `other`:[K,N] → [M,N].
|
||||
///
|
||||
/// Runs the tiled `gemm_tiled_f32` CUDA kernel. Requires contiguous F32
|
||||
/// tensors on the same GPU. Available only when CUDA is compiled in.
|
||||
/// Routes through cuBLAS `Sgemm` (Phase T7). fp32, so it is the same GEMM as
|
||||
/// the T3 tiled kernel up to rounding order. Requires contiguous F32 tensors
|
||||
/// on the same GPU. Available only when CUDA is compiled in.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn matmul(&self, other: &Tensor) -> Self {
|
||||
assert_eq!(self.dtype, DType::F32, "matmul only supports F32");
|
||||
@@ -188,17 +189,18 @@ impl Tensor {
|
||||
let k = self.shape[1];
|
||||
let n = other.shape[1];
|
||||
let out = Tensor::zeros(&[m, n], DType::F32, self.device());
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_gemm_tiled_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
m as i32,
|
||||
n as i32,
|
||||
k as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
false,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
1.0,
|
||||
self.data_ptr() as *const f32,
|
||||
other.data_ptr() as *const f32,
|
||||
0.0,
|
||||
out.data_ptr() as *mut f32,
|
||||
);
|
||||
xtrain_cuda::device::synchronize().expect("matmul kernel sync failed");
|
||||
out
|
||||
}
|
||||
@@ -234,6 +236,9 @@ impl Tensor {
|
||||
/// Backward of `C = A @ B` given the upstream gradient `dC` (shape [M,N]).
|
||||
/// Returns `(dA, dB)` where `dA = dC @ Bᵀ` ([M,K]) and `dB = Aᵀ @ dC`
|
||||
/// ([K,N]). All tensors contiguous F32 on the same GPU.
|
||||
///
|
||||
/// Phase T7: cuBLAS applies the transposes internally via its op flags, so we
|
||||
/// avoid the two transpose kernels (and their allocations) the T3 version ran.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn matmul_backward(a: &Tensor, b: &Tensor, dc: &Tensor) -> (Tensor, Tensor) {
|
||||
assert_eq!(a.ndim(), 2, "matmul_backward requires 2D A");
|
||||
@@ -243,8 +248,36 @@ impl Tensor {
|
||||
assert_eq!(dc.shape[0], a.shape[0], "dC rows != A rows (M)");
|
||||
assert_eq!(dc.shape[1], b.shape[1], "dC cols != B cols (N)");
|
||||
|
||||
let da = dc.matmul(&b.transpose_2d()); // [M,N] @ [N,K] = [M,K]
|
||||
let db = a.transpose_2d().matmul(dc); // [K,M] @ [M,N] = [K,N]
|
||||
let (m, k, n) = (a.shape[0], a.shape[1], b.shape[1]);
|
||||
// dA[M,K] = dC[M,N] · Bᵀ (B stored [K,N], transposed by cuBLAS)
|
||||
let da = Tensor::zeros(&[m, k], DType::F32, a.device());
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
false,
|
||||
true,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
1.0,
|
||||
dc.data_ptr() as *const f32,
|
||||
b.data_ptr() as *const f32,
|
||||
0.0,
|
||||
da.data_ptr() as *mut f32,
|
||||
);
|
||||
// dB[K,N] = Aᵀ · dC[M,N] (A stored [M,K], transposed by cuBLAS)
|
||||
let db = Tensor::zeros(&[k, n], DType::F32, a.device());
|
||||
xtrain_cuda::cublas::sgemm(
|
||||
true,
|
||||
false,
|
||||
k,
|
||||
n,
|
||||
m,
|
||||
1.0,
|
||||
a.data_ptr() as *const f32,
|
||||
dc.data_ptr() as *const f32,
|
||||
0.0,
|
||||
db.data_ptr() as *mut f32,
|
||||
);
|
||||
xtrain_cuda::device::synchronize().expect("matmul_backward sync failed");
|
||||
(da, db)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user