cuda: deterministic BF16 gemv + paged attention reductions
BF16 greedy decode was sensitive to inter-block scheduling when logits were close, which broke speculative-decoding verify-vs-decode parity. - gemv.cu: write per-K-block partials, then reduce in fixed block order in a second kernel instead of atomicAdd across K-blocks. Scratch buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems() exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/ intermediate/vocab from it. - paged_attention.cu: replace atomicAdd merge of warp outputs with per-warp shared partials reduced in warp-id order for both the base and sinks kernels.
This commit is contained in:
@@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result};
|
||||
use xserv_tensor::{DType, Device, Tensor};
|
||||
|
||||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
const GEMV_TILE_K: usize = 256;
|
||||
|
||||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||||
unsafe extern "C" {
|
||||
@@ -26,6 +27,10 @@ pub enum GemmBackend {
|
||||
CuBlas,
|
||||
}
|
||||
|
||||
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
||||
n * k.div_ceil(GEMV_TILE_K)
|
||||
}
|
||||
|
||||
// --- FFI: custom CUDA kernels ---
|
||||
unsafe extern "C" {
|
||||
fn launch_gemm_naive_f32(
|
||||
@@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
},
|
||||
GemmBackend::CuBlas => {
|
||||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
|
||||
let mut fp32_buf =
|
||||
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
|
||||
unsafe {
|
||||
launch_gemv_bf16(
|
||||
a_ptr,
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
|
||||
use xserv_kernels::dispatch;
|
||||
use xserv_kernels::gemm::cublas_handle;
|
||||
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
|
||||
|
||||
use crate::config::ModelConfig;
|
||||
use crate::kv_cache::GpuKVCache;
|
||||
@@ -54,7 +54,7 @@ struct DecodeBuffers {
|
||||
up: GpuBuffer, // [1, intermediate]
|
||||
silu_out: GpuBuffer, // [1, intermediate]
|
||||
|
||||
// GEMV fp32 accumulators (separate per output dimension)
|
||||
// GEMV fp32 scratch for deterministic K-block partials.
|
||||
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
|
||||
fp32_q: GpuBuffer, // for Q projection
|
||||
fp32_kv: GpuBuffer, // for K/V projection
|
||||
@@ -140,11 +140,14 @@ impl DecodeGraphState {
|
||||
up: alloc(intermediate * es),
|
||||
silu_out: alloc(intermediate * es),
|
||||
|
||||
fp32_hidden: alloc(hidden * 4),
|
||||
fp32_q: alloc(num_heads * head_dim * 4),
|
||||
fp32_kv: alloc(num_kv_heads * head_dim * 4),
|
||||
fp32_intermediate: alloc(intermediate * 4),
|
||||
fp32_vocab: alloc(vocab_size * 4),
|
||||
fp32_hidden: alloc(
|
||||
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
|
||||
* 4,
|
||||
),
|
||||
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
|
||||
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
|
||||
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
|
||||
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
|
||||
|
||||
token_id_gpu: alloc(4),
|
||||
position_gpu: alloc(4),
|
||||
|
||||
Reference in New Issue
Block a user