phase 3: GEMM kernels (naive, tiled, cuBLAS)

- Naive GEMM kernel: one thread per output element (F32 + BF16)
- Tiled GEMM kernel: 32x32 shared memory tiles (F32 + BF16)
- cuBLAS wrapper: cublasGemmEx with row-major trick
- GemmBackend enum for runtime backend selection
- CublasContext RAII handle
- Made error::check public for cross-crate use
- 17 GEMM tests: small/medium/rect sizes, all backends, F32+BF16
- Cross-backend consistency verified (naive vs tiled vs cuBLAS)
- All 44 tests pass across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-21 19:48:05 +08:00
parent a83971fa25
commit d77f921a12
9 changed files with 519 additions and 1 deletions

View File

@@ -23,7 +23,7 @@ impl std::error::Error for CudaError {}
pub type Result<T> = std::result::Result<T, CudaError>;
pub(crate) fn check(code: i32) -> Result<()> {
pub fn check(code: i32) -> Result<()> {
if code == ffi::CUDA_SUCCESS {
return Ok(());
}

View File

@@ -0,0 +1,12 @@
[package]
name = "xserv-kernels"
version.workspace = true
edition.workspace = true
[build-dependencies]
cc = "1"
[dependencies]
xserv-cuda = { path = "../xserv-cuda" }
xserv-tensor = { path = "../xserv-tensor" }
half.workspace = true

View File

@@ -0,0 +1,21 @@
use std::env;
fn main() {
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cublas");
cc::Build::new()
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.file("../../csrc/gemm/naive.cu")
.file("../../csrc/gemm/tiled.cu")
.compile("xserv_gemm_kernels");
println!("cargo:rerun-if-changed=../../csrc/gemm/");
}

View File

@@ -0,0 +1,151 @@
use std::ffi::c_void;
use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
#[derive(Debug, Clone, Copy)]
pub enum GemmBackend {
Naive,
Tiled,
CuBlas,
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
}
// --- FFI: cuBLAS ---
type CublasHandle = *mut c_void;
#[allow(non_upper_case_globals)]
const CUBLAS_OP_N: i32 = 0;
// cudaDataType
const CUDA_R_32F: i32 = 0;
const CUDA_R_16BF: i32 = 14;
// cublasComputeType
const CUBLAS_COMPUTE_32F: i32 = 68;
unsafe extern "C" {
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
fn cublasGemmEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32,
b: *const c_void, b_type: i32, ldb: i32,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
}
pub struct CublasContext {
handle: CublasHandle,
}
impl CublasContext {
pub fn new() -> Result<Self> {
let mut handle = std::ptr::null_mut();
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
Ok(Self { handle })
}
}
impl Drop for CublasContext {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { cublasDestroy_v2(self.handle) };
}
}
}
/// Matrix multiplication: C = A @ B
/// A: [M, K], B: [K, N], C: [M, N]
/// All tensors must be contiguous and on the same GPU.
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
let dtype = a.dtype();
let c = Tensor::zeros(&[m, n], dtype, a.device());
let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void;
let c_ptr = c.data_ptr() as *mut c_void;
let null_stream = std::ptr::null_mut();
match backend {
GemmBackend::Naive => {
unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for naive GEMM"),
}
}
xserv_cuda::device::synchronize().unwrap();
}
GemmBackend::Tiled => {
unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
_ => panic!("unsupported dtype for tiled GEMM"),
}
}
xserv_cuda::device::synchronize().unwrap();
}
GemmBackend::CuBlas => {
// cuBLAS uses column-major, but we have row-major tensors.
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
// cuBLAS sees our row-major data as column-major transposed.
let ctx = CublasContext::new().unwrap();
let alpha = 1.0f32;
let beta = 0.0f32;
let (a_type, b_type, c_type) = match dtype {
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
_ => panic!("unsupported dtype for cuBLAS GEMM"),
};
unsafe {
cublasSetStream_v2(ctx.handle, null_stream);
// Row-major trick: swap A/B and transpose flags
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
error::check(cublasGemmEx(
ctx.handle,
CUBLAS_OP_N, CUBLAS_OP_N,
n as i32, m as i32, k as i32,
&alpha as *const f32 as *const c_void,
b_ptr, b_type, n as i32, // B as col-major = B^T
a_ptr, a_type, k as i32, // A as col-major = A^T
&beta as *const f32 as *const c_void,
c_ptr, c_type, n as i32, // C as col-major = C^T
CUBLAS_COMPUTE_32F,
-1, // default algo
)).expect("cuBLAS GEMM failed");
}
xserv_cuda::device::synchronize().unwrap();
}
}
c
}

View File

@@ -0,0 +1,3 @@
pub mod gemm;
pub use gemm::{GemmBackend, matmul};

View File

@@ -0,0 +1,152 @@
use half::bf16;
use xserv_kernels::{matmul, GemmBackend};
use xserv_tensor::{Device, Tensor};
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k {
sum += a[i * k + kk] * b[kk * n + j];
}
c[i * n + j] = sum;
}
}
c
}
fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
assert!(
(r - e).abs() <= atol,
"mismatch at index {i}: got {r}, expected {e}, diff {}",
(r - e).abs()
);
}
}
fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) {
assert_eq!(result.len(), expected.len());
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
let rv = r.to_f32();
assert!(
(rv - e).abs() <= atol,
"mismatch at index {i}: got {rv}, expected {e}, diff {}",
(rv - e).abs()
);
}
}
fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k);
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_f32(c_cpu.as_slice::<f32>(), &expected, 1e-4);
}
fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
xserv_cuda::device::set_device(0).unwrap();
let a_f32: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_f32: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k);
let a_data: Vec<bf16> = a_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let b_data: Vec<bf16> = b_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c = matmul(&a, &b, backend);
let c_cpu = c.to_device(Device::Cpu);
check_close_bf16(c_cpu.as_slice::<bf16>(), &expected, 0.1);
}
// --- F32 tests ---
#[test]
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
#[test]
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
#[test]
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
#[test]
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
#[test]
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
#[test]
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
#[test]
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
#[test]
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
#[test]
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
// --- BF16 tests ---
#[test]
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
#[test]
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
#[test]
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
#[test]
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
#[test]
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
#[test]
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
// --- Larger benchmark-style tests ---
#[test]
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
#[test]
fn test_gemm_consistency_all_backends() {
xserv_cuda::device::set_device(0).unwrap();
let m = 64;
let n = 64;
let k = 64;
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu);
let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu);
let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu);
let naive = c_naive.as_slice::<f32>();
let tiled = c_tiled.as_slice::<f32>();
let cublas = c_cublas.as_slice::<f32>();
check_close_f32(naive, cublas, 1e-4);
check_close_f32(tiled, cublas, 1e-4);
}