Files
xserv/csrc/gemm/gemv.cu
Gahow Wang 5b350ee5f0 cuda: deterministic BF16 gemv + paged attention reductions
BF16 greedy decode was sensitive to inter-block scheduling when logits
were close, which broke speculative-decoding verify-vs-decode parity.

- gemv.cu: write per-K-block partials, then reduce in fixed block order
  in a second kernel instead of atomicAdd across K-blocks. Scratch
  buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems()
  exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/
  intermediate/vocab from it.
- paged_attention.cu: replace atomicAdd merge of warp outputs with
  per-warp shared partials reduced in warp-id order for both the base
  and sinks kernels.
2026-07-01 14:16:28 +08:00

108 lines
3.2 KiB
Plaintext

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "../common.cuh"
// K-split GEMV for M=1 BF16 decode.
//
// y[n] = sum_k x[k] * W[k * N + n]
//
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
// fixed-order reduction over K blocks. The previous implementation used
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
// inter-block scheduling when logits were close.
#define GEMV_TILE_N 128
#define GEMV_TILE_K 256
#define GEMV_BLOCK 128
__global__ void gemv_bf16_partial_kernel(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ W,
float* __restrict__ partials,
int K, int N
) {
const int block_n = blockIdx.x;
const int block_k = blockIdx.y;
const int t = threadIdx.x;
const int col = block_n * GEMV_TILE_N + t;
const int k_start = block_k * GEMV_TILE_K;
const int k_end = min(k_start + GEMV_TILE_K, K);
const int k_len = k_end - k_start;
// Cooperative load of x into shared memory uses ALL threads in the block
// (indexed by t, independent of col). Threads whose column is out of range
// must still help load and reach the barrier — returning early here would
// leave part of x_shared uninitialized AND make __syncthreads divergent
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
__shared__ float x_shared[GEMV_TILE_K];
for (int i = t; i < k_len; i += GEMV_BLOCK) {
x_shared[i] = __bfloat162float(x[k_start + i]);
}
__syncthreads();
if (col >= N) return;
float sum = 0.0f;
for (int ki = 0; ki < k_len; ki++) {
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
}
partials[(long long)block_k * N + col] = sum;
}
__global__ void gemv_reduce_to_bf16_kernel(
const float* __restrict__ partials,
__nv_bfloat16* __restrict__ dst,
int n,
int num_k_blocks
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
float sum = 0.0f;
for (int kb = 0; kb < num_k_blocks; kb++) {
sum += partials[(long long)kb * n + idx];
}
dst[idx] = __float2bfloat16(sum);
}
}
extern "C" {
void launch_gemv_bf16(
const void* x,
const void* W,
void* y_bf16,
void* y_fp32_buf,
int K, int N,
void* stream
) {
cudaStream_t s = (cudaStream_t)stream;
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
(const __nv_bfloat16*)x,
(const __nv_bfloat16*)W,
(float*)y_fp32_buf,
K, N
);
CUDA_CHECK_LAST_ERROR();
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
int conv_block = 256;
int conv_grid = (N + conv_block - 1) / conv_block;
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
(const float*)y_fp32_buf,
(__nv_bfloat16*)y_bf16,
N,
num_k_blocks
);
CUDA_CHECK_LAST_ERROR();
}
} // extern "C"