diff --git a/Cargo.toml b/Cargo.toml index 34e3b83..bf73ab6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "crates/xserv-cuda", "crates/xserv-tensor", + "crates/xserv-kernels", ] [workspace.package] diff --git a/crates/xserv-cuda/src/error.rs b/crates/xserv-cuda/src/error.rs index 3329a7f..4d47fdf 100644 --- a/crates/xserv-cuda/src/error.rs +++ b/crates/xserv-cuda/src/error.rs @@ -23,7 +23,7 @@ impl std::error::Error for CudaError {} pub type Result = std::result::Result; -pub(crate) fn check(code: i32) -> Result<()> { +pub fn check(code: i32) -> Result<()> { if code == ffi::CUDA_SUCCESS { return Ok(()); } diff --git a/crates/xserv-kernels/Cargo.toml b/crates/xserv-kernels/Cargo.toml new file mode 100644 index 0000000..d96b3f2 --- /dev/null +++ b/crates/xserv-kernels/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "xserv-kernels" +version.workspace = true +edition.workspace = true + +[build-dependencies] +cc = "1" + +[dependencies] +xserv-cuda = { path = "../xserv-cuda" } +xserv-tensor = { path = "../xserv-tensor" } +half.workspace = true diff --git a/crates/xserv-kernels/build.rs b/crates/xserv-kernels/build.rs new file mode 100644 index 0000000..6ae1758 --- /dev/null +++ b/crates/xserv-kernels/build.rs @@ -0,0 +1,21 @@ +use std::env; + +fn main() { + let cuda_path = env::var("CUDA_HOME") + .or_else(|_| env::var("CUDA_PATH")) + .unwrap_or_else(|_| "/usr/local/cuda".to_string()); + + println!("cargo:rustc-link-search=native={cuda_path}/lib64"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=cublas"); + + cc::Build::new() + .cuda(true) + .cudart("shared") + .flag("-gencode=arch=compute_120,code=sm_120") + .file("../../csrc/gemm/naive.cu") + .file("../../csrc/gemm/tiled.cu") + .compile("xserv_gemm_kernels"); + + println!("cargo:rerun-if-changed=../../csrc/gemm/"); +} diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs new file mode 100644 index 0000000..533b2bc --- /dev/null +++ b/crates/xserv-kernels/src/gemm.rs @@ -0,0 +1,151 @@ +use std::ffi::c_void; +use xserv_cuda::error::{self, Result}; +use xserv_tensor::{DType, Device, Tensor}; + +#[derive(Debug, Clone, Copy)] +pub enum GemmBackend { + Naive, + Tiled, + CuBlas, +} + +// --- FFI: custom CUDA kernels --- +unsafe extern "C" { + fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); + fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); + fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); + fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); +} + +// --- FFI: cuBLAS --- +type CublasHandle = *mut c_void; + +#[allow(non_upper_case_globals)] +const CUBLAS_OP_N: i32 = 0; + +// cudaDataType +const CUDA_R_32F: i32 = 0; +const CUDA_R_16BF: i32 = 14; + +// cublasComputeType +const CUBLAS_COMPUTE_32F: i32 = 68; + +unsafe extern "C" { + fn cublasCreate_v2(handle: *mut CublasHandle) -> i32; + fn cublasDestroy_v2(handle: CublasHandle) -> i32; + fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32; + fn cublasGemmEx( + handle: CublasHandle, + transa: i32, transb: i32, + m: i32, n: i32, k: i32, + alpha: *const c_void, + a: *const c_void, a_type: i32, lda: i32, + b: *const c_void, b_type: i32, ldb: i32, + beta: *const c_void, + c: *mut c_void, c_type: i32, ldc: i32, + compute_type: i32, + algo: i32, + ) -> i32; +} + +pub struct CublasContext { + handle: CublasHandle, +} + +impl CublasContext { + pub fn new() -> Result { + let mut handle = std::ptr::null_mut(); + error::check(unsafe { cublasCreate_v2(&mut handle) })?; + Ok(Self { handle }) + } +} + +impl Drop for CublasContext { + fn drop(&mut self) { + if !self.handle.is_null() { + unsafe { cublasDestroy_v2(self.handle) }; + } + } +} + +/// Matrix multiplication: C = A @ B +/// A: [M, K], B: [K, N], C: [M, N] +/// All tensors must be contiguous and on the same GPU. +pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { + assert_eq!(a.ndim(), 2); + assert_eq!(b.ndim(), 2); + assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch"); + assert_eq!(a.dtype(), b.dtype(), "dtype mismatch"); + assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors"); + assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors"); + + let m = a.shape()[0]; + let k = a.shape()[1]; + let n = b.shape()[1]; + let dtype = a.dtype(); + + let c = Tensor::zeros(&[m, n], dtype, a.device()); + + let a_ptr = a.data_ptr() as *const c_void; + let b_ptr = b.data_ptr() as *const c_void; + let c_ptr = c.data_ptr() as *mut c_void; + let null_stream = std::ptr::null_mut(); + + match backend { + GemmBackend::Naive => { + unsafe { + match dtype { + DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), + DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), + _ => panic!("unsupported dtype for naive GEMM"), + } + } + xserv_cuda::device::synchronize().unwrap(); + } + GemmBackend::Tiled => { + unsafe { + match dtype { + DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), + DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), + _ => panic!("unsupported dtype for tiled GEMM"), + } + } + xserv_cuda::device::synchronize().unwrap(); + } + GemmBackend::CuBlas => { + // cuBLAS uses column-major, but we have row-major tensors. + // Trick: compute C^T = B^T @ A^T, which gives us C in row-major. + // cuBLAS sees our row-major data as column-major transposed. + let ctx = CublasContext::new().unwrap(); + let alpha = 1.0f32; + let beta = 0.0f32; + + let (a_type, b_type, c_type) = match dtype { + DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F), + DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF), + _ => panic!("unsupported dtype for cuBLAS GEMM"), + }; + + unsafe { + cublasSetStream_v2(ctx.handle, null_stream); + // Row-major trick: swap A/B and transpose flags + // C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T + error::check(cublasGemmEx( + ctx.handle, + CUBLAS_OP_N, CUBLAS_OP_N, + n as i32, m as i32, k as i32, + &alpha as *const f32 as *const c_void, + b_ptr, b_type, n as i32, // B as col-major = B^T + a_ptr, a_type, k as i32, // A as col-major = A^T + &beta as *const f32 as *const c_void, + c_ptr, c_type, n as i32, // C as col-major = C^T + CUBLAS_COMPUTE_32F, + -1, // default algo + )).expect("cuBLAS GEMM failed"); + } + xserv_cuda::device::synchronize().unwrap(); + } + } + + c +} diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs new file mode 100644 index 0000000..e8083fe --- /dev/null +++ b/crates/xserv-kernels/src/lib.rs @@ -0,0 +1,3 @@ +pub mod gemm; + +pub use gemm::{GemmBackend, matmul}; diff --git a/crates/xserv-kernels/tests/gemm_test.rs b/crates/xserv-kernels/tests/gemm_test.rs new file mode 100644 index 0000000..3546d59 --- /dev/null +++ b/crates/xserv-kernels/tests/gemm_test.rs @@ -0,0 +1,152 @@ +use half::bf16; +use xserv_kernels::{matmul, GemmBackend}; +use xserv_tensor::{Device, Tensor}; + +fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0.0f32; m * n]; + for i in 0..m { + for j in 0..n { + let mut sum = 0.0f32; + for kk in 0..k { + sum += a[i * k + kk] * b[kk * n + j]; + } + c[i * n + j] = sum; + } + } + c +} + +fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) { + assert_eq!(result.len(), expected.len()); + for (i, (r, e)) in result.iter().zip(expected).enumerate() { + assert!( + (r - e).abs() <= atol, + "mismatch at index {i}: got {r}, expected {e}, diff {}", + (r - e).abs() + ); + } +} + +fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) { + assert_eq!(result.len(), expected.len()); + for (i, (r, e)) in result.iter().zip(expected).enumerate() { + let rv = r.to_f32(); + assert!( + (rv - e).abs() <= atol, + "mismatch at index {i}: got {rv}, expected {e}, diff {}", + (rv - e).abs() + ); + } +} + +fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) { + xserv_cuda::device::set_device(0).unwrap(); + + let a_data: Vec = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect(); + let b_data: Vec = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect(); + let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k); + + let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0)); + let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0)); + let c = matmul(&a, &b, backend); + + let c_cpu = c.to_device(Device::Cpu); + check_close_f32(c_cpu.as_slice::(), &expected, 1e-4); +} + +fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) { + xserv_cuda::device::set_device(0).unwrap(); + + let a_f32: Vec = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect(); + let b_f32: Vec = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect(); + let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k); + + let a_data: Vec = a_f32.iter().map(|&v| bf16::from_f32(v)).collect(); + let b_data: Vec = b_f32.iter().map(|&v| bf16::from_f32(v)).collect(); + + let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0)); + let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0)); + let c = matmul(&a, &b, backend); + + let c_cpu = c.to_device(Device::Cpu); + check_close_bf16(c_cpu.as_slice::(), &expected, 0.1); +} + +// --- F32 tests --- + +#[test] +fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); } + +#[test] +fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); } + +#[test] +fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); } + +#[test] +fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); } + +#[test] +fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); } + +#[test] +fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); } + +#[test] +fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); } + +#[test] +fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); } + +#[test] +fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); } + +// --- BF16 tests --- + +#[test] +fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); } + +#[test] +fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); } + +#[test] +fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); } + +#[test] +fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); } + +#[test] +fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); } + +#[test] +fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); } + +// --- Larger benchmark-style tests --- + +#[test] +fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); } + +#[test] +fn test_gemm_consistency_all_backends() { + xserv_cuda::device::set_device(0).unwrap(); + + let m = 64; + let n = 64; + let k = 64; + let a_data: Vec = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect(); + let b_data: Vec = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect(); + + let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0)); + let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0)); + + let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu); + let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu); + let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu); + + let naive = c_naive.as_slice::(); + let tiled = c_tiled.as_slice::(); + let cublas = c_cublas.as_slice::(); + + check_close_f32(naive, cublas, 1e-4); + check_close_f32(tiled, cublas, 1e-4); +} diff --git a/csrc/gemm/naive.cu b/csrc/gemm/naive.cu new file mode 100644 index 0000000..d917436 --- /dev/null +++ b/csrc/gemm/naive.cu @@ -0,0 +1,62 @@ +#include + +// Naive GEMM: each thread computes one element of C. +// C[i][j] = sum_k A[i][k] * B[k][j] +// All matrices are row-major. +__global__ void gemm_naive_bf16( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, + int M, int N, int K +) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + float sum = 0.0f; + for (int k = 0; k < K; k++) { + sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]); + } + C[row * N + col] = __float2bfloat16(sum); + } +} + +__global__ void gemm_naive_f32( + const float* A, const float* B, float* C, + int M, int N, int K +) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (row < M && col < N) { + float sum = 0.0f; + for (int k = 0; k < K; k++) { + sum += A[row * K + k] * B[k * N + col]; + } + C[row * N + col] = sum; + } +} + +extern "C" { + +void launch_gemm_naive_bf16( + const void* A, const void* B, void* C, + int M, int N, int K, void* stream +) { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + gemm_naive_bf16<<>>( + (const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K + ); +} + +void launch_gemm_naive_f32( + const void* A, const void* B, void* C, + int M, int N, int K, void* stream +) { + dim3 block(16, 16); + dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y); + gemm_naive_f32<<>>( + (const float*)A, (const float*)B, (float*)C, M, N, K + ); +} + +} // extern "C" diff --git a/csrc/gemm/tiled.cu b/csrc/gemm/tiled.cu new file mode 100644 index 0000000..00ad1b1 --- /dev/null +++ b/csrc/gemm/tiled.cu @@ -0,0 +1,116 @@ +#include + +// Tiled GEMM using shared memory. +// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B +// into shared memory, then computes a partial dot product. +#define TILE_SIZE 32 + +__global__ void gemm_tiled_f32( + const float* A, const float* B, float* C, + int M, int N, int K +) { + __shared__ float As[TILE_SIZE][TILE_SIZE]; + __shared__ float Bs[TILE_SIZE][TILE_SIZE]; + + int row = blockIdx.y * TILE_SIZE + threadIdx.y; + int col = blockIdx.x * TILE_SIZE + threadIdx.x; + + float sum = 0.0f; + + for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) { + // Load tile of A + int a_col = t * TILE_SIZE + threadIdx.x; + if (row < M && a_col < K) { + As[threadIdx.y][threadIdx.x] = A[row * K + a_col]; + } else { + As[threadIdx.y][threadIdx.x] = 0.0f; + } + + // Load tile of B + int b_row = t * TILE_SIZE + threadIdx.y; + if (b_row < K && col < N) { + Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col]; + } else { + Bs[threadIdx.y][threadIdx.x] = 0.0f; + } + + __syncthreads(); + + for (int k = 0; k < TILE_SIZE; k++) { + sum += As[threadIdx.y][k] * Bs[k][threadIdx.x]; + } + + __syncthreads(); + } + + if (row < M && col < N) { + C[row * N + col] = sum; + } +} + +__global__ void gemm_tiled_bf16( + const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C, + int M, int N, int K +) { + __shared__ float As[TILE_SIZE][TILE_SIZE]; + __shared__ float Bs[TILE_SIZE][TILE_SIZE]; + + int row = blockIdx.y * TILE_SIZE + threadIdx.y; + int col = blockIdx.x * TILE_SIZE + threadIdx.x; + + float sum = 0.0f; + + for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) { + int a_col = t * TILE_SIZE + threadIdx.x; + if (row < M && a_col < K) { + As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]); + } else { + As[threadIdx.y][threadIdx.x] = 0.0f; + } + + int b_row = t * TILE_SIZE + threadIdx.y; + if (b_row < K && col < N) { + Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]); + } else { + Bs[threadIdx.y][threadIdx.x] = 0.0f; + } + + __syncthreads(); + + for (int k = 0; k < TILE_SIZE; k++) { + sum += As[threadIdx.y][k] * Bs[k][threadIdx.x]; + } + + __syncthreads(); + } + + if (row < M && col < N) { + C[row * N + col] = __float2bfloat16(sum); + } +} + +extern "C" { + +void launch_gemm_tiled_f32( + const void* A, const void* B, void* C, + int M, int N, int K, void* stream +) { + dim3 block(TILE_SIZE, TILE_SIZE); + dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE); + gemm_tiled_f32<<>>( + (const float*)A, (const float*)B, (float*)C, M, N, K + ); +} + +void launch_gemm_tiled_bf16( + const void* A, const void* B, void* C, + int M, int N, int K, void* stream +) { + dim3 block(TILE_SIZE, TILE_SIZE); + dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE); + gemm_tiled_bf16<<>>( + (const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K + ); +} + +} // extern "C"