Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid launch (z = batch row) with numerically identical output to M sequential launch_gemv_bf16 calls — same K-block partial accumulation, same fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens: matched=true, verify_decode_mismatches=0. Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3 forward_verify_paged_decode_attention; the per-row loop over matmul_2d + concat_rows is replaced by a single matmul_batched_gemv call that allocates the partials buffer in one shot and launches 2 kernels instead of 2*M. Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×); the batched launch saves ~3 ms overhead but this is small relative to the total 28 ms spec cost. The path forward (per docs/24 §4) is higher acceptance rate or cheaper draft, not further kernel optimization.
197 lines
6.1 KiB
Plaintext
197 lines
6.1 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <cuda_runtime.h>
|
|
#include "../common.cuh"
|
|
|
|
// K-split GEMV for M=1 BF16 decode.
|
|
//
|
|
// y[n] = sum_k x[k] * W[k * N + n]
|
|
//
|
|
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
|
|
// fixed-order reduction over K blocks. The previous implementation used
|
|
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
|
|
// inter-block scheduling when logits were close.
|
|
|
|
#define GEMV_TILE_N 128
|
|
#define GEMV_TILE_K 256
|
|
#define GEMV_BLOCK 128
|
|
|
|
__global__ void gemv_bf16_partial_kernel(
|
|
const __nv_bfloat16* __restrict__ x,
|
|
const __nv_bfloat16* __restrict__ W,
|
|
float* __restrict__ partials,
|
|
int K, int N
|
|
) {
|
|
const int block_n = blockIdx.x;
|
|
const int block_k = blockIdx.y;
|
|
const int t = threadIdx.x;
|
|
const int col = block_n * GEMV_TILE_N + t;
|
|
|
|
const int k_start = block_k * GEMV_TILE_K;
|
|
const int k_end = min(k_start + GEMV_TILE_K, K);
|
|
const int k_len = k_end - k_start;
|
|
|
|
// Cooperative load of x into shared memory uses ALL threads in the block
|
|
// (indexed by t, independent of col). Threads whose column is out of range
|
|
// must still help load and reach the barrier — returning early here would
|
|
// leave part of x_shared uninitialized AND make __syncthreads divergent
|
|
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
|
|
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
|
|
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
|
|
__shared__ float x_shared[GEMV_TILE_K];
|
|
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
|
x_shared[i] = __bfloat162float(x[k_start + i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (col >= N) return;
|
|
|
|
float sum = 0.0f;
|
|
for (int ki = 0; ki < k_len; ki++) {
|
|
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
|
}
|
|
|
|
partials[(long long)block_k * N + col] = sum;
|
|
}
|
|
|
|
__global__ void gemv_reduce_to_bf16_kernel(
|
|
const float* __restrict__ partials,
|
|
__nv_bfloat16* __restrict__ dst,
|
|
int n,
|
|
int num_k_blocks
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx < n) {
|
|
float sum = 0.0f;
|
|
for (int kb = 0; kb < num_k_blocks; kb++) {
|
|
sum += partials[(long long)kb * n + idx];
|
|
}
|
|
dst[idx] = __float2bfloat16(sum);
|
|
}
|
|
}
|
|
|
|
// Batched variant: M rows, same W. Grid.z = batch row index.
|
|
// Numerically identical to calling launch_gemv_bf16 M times in sequence because
|
|
// each z-slice executes the same accumulation order on the same data.
|
|
// partials buffer must be [M * num_k_blocks * N] floats.
|
|
__global__ void gemv_bf16_batched_partial_kernel(
|
|
const __nv_bfloat16* __restrict__ x, // [M, K]
|
|
const __nv_bfloat16* __restrict__ W, // [K, N]
|
|
float* __restrict__ partials, // [M, num_k_blocks, N]
|
|
int K, int N
|
|
) {
|
|
const int block_n = blockIdx.x;
|
|
const int block_k = blockIdx.y;
|
|
const int row = blockIdx.z;
|
|
const int t = threadIdx.x;
|
|
const int col = block_n * GEMV_TILE_N + t;
|
|
|
|
const int k_start = block_k * GEMV_TILE_K;
|
|
const int k_end = min(k_start + GEMV_TILE_K, K);
|
|
const int k_len = k_end - k_start;
|
|
|
|
__shared__ float x_shared[GEMV_TILE_K];
|
|
const __nv_bfloat16* x_row = x + (long long)row * K;
|
|
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
|
x_shared[i] = __bfloat162float(x_row[k_start + i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (col >= N) return;
|
|
|
|
float sum = 0.0f;
|
|
for (int ki = 0; ki < k_len; ki++) {
|
|
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
|
}
|
|
|
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
|
partials[((long long)row * num_k_blocks + block_k) * N + col] = sum;
|
|
}
|
|
|
|
__global__ void gemv_batched_reduce_to_bf16_kernel(
|
|
const float* __restrict__ partials, // [M, num_k_blocks, N]
|
|
__nv_bfloat16* __restrict__ dst, // [M, N]
|
|
int n,
|
|
int num_k_blocks
|
|
) {
|
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int row = blockIdx.y;
|
|
if (col >= n) return;
|
|
|
|
float sum = 0.0f;
|
|
const float* row_partials = partials + (long long)row * num_k_blocks * n;
|
|
for (int kb = 0; kb < num_k_blocks; kb++) {
|
|
sum += row_partials[(long long)kb * n + col];
|
|
}
|
|
dst[(long long)row * n + col] = __float2bfloat16(sum);
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_gemv_bf16(
|
|
const void* x,
|
|
const void* W,
|
|
void* y_bf16,
|
|
void* y_fp32_buf,
|
|
int K, int N,
|
|
void* stream
|
|
) {
|
|
cudaStream_t s = (cudaStream_t)stream;
|
|
|
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
|
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
|
|
|
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
|
(const __nv_bfloat16*)x,
|
|
(const __nv_bfloat16*)W,
|
|
(float*)y_fp32_buf,
|
|
K, N
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
|
|
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
|
|
int conv_block = 256;
|
|
int conv_grid = (N + conv_block - 1) / conv_block;
|
|
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
|
(const float*)y_fp32_buf,
|
|
(__nv_bfloat16*)y_bf16,
|
|
N,
|
|
num_k_blocks
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_gemv_bf16_batched(
|
|
const void* x, // [M, K] BF16
|
|
const void* W, // [K, N] BF16
|
|
void* y_bf16, // [M, N] BF16
|
|
void* y_fp32_buf, // [M * num_k_blocks * N] FP32
|
|
int M, int K, int N,
|
|
void* stream
|
|
) {
|
|
cudaStream_t s = (cudaStream_t)stream;
|
|
|
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
|
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M);
|
|
|
|
gemv_bf16_batched_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
|
(const __nv_bfloat16*)x,
|
|
(const __nv_bfloat16*)W,
|
|
(float*)y_fp32_buf,
|
|
K, N
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
|
|
int conv_block = 256;
|
|
int conv_grid_x = (N + conv_block - 1) / conv_block;
|
|
dim3 reduce_grid(conv_grid_x, M);
|
|
gemv_batched_reduce_to_bf16_kernel<<<reduce_grid, conv_block, 0, s>>>(
|
|
(const float*)y_fp32_buf,
|
|
(__nv_bfloat16*)y_bf16,
|
|
N,
|
|
num_k_blocks
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
} // extern "C"
|