cuda: deterministic BF16 gemv + paged attention reductions

BF16 greedy decode was sensitive to inter-block scheduling when logits
were close, which broke speculative-decoding verify-vs-decode parity.

- gemv.cu: write per-K-block partials, then reduce in fixed block order
  in a second kernel instead of atomicAdd across K-blocks. Scratch
  buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems()
  exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/
  intermediate/vocab from it.
- paged_attention.cu: replace atomicAdd merge of warp outputs with
  per-warp shared partials reduced in warp-id order for both the base
  and sinks kernels.
This commit is contained in:
2026-07-01 14:14:55 +08:00
parent 0314b4f3ac
commit 5b350ee5f0
4 changed files with 59 additions and 43 deletions

View File

@@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
const GEMV_TILE_K: usize = 256;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
@@ -26,6 +27,10 @@ pub enum GemmBackend {
CuBlas,
}
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
n * k.div_ceil(GEMV_TILE_K)
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(
@@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
let mut fp32_buf =
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr,

View File

@@ -9,7 +9,7 @@
use std::ffi::c_void;
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
use xserv_kernels::dispatch;
use xserv_kernels::gemm::cublas_handle;
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
use crate::config::ModelConfig;
use crate::kv_cache::GpuKVCache;
@@ -54,7 +54,7 @@ struct DecodeBuffers {
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
// GEMV fp32 accumulators (separate per output dimension)
// GEMV fp32 scratch for deterministic K-block partials.
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
@@ -140,11 +140,14 @@ impl DecodeGraphState {
up: alloc(intermediate * es),
silu_out: alloc(intermediate * es),
fp32_hidden: alloc(hidden * 4),
fp32_q: alloc(num_heads * head_dim * 4),
fp32_kv: alloc(num_kv_heads * head_dim * 4),
fp32_intermediate: alloc(intermediate * 4),
fp32_vocab: alloc(vocab_size * 4),
fp32_hidden: alloc(
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
* 4,
),
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
token_id_gpu: alloc(4),
position_gpu: alloc(4),

View File

@@ -118,7 +118,7 @@ __global__ void paged_decode_attention_bf16_kernel(
// ---- Block-level online softmax reduction ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
@@ -164,8 +164,12 @@ __global__ void paged_decode_attention_bf16_kernel(
__syncthreads();
global_sum = smem_sum[0];
// Step 4: reduce O across block, dim by dim
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
// Step 4: reduce O across block, dim by dim. Store one partial per warp
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
// when logits were close.
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
@@ -173,13 +177,15 @@ __global__ void paged_decode_attention_bf16_kernel(
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}
@@ -289,7 +295,7 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
// ---- Block-level online softmax reduction (same as base kernel) ----
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
@@ -332,7 +338,9 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
__syncthreads();
global_sum = smem_sum[0];
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
@@ -340,13 +348,15 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) atomicAdd(&smem_O[d], val);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}

View File

@@ -6,22 +6,20 @@
//
// y[n] = sum_k x[k] * W[k * N + n]
//
// Grid: (N / TILE_N, K / TILE_K).
// All blocks atomicAdd their partial sums into a pre-zeroed FP32 buffer.
// A separate conversion kernel writes the final BF16 output.
// Launch sequence: cudaMemsetAsync(fp32) → accumulation kernel → convert kernel.
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
// fixed-order reduction over K blocks. The previous implementation used
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
// inter-block scheduling when logits were close.
#define GEMV_TILE_N 128
#define GEMV_TILE_K 256
#define GEMV_BLOCK 128
__global__ void gemv_bf16_fused_kernel(
__global__ void gemv_bf16_partial_kernel(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ W,
__nv_bfloat16* __restrict__ y_bf16,
float* __restrict__ y_fp32,
int K, int N,
int num_k_blocks
float* __restrict__ partials,
int K, int N
) {
const int block_n = blockIdx.x;
const int block_k = blockIdx.y;
@@ -52,18 +50,22 @@ __global__ void gemv_bf16_fused_kernel(
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
}
atomicAdd(&y_fp32[col], sum);
partials[(long long)block_k * N + col] = sum;
}
// Conversion kernel: FP32 accumulator -> BF16 output
__global__ void gemv_fp32_to_bf16_kernel(
const float* __restrict__ src,
__global__ void gemv_reduce_to_bf16_kernel(
const float* __restrict__ partials,
__nv_bfloat16* __restrict__ dst,
int n
int n,
int num_k_blocks
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
dst[idx] = __float2bfloat16(src[idx]);
float sum = 0.0f;
for (int kb = 0; kb < num_k_blocks; kb++) {
sum += partials[(long long)kb * n + idx];
}
dst[idx] = __float2bfloat16(sum);
}
}
@@ -79,30 +81,25 @@ void launch_gemv_bf16(
) {
cudaStream_t s = (cudaStream_t)stream;
// Zero the FP32 accumulator BEFORE the kernel — the kernel uses atomicAdd
// across K-blocks with no inter-block ordering, so the buffer must be
// pre-zeroed to avoid accumulating on stale data.
cudaMemsetAsync(y_fp32_buf, 0, (size_t)N * sizeof(float), s);
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
gemv_bf16_fused_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
(const __nv_bfloat16*)x,
(const __nv_bfloat16*)W,
(__nv_bfloat16*)y_bf16,
(float*)y_fp32_buf,
K, N, num_k_blocks
K, N
);
CUDA_CHECK_LAST_ERROR();
// FP32 → BF16 conversion (must wait for all K-blocks to finish)
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
int conv_block = 256;
int conv_grid = (N + conv_block - 1) / conv_block;
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
(const float*)y_fp32_buf,
(__nv_bfloat16*)y_bf16,
N
N,
num_k_blocks
);
CUDA_CHECK_LAST_ERROR();
}