phase 15: custom GEMV kernel — 46.6 tok/s serial (3.5x improvement, 130% of HF)
Custom bandwidth-optimized GEMV kernel for M=1 BF16 decode, replacing cuBLAS which achieves only ~8% bandwidth utilization for tiny M=1 GEMMs. Kernel design (csrc/gemm/gemv.cu): - K-split tiled: TILE_N=128, TILE_K=256, Grid=(N/128, K/256)=512 blocks - High occupancy: 512 blocks / 170 SMs = ~3 blocks/SM - Coalesced memory access: adjacent threads read adjacent columns of W - Shared memory for x vector (avoids redundant global reads) - FP32 accumulation via atomicAdd (K-split partial sums) - Separate fp32→bf16 conversion kernel Integration: - matmul() auto-dispatches to custom GEMV when M==1 && dtype==BF16 - Batched decode (M>1) continues to use cuBLAS - Caching allocator provides FP32 temp buffer (pooled, no per-call malloc) Ablation results (dash5, RTX 5090, Qwen3-8B BF16): | Config | tok/s | vs HF (36) | vs roofline (112) | |--------|-------|-----------|-------------------| | Phase 14 (cuBLAS M=1) | 13.2 | 37% | 12% | | + Custom GEMV (M=1) | 46.6 | 130% | 42% | | Concurrent batch=4 | 28.2 | 78% | — | Single-request throughput now EXCEEDS HuggingFace transformers by 30%. The custom GEMV achieves ~42% of the theoretical roofline (vs 12% before). Note: concurrent batch=4 (28.2 tok/s) is slower than serial (46.6 tok/s) because the per-seq attention/reshape overhead in batched decode outweighs the cuBLAS M=4 benefit when the custom GEMV already handles M=1 efficiently. Engine should prefer serial decode when custom GEMV is available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
102
csrc/gemm/gemv.cu
Normal file
102
csrc/gemm/gemv.cu
Normal file
@@ -0,0 +1,102 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Custom GEMV kernel for M=1 decode step (BF16):
|
||||
// y[n] = sum_k x[k] * W[k * N + n]
|
||||
// where x: [K] (BF16), W: [K, N] (BF16, row-major), y: [N] (BF16).
|
||||
//
|
||||
// Design: K-split for high occupancy on large GPU (170 SMs).
|
||||
// Grid: (N / TILE_N, K / TILE_K) — each block computes a partial sum
|
||||
// for TILE_N output columns over a TILE_K slice of K.
|
||||
// Partial results are atomicAdd'd to an FP32 accumulator, then a
|
||||
// second kernel converts FP32 -> BF16.
|
||||
//
|
||||
// Memory access: adjacent threads read adjacent columns of the same row
|
||||
// of W, giving perfectly coalesced 128-byte transactions.
|
||||
|
||||
#define GEMV_TILE_N 128
|
||||
#define GEMV_TILE_K 256
|
||||
#define GEMV_BLOCK 128 // = TILE_N, one thread per output column
|
||||
|
||||
__global__ void gemv_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [K]
|
||||
const __nv_bfloat16* __restrict__ W, // [K, N] row-major
|
||||
float* __restrict__ y_fp32, // [N] accumulator
|
||||
int K, int N
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
const int t = threadIdx.x;
|
||||
const int col = block_n * GEMV_TILE_N + t;
|
||||
|
||||
if (col >= N) return;
|
||||
|
||||
const int k_start = block_k * GEMV_TILE_K;
|
||||
const int k_end = min(k_start + GEMV_TILE_K, K);
|
||||
const int k_len = k_end - k_start;
|
||||
|
||||
// Load x[k_start..k_end] into shared memory as FP32
|
||||
__shared__ float x_shared[GEMV_TILE_K];
|
||||
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
||||
x_shared[i] = __bfloat162float(x[k_start + i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute partial dot product for this column
|
||||
float sum = 0.0f;
|
||||
for (int ki = 0; ki < k_len; ki++) {
|
||||
sum += x_shared[ki] * __bfloat162float(W[(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
// Atomic accumulate (handles K-split reduction)
|
||||
atomicAdd(&y_fp32[col], sum);
|
||||
}
|
||||
|
||||
// Conversion kernel: FP32 accumulator -> BF16 output
|
||||
__global__ void gemv_fp32_to_bf16_kernel(
|
||||
const float* __restrict__ src,
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int n
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
dst[idx] = __float2bfloat16(src[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gemv_bf16(
|
||||
const void* x, // [K] BF16
|
||||
const void* W, // [K, N] BF16 row-major
|
||||
void* y_bf16, // [N] BF16 output
|
||||
void* y_fp32_buf, // [N] FP32 temporary (caller-provided)
|
||||
int K, int N,
|
||||
void* stream
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
// Zero the FP32 accumulator
|
||||
cudaMemsetAsync((float*)y_fp32_buf, 0, N * sizeof(float), s);
|
||||
|
||||
// Launch GEMV kernel
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N,
|
||||
(K + GEMV_TILE_K - 1) / GEMV_TILE_K);
|
||||
gemv_bf16_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(float*)y_fp32_buf,
|
||||
K, N
|
||||
);
|
||||
|
||||
// Convert FP32 -> BF16
|
||||
int conv_block = 256;
|
||||
int conv_grid = (N + conv_block - 1) / conv_block;
|
||||
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||
(const float*)y_fp32_buf,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
N
|
||||
);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user