phase 3: GEMM kernels (naive, tiled, cuBLAS)
- Naive GEMM kernel: one thread per output element (F32 + BF16) - Tiled GEMM kernel: 32x32 shared memory tiles (F32 + BF16) - cuBLAS wrapper: cublasGemmEx with row-major trick - GemmBackend enum for runtime backend selection - CublasContext RAII handle - Made error::check public for cross-crate use - 17 GEMM tests: small/medium/rect sizes, all backends, F32+BF16 - Cross-backend consistency verified (naive vs tiled vs cuBLAS) - All 44 tests pass across all crates Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
62
csrc/gemm/naive.cu
Normal file
62
csrc/gemm/naive.cu
Normal file
@@ -0,0 +1,62 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Naive GEMM: each thread computes one element of C.
|
||||
// C[i][j] = sum_k A[i][k] * B[k][j]
|
||||
// All matrices are row-major.
|
||||
__global__ void gemm_naive_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]);
|
||||
}
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_naive_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (row < M && col < N) {
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < K; k++) {
|
||||
sum += A[row * K + k] * B[k * N + col];
|
||||
}
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_naive_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_naive_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
|
||||
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
116
csrc/gemm/tiled.cu
Normal file
116
csrc/gemm/tiled.cu
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Tiled GEMM using shared memory.
|
||||
// Each thread block loads TILE_SIZE x TILE_SIZE tiles of A and B
|
||||
// into shared memory, then computes a partial dot product.
|
||||
#define TILE_SIZE 32
|
||||
|
||||
__global__ void gemm_tiled_f32(
|
||||
const float* A, const float* B, float* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
// Load tile of A
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = A[row * K + a_col];
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
// Load tile of B
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = B[b_row * N + col];
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gemm_tiled_bf16(
|
||||
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
|
||||
int M, int N, int K
|
||||
) {
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
|
||||
int a_col = t * TILE_SIZE + threadIdx.x;
|
||||
if (row < M && a_col < K) {
|
||||
As[threadIdx.y][threadIdx.x] = __bfloat162float(A[row * K + a_col]);
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
int b_row = t * TILE_SIZE + threadIdx.y;
|
||||
if (b_row < K && col < N) {
|
||||
Bs[threadIdx.y][threadIdx.x] = __bfloat162float(B[b_row * N + col]);
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 0; k < TILE_SIZE; k++) {
|
||||
sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (row < M && col < N) {
|
||||
C[row * N + col] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemm_tiled_f32(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)A, (const float*)B, (float*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
void launch_gemm_tiled_bf16(
|
||||
const void* A, const void* B, void* C,
|
||||
int M, int N, int K, void* stream
|
||||
) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE);
|
||||
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
|
||||
gemm_tiled_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user