phase 15: custom GEMV kernel — 46.6 tok/s serial (3.5x improvement, 130% of HF)

Custom bandwidth-optimized GEMV kernel for M=1 BF16 decode, replacing
cuBLAS which achieves only ~8% bandwidth utilization for tiny M=1 GEMMs.

Kernel design (csrc/gemm/gemv.cu):
- K-split tiled: TILE_N=128, TILE_K=256, Grid=(N/128, K/256)=512 blocks
- High occupancy: 512 blocks / 170 SMs = ~3 blocks/SM
- Coalesced memory access: adjacent threads read adjacent columns of W
- Shared memory for x vector (avoids redundant global reads)
- FP32 accumulation via atomicAdd (K-split partial sums)
- Separate fp32→bf16 conversion kernel

Integration:
- matmul() auto-dispatches to custom GEMV when M==1 && dtype==BF16
- Batched decode (M>1) continues to use cuBLAS
- Caching allocator provides FP32 temp buffer (pooled, no per-call malloc)

Ablation results (dash5, RTX 5090, Qwen3-8B BF16):

| Config | tok/s | vs HF (36) | vs roofline (112) |
|--------|-------|-----------|-------------------|
| Phase 14 (cuBLAS M=1) | 13.2 | 37% | 12% |
| + Custom GEMV (M=1) | 46.6 | 130% | 42% |
| Concurrent batch=4 | 28.2 | 78% | — |

Single-request throughput now EXCEEDS HuggingFace transformers by 30%.
The custom GEMV achieves ~42% of the theoretical roofline (vs 12% before).

Note: concurrent batch=4 (28.2 tok/s) is slower than serial (46.6 tok/s)
because the per-seq attention/reshape overhead in batched decode outweighs
the cuBLAS M=4 benefit when the custom GEMV already handles M=1 efficiently.
Engine should prefer serial decode when custom GEMV is available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 22:22:31 +08:00
parent 876d3f5d6a
commit e207523e21
4 changed files with 159 additions and 27 deletions

102
csrc/gemm/gemv.cu Normal file
View File

@@ -0,0 +1,102 @@
#include <cuda_bf16.h>
#include <cuda_runtime.h>
// Custom GEMV kernel for M=1 decode step (BF16):
// y[n] = sum_k x[k] * W[k * N + n]
// where x: [K] (BF16), W: [K, N] (BF16, row-major), y: [N] (BF16).
//
// Design: K-split for high occupancy on large GPU (170 SMs).
// Grid: (N / TILE_N, K / TILE_K) — each block computes a partial sum
// for TILE_N output columns over a TILE_K slice of K.
// Partial results are atomicAdd'd to an FP32 accumulator, then a
// second kernel converts FP32 -> BF16.
//
// Memory access: adjacent threads read adjacent columns of the same row
// of W, giving perfectly coalesced 128-byte transactions.
#define GEMV_TILE_N 128
#define GEMV_TILE_K 256
#define GEMV_BLOCK 128 // = TILE_N, one thread per output column
__global__ void gemv_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [K]
const __nv_bfloat16* __restrict__ W, // [K, N] row-major
float* __restrict__ y_fp32, // [N] accumulator
int K, int N
) {
const int block_n = blockIdx.x;
const int block_k = blockIdx.y;
const int t = threadIdx.x;
const int col = block_n * GEMV_TILE_N + t;
if (col >= N) return;
const int k_start = block_k * GEMV_TILE_K;
const int k_end = min(k_start + GEMV_TILE_K, K);
const int k_len = k_end - k_start;
// Load x[k_start..k_end] into shared memory as FP32
__shared__ float x_shared[GEMV_TILE_K];
for (int i = t; i < k_len; i += GEMV_BLOCK) {
x_shared[i] = __bfloat162float(x[k_start + i]);
}
__syncthreads();
// Compute partial dot product for this column
float sum = 0.0f;
for (int ki = 0; ki < k_len; ki++) {
sum += x_shared[ki] * __bfloat162float(W[(k_start + ki) * N + col]);
}
// Atomic accumulate (handles K-split reduction)
atomicAdd(&y_fp32[col], sum);
}
// Conversion kernel: FP32 accumulator -> BF16 output
__global__ void gemv_fp32_to_bf16_kernel(
const float* __restrict__ src,
__nv_bfloat16* __restrict__ dst,
int n
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = __float2bfloat16(src[idx]);
}
}
extern "C" {
void launch_gemv_bf16(
const void* x, // [K] BF16
const void* W, // [K, N] BF16 row-major
void* y_bf16, // [N] BF16 output
void* y_fp32_buf, // [N] FP32 temporary (caller-provided)
int K, int N,
void* stream
) {
cudaStream_t s = (cudaStream_t)stream;
// Zero the FP32 accumulator
cudaMemsetAsync((float*)y_fp32_buf, 0, N * sizeof(float), s);
// Launch GEMV kernel
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N,
(K + GEMV_TILE_K - 1) / GEMV_TILE_K);
gemv_bf16_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
(const __nv_bfloat16*)x,
(const __nv_bfloat16*)W,
(float*)y_fp32_buf,
K, N
);
// Convert FP32 -> BF16
int conv_block = 256;
int conv_grid = (N + conv_block - 1) / conv_block;
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
(const float*)y_fp32_buf,
(__nv_bfloat16*)y_bf16,
N
);
}
} // extern "C"