BF16 greedy decode was sensitive to inter-block scheduling when logits were close, which broke speculative-decoding verify-vs-decode parity. - gemv.cu: write per-K-block partials, then reduce in fixed block order in a second kernel instead of atomicAdd across K-blocks. Scratch buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems() exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/ intermediate/vocab from it. - paged_attention.cu: replace atomicAdd merge of warp outputs with per-warp shared partials reduced in warp-id order for both the base and sinks kernels.
108 lines
3.2 KiB
Plaintext
108 lines
3.2 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <cuda_runtime.h>
|
|
#include "../common.cuh"
|
|
|
|
// K-split GEMV for M=1 BF16 decode.
|
|
//
|
|
// y[n] = sum_k x[k] * W[k * N + n]
|
|
//
|
|
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
|
|
// fixed-order reduction over K blocks. The previous implementation used
|
|
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
|
|
// inter-block scheduling when logits were close.
|
|
|
|
#define GEMV_TILE_N 128
|
|
#define GEMV_TILE_K 256
|
|
#define GEMV_BLOCK 128
|
|
|
|
__global__ void gemv_bf16_partial_kernel(
|
|
const __nv_bfloat16* __restrict__ x,
|
|
const __nv_bfloat16* __restrict__ W,
|
|
float* __restrict__ partials,
|
|
int K, int N
|
|
) {
|
|
const int block_n = blockIdx.x;
|
|
const int block_k = blockIdx.y;
|
|
const int t = threadIdx.x;
|
|
const int col = block_n * GEMV_TILE_N + t;
|
|
|
|
const int k_start = block_k * GEMV_TILE_K;
|
|
const int k_end = min(k_start + GEMV_TILE_K, K);
|
|
const int k_len = k_end - k_start;
|
|
|
|
// Cooperative load of x into shared memory uses ALL threads in the block
|
|
// (indexed by t, independent of col). Threads whose column is out of range
|
|
// must still help load and reach the barrier — returning early here would
|
|
// leave part of x_shared uninitialized AND make __syncthreads divergent
|
|
// (UB). So the col>=N check happens only AFTER the load + barrier. This bug
|
|
// produced intermittent huge/garbage outputs whenever N % GEMV_TILE_N != 0
|
|
// (e.g. gpt-oss decode o_proj with N=2880), collapsing the forward pass.
|
|
__shared__ float x_shared[GEMV_TILE_K];
|
|
for (int i = t; i < k_len; i += GEMV_BLOCK) {
|
|
x_shared[i] = __bfloat162float(x[k_start + i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (col >= N) return;
|
|
|
|
float sum = 0.0f;
|
|
for (int ki = 0; ki < k_len; ki++) {
|
|
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
|
}
|
|
|
|
partials[(long long)block_k * N + col] = sum;
|
|
}
|
|
|
|
__global__ void gemv_reduce_to_bf16_kernel(
|
|
const float* __restrict__ partials,
|
|
__nv_bfloat16* __restrict__ dst,
|
|
int n,
|
|
int num_k_blocks
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
if (idx < n) {
|
|
float sum = 0.0f;
|
|
for (int kb = 0; kb < num_k_blocks; kb++) {
|
|
sum += partials[(long long)kb * n + idx];
|
|
}
|
|
dst[idx] = __float2bfloat16(sum);
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_gemv_bf16(
|
|
const void* x,
|
|
const void* W,
|
|
void* y_bf16,
|
|
void* y_fp32_buf,
|
|
int K, int N,
|
|
void* stream
|
|
) {
|
|
cudaStream_t s = (cudaStream_t)stream;
|
|
|
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
|
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
|
|
|
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
|
(const __nv_bfloat16*)x,
|
|
(const __nv_bfloat16*)W,
|
|
(float*)y_fp32_buf,
|
|
K, N
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
|
|
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
|
|
int conv_block = 256;
|
|
int conv_grid = (N + conv_block - 1) / conv_block;
|
|
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
|
(const float*)y_fp32_buf,
|
|
(__nv_bfloat16*)y_bf16,
|
|
N,
|
|
num_k_blocks
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
} // extern "C"
|