cuda: deterministic BF16 gemv + paged attention reductions
BF16 greedy decode was sensitive to inter-block scheduling when logits were close, which broke speculative-decoding verify-vs-decode parity. - gemv.cu: write per-K-block partials, then reduce in fixed block order in a second kernel instead of atomicAdd across K-blocks. Scratch buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems() exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/ intermediate/vocab from it. - paged_attention.cu: replace atomicAdd merge of warp outputs with per-warp shared partials reduced in warp-id order for both the base and sinks kernels.
This commit is contained in:
@@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result};
|
|||||||
use xserv_tensor::{DType, Device, Tensor};
|
use xserv_tensor::{DType, Device, Tensor};
|
||||||
|
|
||||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||||
|
const GEMV_TILE_K: usize = 256;
|
||||||
|
|
||||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
@@ -26,6 +27,10 @@ pub enum GemmBackend {
|
|||||||
CuBlas,
|
CuBlas,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
||||||
|
n * k.div_ceil(GEMV_TILE_K)
|
||||||
|
}
|
||||||
|
|
||||||
// --- FFI: custom CUDA kernels ---
|
// --- FFI: custom CUDA kernels ---
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
fn launch_gemm_naive_f32(
|
fn launch_gemm_naive_f32(
|
||||||
@@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
|||||||
},
|
},
|
||||||
GemmBackend::CuBlas => {
|
GemmBackend::CuBlas => {
|
||||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
let mut fp32_buf =
|
||||||
|
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
|
||||||
unsafe {
|
unsafe {
|
||||||
launch_gemv_bf16(
|
launch_gemv_bf16(
|
||||||
a_ptr,
|
a_ptr,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||||
use xserv_kernels::dispatch;
|
use xserv_kernels::dispatch;
|
||||||
use xserv_kernels::gemm::cublas_handle;
|
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
|
||||||
|
|
||||||
use crate::config::ModelConfig;
|
use crate::config::ModelConfig;
|
||||||
use crate::kv_cache::GpuKVCache;
|
use crate::kv_cache::GpuKVCache;
|
||||||
@@ -54,7 +54,7 @@ struct DecodeBuffers {
|
|||||||
up: GpuBuffer, // [1, intermediate]
|
up: GpuBuffer, // [1, intermediate]
|
||||||
silu_out: GpuBuffer, // [1, intermediate]
|
silu_out: GpuBuffer, // [1, intermediate]
|
||||||
|
|
||||||
// GEMV fp32 accumulators (separate per output dimension)
|
// GEMV fp32 scratch for deterministic K-block partials.
|
||||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||||
fp32_q: GpuBuffer, // for Q projection
|
fp32_q: GpuBuffer, // for Q projection
|
||||||
fp32_kv: GpuBuffer, // for K/V projection
|
fp32_kv: GpuBuffer, // for K/V projection
|
||||||
@@ -140,11 +140,14 @@ impl DecodeGraphState {
|
|||||||
up: alloc(intermediate * es),
|
up: alloc(intermediate * es),
|
||||||
silu_out: alloc(intermediate * es),
|
silu_out: alloc(intermediate * es),
|
||||||
|
|
||||||
fp32_hidden: alloc(hidden * 4),
|
fp32_hidden: alloc(
|
||||||
fp32_q: alloc(num_heads * head_dim * 4),
|
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
|
||||||
fp32_kv: alloc(num_kv_heads * head_dim * 4),
|
* 4,
|
||||||
fp32_intermediate: alloc(intermediate * 4),
|
),
|
||||||
fp32_vocab: alloc(vocab_size * 4),
|
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
|
||||||
|
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
|
||||||
|
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
|
||||||
|
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
|
||||||
|
|
||||||
token_id_gpu: alloc(4),
|
token_id_gpu: alloc(4),
|
||||||
position_gpu: alloc(4),
|
position_gpu: alloc(4),
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ __global__ void paged_decode_attention_bf16_kernel(
|
|||||||
// ---- Block-level online softmax reduction ----
|
// ---- Block-level online softmax reduction ----
|
||||||
__shared__ float smem_max[32];
|
__shared__ float smem_max[32];
|
||||||
__shared__ float smem_sum[32];
|
__shared__ float smem_sum[32];
|
||||||
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||||
|
|
||||||
int lane = tid & 31;
|
int lane = tid & 31;
|
||||||
int warp_id = tid >> 5;
|
int warp_id = tid >> 5;
|
||||||
@@ -164,8 +164,12 @@ __global__ void paged_decode_attention_bf16_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
global_sum = smem_sum[0];
|
global_sum = smem_sum[0];
|
||||||
|
|
||||||
// Step 4: reduce O across block, dim by dim
|
// Step 4: reduce O across block, dim by dim. Store one partial per warp
|
||||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
||||||
|
// when logits were close.
|
||||||
|
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||||
|
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int d = 0; d < head_dim; d++) {
|
for (int d = 0; d < head_dim; d++) {
|
||||||
@@ -173,13 +177,15 @@ __global__ void paged_decode_attention_bf16_kernel(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1)
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||||
if (lane == 0) atomicAdd(&smem_O[d], val);
|
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||||
|
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,7 +295,7 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
|||||||
// ---- Block-level online softmax reduction (same as base kernel) ----
|
// ---- Block-level online softmax reduction (same as base kernel) ----
|
||||||
__shared__ float smem_max[32];
|
__shared__ float smem_max[32];
|
||||||
__shared__ float smem_sum[32];
|
__shared__ float smem_sum[32];
|
||||||
__shared__ float smem_O[PAGED_HEAD_DIM_MAX];
|
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||||
|
|
||||||
int lane = tid & 31;
|
int lane = tid & 31;
|
||||||
int warp_id = tid >> 5;
|
int warp_id = tid >> 5;
|
||||||
@@ -332,7 +338,9 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
global_sum = smem_sum[0];
|
global_sum = smem_sum[0];
|
||||||
|
|
||||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f;
|
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||||
|
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int d = 0; d < head_dim; d++) {
|
for (int d = 0; d < head_dim; d++) {
|
||||||
@@ -340,13 +348,15 @@ __global__ void paged_decode_attention_sinks_bf16_kernel(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int offset = 16; offset > 0; offset >>= 1)
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||||
if (lane == 0) atomicAdd(&smem_O[d], val);
|
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||||
O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum);
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||||
|
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,22 +6,20 @@
|
|||||||
//
|
//
|
||||||
// y[n] = sum_k x[k] * W[k * N + n]
|
// y[n] = sum_k x[k] * W[k * N + n]
|
||||||
//
|
//
|
||||||
// Grid: (N / TILE_N, K / TILE_K).
|
// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic
|
||||||
// All blocks atomicAdd their partial sums into a pre-zeroed FP32 buffer.
|
// fixed-order reduction over K blocks. The previous implementation used
|
||||||
// A separate conversion kernel writes the final BF16 output.
|
// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to
|
||||||
// Launch sequence: cudaMemsetAsync(fp32) → accumulation kernel → convert kernel.
|
// inter-block scheduling when logits were close.
|
||||||
|
|
||||||
#define GEMV_TILE_N 128
|
#define GEMV_TILE_N 128
|
||||||
#define GEMV_TILE_K 256
|
#define GEMV_TILE_K 256
|
||||||
#define GEMV_BLOCK 128
|
#define GEMV_BLOCK 128
|
||||||
|
|
||||||
__global__ void gemv_bf16_fused_kernel(
|
__global__ void gemv_bf16_partial_kernel(
|
||||||
const __nv_bfloat16* __restrict__ x,
|
const __nv_bfloat16* __restrict__ x,
|
||||||
const __nv_bfloat16* __restrict__ W,
|
const __nv_bfloat16* __restrict__ W,
|
||||||
__nv_bfloat16* __restrict__ y_bf16,
|
float* __restrict__ partials,
|
||||||
float* __restrict__ y_fp32,
|
int K, int N
|
||||||
int K, int N,
|
|
||||||
int num_k_blocks
|
|
||||||
) {
|
) {
|
||||||
const int block_n = blockIdx.x;
|
const int block_n = blockIdx.x;
|
||||||
const int block_k = blockIdx.y;
|
const int block_k = blockIdx.y;
|
||||||
@@ -52,18 +50,22 @@ __global__ void gemv_bf16_fused_kernel(
|
|||||||
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]);
|
||||||
}
|
}
|
||||||
|
|
||||||
atomicAdd(&y_fp32[col], sum);
|
partials[(long long)block_k * N + col] = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conversion kernel: FP32 accumulator -> BF16 output
|
__global__ void gemv_reduce_to_bf16_kernel(
|
||||||
__global__ void gemv_fp32_to_bf16_kernel(
|
const float* __restrict__ partials,
|
||||||
const float* __restrict__ src,
|
|
||||||
__nv_bfloat16* __restrict__ dst,
|
__nv_bfloat16* __restrict__ dst,
|
||||||
int n
|
int n,
|
||||||
|
int num_k_blocks
|
||||||
) {
|
) {
|
||||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (idx < n) {
|
if (idx < n) {
|
||||||
dst[idx] = __float2bfloat16(src[idx]);
|
float sum = 0.0f;
|
||||||
|
for (int kb = 0; kb < num_k_blocks; kb++) {
|
||||||
|
sum += partials[(long long)kb * n + idx];
|
||||||
|
}
|
||||||
|
dst[idx] = __float2bfloat16(sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,30 +81,25 @@ void launch_gemv_bf16(
|
|||||||
) {
|
) {
|
||||||
cudaStream_t s = (cudaStream_t)stream;
|
cudaStream_t s = (cudaStream_t)stream;
|
||||||
|
|
||||||
// Zero the FP32 accumulator BEFORE the kernel — the kernel uses atomicAdd
|
|
||||||
// across K-blocks with no inter-block ordering, so the buffer must be
|
|
||||||
// pre-zeroed to avoid accumulating on stale data.
|
|
||||||
cudaMemsetAsync(y_fp32_buf, 0, (size_t)N * sizeof(float), s);
|
|
||||||
|
|
||||||
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K;
|
||||||
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks);
|
||||||
|
|
||||||
gemv_bf16_fused_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
gemv_bf16_partial_kernel<<<grid, GEMV_BLOCK, 0, s>>>(
|
||||||
(const __nv_bfloat16*)x,
|
(const __nv_bfloat16*)x,
|
||||||
(const __nv_bfloat16*)W,
|
(const __nv_bfloat16*)W,
|
||||||
(__nv_bfloat16*)y_bf16,
|
|
||||||
(float*)y_fp32_buf,
|
(float*)y_fp32_buf,
|
||||||
K, N, num_k_blocks
|
K, N
|
||||||
);
|
);
|
||||||
CUDA_CHECK_LAST_ERROR();
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
|
||||||
// FP32 → BF16 conversion (must wait for all K-blocks to finish)
|
// Fixed-order FP32 reduction over K blocks, then BF16 conversion.
|
||||||
int conv_block = 256;
|
int conv_block = 256;
|
||||||
int conv_grid = (N + conv_block - 1) / conv_block;
|
int conv_grid = (N + conv_block - 1) / conv_block;
|
||||||
gemv_fp32_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
gemv_reduce_to_bf16_kernel<<<conv_grid, conv_block, 0, s>>>(
|
||||||
(const float*)y_fp32_buf,
|
(const float*)y_fp32_buf,
|
||||||
(__nv_bfloat16*)y_bf16,
|
(__nv_bfloat16*)y_bf16,
|
||||||
N
|
N,
|
||||||
|
num_k_blocks
|
||||||
);
|
);
|
||||||
CUDA_CHECK_LAST_ERROR();
|
CUDA_CHECK_LAST_ERROR();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user