phase 3: GEMM kernels (naive, tiled, cuBLAS)
- Naive GEMM kernel: one thread per output element (F32 + BF16) - Tiled GEMM kernel: 32x32 shared memory tiles (F32 + BF16) - cuBLAS wrapper: cublasGemmEx with row-major trick - GemmBackend enum for runtime backend selection - CublasContext RAII handle - Made error::check public for cross-crate use - 17 GEMM tests: small/medium/rect sizes, all backends, F32+BF16 - Cross-backend consistency verified (naive vs tiled vs cuBLAS) - All 44 tests pass across all crates Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ resolver = "2"
|
||||
members = [
|
||||
"crates/xserv-cuda",
|
||||
"crates/xserv-tensor",
|
||||
"crates/xserv-kernels",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
||||
@@ -23,7 +23,7 @@ impl std::error::Error for CudaError {}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CudaError>;
|
||||
|
||||
pub(crate) fn check(code: i32) -> Result<()> {
|
||||
pub fn check(code: i32) -> Result<()> {
|
||||
if code == ffi::CUDA_SUCCESS {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
12
crates/xserv-kernels/Cargo.toml
Normal file
12
crates/xserv-kernels/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "xserv-kernels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1"
|
||||
|
||||
[dependencies]
|
||||
xserv-cuda = { path = "../xserv-cuda" }
|
||||
xserv-tensor = { path = "../xserv-tensor" }
|
||||
half.workspace = true
|
||||
21
crates/xserv-kernels/build.rs
Normal file
21
crates/xserv-kernels/build.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use std::env;
|
||||
|
||||
fn main() {
|
||||
let cuda_path = env::var("CUDA_HOME")
|
||||
.or_else(|_| env::var("CUDA_PATH"))
|
||||
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||
|
||||
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=cublas");
|
||||
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.cudart("shared")
|
||||
.flag("-gencode=arch=compute_120,code=sm_120")
|
||||
.file("../../csrc/gemm/naive.cu")
|
||||
.file("../../csrc/gemm/tiled.cu")
|
||||
.compile("xserv_gemm_kernels");
|
||||
|
||||
println!("cargo:rerun-if-changed=../../csrc/gemm/");
|
||||
}
|
||||
151
crates/xserv-kernels/src/gemm.rs
Normal file
151
crates/xserv-kernels/src/gemm.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GemmBackend {
|
||||
Naive,
|
||||
Tiled,
|
||||
CuBlas,
|
||||
}
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void);
|
||||
}
|
||||
|
||||
// --- FFI: cuBLAS ---
|
||||
type CublasHandle = *mut c_void;
|
||||
|
||||
#[allow(non_upper_case_globals)]
|
||||
const CUBLAS_OP_N: i32 = 0;
|
||||
|
||||
// cudaDataType
|
||||
const CUDA_R_32F: i32 = 0;
|
||||
const CUDA_R_16BF: i32 = 14;
|
||||
|
||||
// cublasComputeType
|
||||
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||||
|
||||
unsafe extern "C" {
|
||||
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
||||
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
||||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||||
fn cublasGemmEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
m: i32, n: i32, k: i32,
|
||||
alpha: *const c_void,
|
||||
a: *const c_void, a_type: i32, lda: i32,
|
||||
b: *const c_void, b_type: i32, ldb: i32,
|
||||
beta: *const c_void,
|
||||
c: *mut c_void, c_type: i32, ldc: i32,
|
||||
compute_type: i32,
|
||||
algo: i32,
|
||||
) -> i32;
|
||||
}
|
||||
|
||||
pub struct CublasContext {
|
||||
handle: CublasHandle,
|
||||
}
|
||||
|
||||
impl CublasContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut handle = std::ptr::null_mut();
|
||||
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
|
||||
Ok(Self { handle })
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CublasContext {
|
||||
fn drop(&mut self) {
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasDestroy_v2(self.handle) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication: C = A @ B
|
||||
/// A: [M, K], B: [K, N], C: [M, N]
|
||||
/// All tensors must be contiguous and on the same GPU.
|
||||
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||||
assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors");
|
||||
assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors");
|
||||
|
||||
let m = a.shape()[0];
|
||||
let k = a.shape()[1];
|
||||
let n = b.shape()[1];
|
||||
let dtype = a.dtype();
|
||||
|
||||
let c = Tensor::zeros(&[m, n], dtype, a.device());
|
||||
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
let c_ptr = c.data_ptr() as *mut c_void;
|
||||
let null_stream = std::ptr::null_mut();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for naive GEMM"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
GemmBackend::Tiled => {
|
||||
unsafe {
|
||||
match dtype {
|
||||
DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream),
|
||||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||||
}
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
GemmBackend::CuBlas => {
|
||||
// cuBLAS uses column-major, but we have row-major tensors.
|
||||
// Trick: compute C^T = B^T @ A^T, which gives us C in row-major.
|
||||
// cuBLAS sees our row-major data as column-major transposed.
|
||||
let ctx = CublasContext::new().unwrap();
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||||
_ => panic!("unsupported dtype for cuBLAS GEMM"),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
cublasSetStream_v2(ctx.handle, null_stream);
|
||||
// Row-major trick: swap A/B and transpose flags
|
||||
// C(row-major) = A @ B <=> C^T(col-major) = B^T @ A^T
|
||||
error::check(cublasGemmEx(
|
||||
ctx.handle,
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n as i32, m as i32, k as i32,
|
||||
&alpha as *const f32 as *const c_void,
|
||||
b_ptr, b_type, n as i32, // B as col-major = B^T
|
||||
a_ptr, a_type, k as i32, // A as col-major = A^T
|
||||
&beta as *const f32 as *const c_void,
|
||||
c_ptr, c_type, n as i32, // C as col-major = C^T
|
||||
CUBLAS_COMPUTE_32F,
|
||||
-1, // default algo
|
||||
)).expect("cuBLAS GEMM failed");
|
||||
}
|
||||
xserv_cuda::device::synchronize().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
3
crates/xserv-kernels/src/lib.rs
Normal file
3
crates/xserv-kernels/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod gemm;
|
||||
|
||||
pub use gemm::{GemmBackend, matmul};
|
||||
152
crates/xserv-kernels/tests/gemm_test.rs
Normal file
152
crates/xserv-kernels/tests/gemm_test.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use half::bf16;
|
||||
use xserv_kernels::{matmul, GemmBackend};
|
||||
use xserv_tensor::{Device, Tensor};
|
||||
|
||||
fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec<f32> {
|
||||
let mut c = vec![0.0f32; m * n];
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kk in 0..k {
|
||||
sum += a[i * k + kk] * b[kk * n + j];
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
c
|
||||
}
|
||||
|
||||
fn check_close_f32(result: &[f32], expected: &[f32], atol: f32) {
|
||||
assert_eq!(result.len(), expected.len());
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
assert!(
|
||||
(r - e).abs() <= atol,
|
||||
"mismatch at index {i}: got {r}, expected {e}, diff {}",
|
||||
(r - e).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_close_bf16(result: &[bf16], expected: &[f32], atol: f32) {
|
||||
assert_eq!(result.len(), expected.len());
|
||||
for (i, (r, e)) in result.iter().zip(expected).enumerate() {
|
||||
let rv = r.to_f32();
|
||||
assert!(
|
||||
(rv - e).abs() <= atol,
|
||||
"mismatch at index {i}: got {rv}, expected {e}, diff {}",
|
||||
(rv - e).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_gemm_test_f32(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
let expected = cpu_matmul_f32(&a_data, &b_data, m, n, k);
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
let c = matmul(&a, &b, backend);
|
||||
|
||||
let c_cpu = c.to_device(Device::Cpu);
|
||||
check_close_f32(c_cpu.as_slice::<f32>(), &expected, 1e-4);
|
||||
}
|
||||
|
||||
fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let a_f32: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_f32: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
let expected = cpu_matmul_f32(&a_f32, &b_f32, m, n, k);
|
||||
|
||||
let a_data: Vec<bf16> = a_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
let b_data: Vec<bf16> = b_f32.iter().map(|&v| bf16::from_f32(v)).collect();
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
let c = matmul(&a, &b, backend);
|
||||
|
||||
let c_cpu = c.to_device(Device::Cpu);
|
||||
check_close_bf16(c_cpu.as_slice::<bf16>(), &expected, 0.1);
|
||||
}
|
||||
|
||||
// --- F32 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); }
|
||||
|
||||
// --- BF16 tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); }
|
||||
|
||||
// --- Larger benchmark-style tests ---
|
||||
|
||||
#[test]
|
||||
fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); }
|
||||
|
||||
#[test]
|
||||
fn test_gemm_consistency_all_backends() {
|
||||
xserv_cuda::device::set_device(0).unwrap();
|
||||
|
||||
let m = 64;
|
||||
let n = 64;
|
||||
let k = 64;
|
||||
let a_data: Vec<f32> = (0..m * k).map(|i| ((i % 7) as f32 - 3.0) * 0.1).collect();
|
||||
let b_data: Vec<f32> = (0..k * n).map(|i| ((i % 11) as f32 - 5.0) * 0.1).collect();
|
||||
|
||||
let a = Tensor::from_slice(&a_data, &[m, k]).to_device(Device::Cuda(0));
|
||||
let b = Tensor::from_slice(&b_data, &[k, n]).to_device(Device::Cuda(0));
|
||||
|
||||
let c_naive = matmul(&a, &b, GemmBackend::Naive).to_device(Device::Cpu);
|
||||
let c_tiled = matmul(&a, &b, GemmBackend::Tiled).to_device(Device::Cpu);
|
||||
let c_cublas = matmul(&a, &b, GemmBackend::CuBlas).to_device(Device::Cpu);
|
||||
|
||||
let naive = c_naive.as_slice::<f32>();
|
||||
let tiled = c_tiled.as_slice::<f32>();
|
||||
let cublas = c_cublas.as_slice::<f32>();
|
||||
|
||||
check_close_f32(naive, cublas, 1e-4);
|
||||
check_close_f32(tiled, cublas, 1e-4);
|
||||
}
|
||||
62
csrc/gemm/naive.cu
Normal file
62
csrc/gemm/naive.cu
Normal file
@@ -0,0 +1,62 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Naive GEMM: each thread computes one element of C.
|
||||
// C[i][j] = sum_k A[i][k] * B[k][j]
|
||||
// All matrices are row-major.
|
||||
__global__ void gemm_naive_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]);
|
||||
}
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_naive_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += A[row * K + k] * B[k * N + col];
|
||||
}
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_naive_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_naive_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
116
csrc/gemm/tiled.cu
Normal file
116
csrc/gemm/tiled.cu
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Tiled GEMM using shared memory.
|
||||
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
||||
// into shared memory, then computes a partial dot product.
|
||||
#define TILE_SIZE 32
|
||||
|
||||
__global__ void gemm_tiled_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
// Load tile of A
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
// Load tile of B
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_tiled_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]);
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]);
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_tiled_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_tiled_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user