Files
xserv/csrc/gemm/naive.cu
Gahow Wang d77f921a12 phase 3: GEMM kernels (naive, tiled, cuBLAS)
- Naive GEMM kernel: one thread per output element (F32 + BF16)
- Tiled GEMM kernel: 32x32 shared memory tiles (F32 + BF16)
- cuBLAS wrapper: cublasGemmEx with row-major trick
- GemmBackend enum for runtime backend selection
- CublasContext RAII handle
- Made error::check public for cross-crate use
- 17 GEMM tests: small/medium/rect sizes, all backends, F32+BF16
- Cross-backend consistency verified (naive vs tiled vs cuBLAS)
- All 44 tests pass across all crates

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 19:48:05 +08:00

63 lines
1.7 KiB
Plaintext

#include <cuda_bf16.h>
// Naive GEMM: each thread computes one element of C.
// C[i][j] = sum_k A[i][k] * B[k][j]
// All matrices are row-major.
__global__ void gemm_naive_bf16(
const __nv_bfloat16* A, const __nv_bfloat16* B, __nv_bfloat16* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += __bfloat162float(A[row * K + k]) * __bfloat162float(B[k * N + col]);
}
C[row * N + col] = __float2bfloat16(sum);
}
}
__global__ void gemm_naive_f32(
const float* A, const float* B, float* C,
int M, int N, int K
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
extern "C" {
void launch_gemm_naive_bf16(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
gemm_naive_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)A, (const __nv_bfloat16*)B, (__nv_bfloat16*)C, M, N, K
);
}
void launch_gemm_naive_f32(
const void* A, const void* B, void* C,
int M, int N, int K, void* stream
) {
dim3 block(16, 16);
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
gemm_naive_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
(const float*)A, (const float*)B, (float*)C, M, N, K
);
}
} // extern "C"