diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs index fc919c6..827c1d5 100644 --- a/crates/xserv-kernels/src/gemm.rs +++ b/crates/xserv-kernels/src/gemm.rs @@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result}; use xserv_tensor::{DType, Device, Tensor}; const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024; +const GEMV_TILE_K: usize = 256; // GEMV: single-kernel, no FP32 temp buffer needed unsafe extern "C" { @@ -26,6 +27,10 @@ pub enum GemmBackend { CuBlas, } +pub fn gemv_scratch_elems(k: usize, n: usize) -> usize { + n * k.div_ceil(GEMV_TILE_K) +} + // --- FFI: custom CUDA kernels --- unsafe extern "C" { fn launch_gemm_naive_f32( @@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { }, GemmBackend::CuBlas => { if m == 1 && dtype == DType::BF16 && n >= 256 { - let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap(); + let mut fp32_buf = + xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap(); unsafe { launch_gemv_bf16( a_ptr, diff --git a/crates/xserv-model/src/decode_graph.rs b/crates/xserv-model/src/decode_graph.rs index ed518e4..74bad8f 100644 --- a/crates/xserv-model/src/decode_graph.rs +++ b/crates/xserv-model/src/decode_graph.rs @@ -9,7 +9,7 @@ use std::ffi::c_void; use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer}; use xserv_kernels::dispatch; -use xserv_kernels::gemm::cublas_handle; +use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems}; use crate::config::ModelConfig; use crate::kv_cache::GpuKVCache; @@ -54,7 +54,7 @@ struct DecodeBuffers { up: GpuBuffer, // [1, intermediate] silu_out: GpuBuffer, // [1, intermediate] - // GEMV fp32 accumulators (separate per output dimension) + // GEMV fp32 scratch for deterministic K-block partials. fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs fp32_q: GpuBuffer, // for Q projection fp32_kv: GpuBuffer, // for K/V projection @@ -140,11 +140,14 @@ impl DecodeGraphState { up: alloc(intermediate * es), silu_out: alloc(intermediate * es), - fp32_hidden: alloc(hidden * 4), - fp32_q: alloc(num_heads * head_dim * 4), - fp32_kv: alloc(num_kv_heads * head_dim * 4), - fp32_intermediate: alloc(intermediate * 4), - fp32_vocab: alloc(vocab_size * 4), + fp32_hidden: alloc( + gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden)) + * 4, + ), + fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4), + fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4), + fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4), + fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4), token_id_gpu: alloc(4), position_gpu: alloc(4), diff --git a/csrc/attention/paged_attention.cu b/csrc/attention/paged_attention.cu index a249665..29c3066 100644 --- a/csrc/attention/paged_attention.cu +++ b/csrc/attention/paged_attention.cu @@ -118,7 +118,7 @@ __global__ void paged_decode_attention_bf16_kernel( // ---- Block-level online softmax reduction ---- __shared__ float smem_max[32]; __shared__ float smem_sum[32]; - __shared__ float smem_O[PAGED_HEAD_DIM_MAX]; + __shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX]; int lane = tid & 31; int warp_id = tid >> 5; @@ -164,8 +164,12 @@ __global__ void paged_decode_attention_bf16_kernel( __syncthreads(); global_sum = smem_sum[0]; - // Step 4: reduce O across block, dim by dim - for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f; + // Step 4: reduce O across block, dim by dim. Store one partial per warp + // and sum in warp-id order; atomicAdd made greedy decode nondeterministic + // when logits were close. + for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) { + reinterpret_cast(smem_O_warp)[i] = 0.0f; + } __syncthreads(); for (int d = 0; d < head_dim; d++) { @@ -173,13 +177,15 @@ __global__ void paged_decode_attention_bf16_kernel( #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) val += __shfl_down_sync(0xffffffff, val, offset); - if (lane == 0) atomicAdd(&smem_O[d], val); + if (lane == 0) smem_O_warp[warp_id][d] = val; } __syncthreads(); float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f; for (int d = tid; d < head_dim; d += PAGED_THREADS) { - O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum); + float out = 0.0f; + for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d]; + O_ptr[d] = __float2bfloat16(out * inv_sum); } } @@ -289,7 +295,7 @@ __global__ void paged_decode_attention_sinks_bf16_kernel( // ---- Block-level online softmax reduction (same as base kernel) ---- __shared__ float smem_max[32]; __shared__ float smem_sum[32]; - __shared__ float smem_O[PAGED_HEAD_DIM_MAX]; + __shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX]; int lane = tid & 31; int warp_id = tid >> 5; @@ -332,7 +338,9 @@ __global__ void paged_decode_attention_sinks_bf16_kernel( __syncthreads(); global_sum = smem_sum[0]; - for (int d = tid; d < head_dim; d += PAGED_THREADS) smem_O[d] = 0.0f; + for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) { + reinterpret_cast(smem_O_warp)[i] = 0.0f; + } __syncthreads(); for (int d = 0; d < head_dim; d++) { @@ -340,13 +348,15 @@ __global__ void paged_decode_attention_sinks_bf16_kernel( #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) val += __shfl_down_sync(0xffffffff, val, offset); - if (lane == 0) atomicAdd(&smem_O[d], val); + if (lane == 0) smem_O_warp[warp_id][d] = val; } __syncthreads(); float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f; for (int d = tid; d < head_dim; d += PAGED_THREADS) { - O_ptr[d] = __float2bfloat16(smem_O[d] * inv_sum); + float out = 0.0f; + for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d]; + O_ptr[d] = __float2bfloat16(out * inv_sum); } } diff --git a/csrc/gemm/gemv.cu b/csrc/gemm/gemv.cu index 4554f94..cb32433 100644 --- a/csrc/gemm/gemv.cu +++ b/csrc/gemm/gemv.cu @@ -6,22 +6,20 @@ // // y[n] = sum_k x[k] * W[k * N + n] // -// Grid: (N / TILE_N, K / TILE_K). -// All blocks atomicAdd their partial sums into a pre-zeroed FP32 buffer. -// A separate conversion kernel writes the final BF16 output. -// Launch sequence: cudaMemsetAsync(fp32) → accumulation kernel → convert kernel. +// Grid: (N / TILE_N, K / TILE_K) partials, followed by a deterministic +// fixed-order reduction over K blocks. The previous implementation used +// atomicAdd into y_fp32[col]; that made BF16 greedy decode sensitive to +// inter-block scheduling when logits were close. #define GEMV_TILE_N 128 #define GEMV_TILE_K 256 #define GEMV_BLOCK 128 -__global__ void gemv_bf16_fused_kernel( +__global__ void gemv_bf16_partial_kernel( const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ W, - __nv_bfloat16* __restrict__ y_bf16, - float* __restrict__ y_fp32, - int K, int N, - int num_k_blocks + float* __restrict__ partials, + int K, int N ) { const int block_n = blockIdx.x; const int block_k = blockIdx.y; @@ -52,18 +50,22 @@ __global__ void gemv_bf16_fused_kernel( sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]); } - atomicAdd(&y_fp32[col], sum); + partials[(long long)block_k * N + col] = sum; } -// Conversion kernel: FP32 accumulator -> BF16 output -__global__ void gemv_fp32_to_bf16_kernel( - const float* __restrict__ src, +__global__ void gemv_reduce_to_bf16_kernel( + const float* __restrict__ partials, __nv_bfloat16* __restrict__ dst, - int n + int n, + int num_k_blocks ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { - dst[idx] = __float2bfloat16(src[idx]); + float sum = 0.0f; + for (int kb = 0; kb < num_k_blocks; kb++) { + sum += partials[(long long)kb * n + idx]; + } + dst[idx] = __float2bfloat16(sum); } } @@ -79,30 +81,25 @@ void launch_gemv_bf16( ) { cudaStream_t s = (cudaStream_t)stream; - // Zero the FP32 accumulator BEFORE the kernel — the kernel uses atomicAdd - // across K-blocks with no inter-block ordering, so the buffer must be - // pre-zeroed to avoid accumulating on stale data. - cudaMemsetAsync(y_fp32_buf, 0, (size_t)N * sizeof(float), s); - int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K; dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks); - gemv_bf16_fused_kernel<<>>( + gemv_bf16_partial_kernel<<>>( (const __nv_bfloat16*)x, (const __nv_bfloat16*)W, - (__nv_bfloat16*)y_bf16, (float*)y_fp32_buf, - K, N, num_k_blocks + K, N ); CUDA_CHECK_LAST_ERROR(); - // FP32 → BF16 conversion (must wait for all K-blocks to finish) + // Fixed-order FP32 reduction over K blocks, then BF16 conversion. int conv_block = 256; int conv_grid = (N + conv_block - 1) / conv_block; - gemv_fp32_to_bf16_kernel<<>>( + gemv_reduce_to_bf16_kernel<<>>( (const float*)y_fp32_buf, (__nv_bfloat16*)y_bf16, - N + N, + num_k_blocks ); CUDA_CHECK_LAST_ERROR(); }