cuda: deterministic BF16 gemv + paged attention reductions
BF16 greedy decode was sensitive to inter-block scheduling when logits were close, which broke speculative-decoding verify-vs-decode parity. - gemv.cu: write per-K-block partials, then reduce in fixed block order in a second kernel instead of atomicAdd across K-blocks. Scratch buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems() exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/ intermediate/vocab from it. - paged_attention.cu: replace atomicAdd merge of warp outputs with per-warp shared partials reduced in warp-id order for both the base and sinks kernels.
This commit is contained in:
@@ -118,7 +118,7 @@ __global__ void paged_decode_attention_bf16_kernel(
|
||||
// ---- Block-level online softmax reduction ----
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
@@ -164,8 +164,12 @@ __global__ void paged_decode_attention_bf16_kernel(
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
// Step 4: reduce O across block, dim by dim
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
||||
// Step 4: reduce O across block, dim by dim. Store one partial per warp
|
||||
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||
// when logits were close.
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
@@ -173,13 +177,15 @@ __global__ void paged_decode_attention_bf16_kernel(
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) atomicAdd(&smem_O[d], val);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,7 +295,7 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
||||
// ---- Block-level online softmax reduction (same as base kernel) ----
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
@@ -332,7 +338,9 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
@@ -340,13 +348,15 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) atomicAdd(&smem_O[d], val);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,22 +6,20 @@
|
||||
//
|
||||
// y[n] = sum_k x[k] * W[k * N + n]
|
||||
//
|
||||
// Grid: (N / TILE_N, K / TILE_K).
|
||||
// All blocks atomicAdd their partial sums into a pre-zeroed FP32 buffer.
|
||||
// A separate conversion kernel writes the final BF16 output.
|
||||
// Launch sequence: cudaMemsetAsync(fp32) → accumulation kernel → convert kernel.
|
||||
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
|
||||
// fixed-order reduction over K blocks. The previous implementation used
|
||||
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
|
||||
// inter-block scheduling when logits were close.
|
||||
|
||||
#define GEMV_TILE_N 128
|
||||
#define GEMV_TILE_K 256
|
||||
#define GEMV_BLOCK 128
|
||||
|
||||
__global__ void gemv_bf16_fused_kernel(
|
||||
__global__ void gemv_bf16_partial_kernel(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ W,
|
||||
__nv_bfloat16* __restrict__ y_bf16,
|
||||
float* __restrict__ y_fp32,
|
||||
int K, int N,
|
||||
int num_k_blocks
|
||||
float* __restrict__ partials,
|
||||
int K, int N
|
||||
) {
|
||||
const int block_n = blockIdx.x;
|
||||
const int block_k = blockIdx.y;
|
||||
@@ -52,18 +50,22 @@ __global__ void gemv_bf16_fused_kernel(
|
||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||
}
|
||||
|
||||
atomicAdd(&y_fp32[col], sum);
|
||||
partials[(long long)block_k * N + col] = sum;
|
||||
}
|
||||
|
||||
// Conversion kernel: FP32 accumulator -> BF16 output
|
||||
__global__ void gemv_fp32_to_bf16_kernel(
|
||||
const float* __restrict__ src,
|
||||
__global__ void gemv_reduce_to_bf16_kernel(
|
||||
const float* __restrict__ partials,
|
||||
__nv_bfloat16* __restrict__ dst,
|
||||
int n
|
||||
int n,
|
||||
int num_k_blocks
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) {
|
||||
dst[idx] = __float2bfloat16(src[idx]);
|
||||
float sum = 0.0f;
|
||||
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||
sum += partials[(long long)kb * n + idx];
|
||||
}
|
||||
dst[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,30 +81,25 @@ void launch_gemv_bf16(
|
||||
) {
|
||||
cudaStream_t s = (cudaStream_t)stream;
|
||||
|
||||
// Zero the FP32 accumulator BEFORE the kernel — the kernel uses atomicAdd
|
||||
// across K-blocks with no inter-block ordering, so the buffer must be
|
||||
// pre-zeroed to avoid accumulating on stale data.
|
||||
cudaMemsetAsync(y_fp32_buf, 0, (size_t)N * sizeof(float), s);
|
||||
|
||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
||||
|
||||
gemv_bf16_fused_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||
(const __nv_bfloat16*)x,
|
||||
(const __nv_bfloat16*)W,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
(float*)y_fp32_buf,
|
||||
K, N, num_k_blocks
|
||||
K, N
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
|
||||
// FP32 → BF16 conversion (must wait for all K-blocks to finish)
|
||||
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
|
||||
int conv_block = 256;
|
||||
int conv_grid = (N + conv_block - 1) / conv_block;
|
||||
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||
(const float*)y_fp32_buf,
|
||||
(__nv_bfloat16*)y_bf16,
|
||||
N
|
||||
N,
|
||||
num_k_blocks
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user