cuda: deterministic BF16 gemv + paged attention reductions

BF16 greedy decode was sensitive to inter-block scheduling when logits
were close, which broke speculative-decoding verify-vs-decode parity.

- gemv.cu: write per-K-block partials, then reduce in fixed block order
  in a second kernel instead of atomicAdd across K-blocks. Scratch
  buffer size is now n * ceil(k / GEMV_TILE_K); gemv_scratch_elems()
  exposes this to callers, and decode_graph.rs sizes fp32_hidden/q/kv/
  intermediate/vocab from it.
- paged_attention.cu: replace atomicAdd merge of warp outputs with
  per-warp shared partials reduced in warp-id order for both the base
  and sinks kernels.
This commit is contained in:
2026-07-01 14:14:55 +08:00
parent 0314b4f3ac
commit 5b350ee5f0
4 changed files with 59 additions and 43 deletions

View File

@@ -5,6 +5,7 @@ use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
const GEMV_TILE_K: usize = 256;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
@@ -26,6 +27,10 @@ pub enum GemmBackend {
CuBlas,
}
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
n * k.div_ceil(GEMV_TILE_K)
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(
@@ -274,7 +279,8 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap();
let mut fp32_buf =
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr,

View File

@@ -9,7 +9,7 @@
use std::ffi::c_void;
use xserv_cuda::{CudaGraph, CudaStream, GpuBuffer};
use xserv_kernels::dispatch;
use xserv_kernels::gemm::cublas_handle;
use xserv_kernels::gemm::{cublas_handle, gemv_scratch_elems};
use crate::config::ModelConfig;
use crate::kv_cache::GpuKVCache;
@@ -54,7 +54,7 @@ struct DecodeBuffers {
up: GpuBuffer, // [1, intermediate]
silu_out: GpuBuffer, // [1, intermediate]
// GEMV fp32 accumulators (separate per output dimension)
// GEMV fp32 scratch for deterministic K-block partials.
fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs
fp32_q: GpuBuffer, // for Q projection
fp32_kv: GpuBuffer, // for K/V projection
@@ -140,11 +140,14 @@ impl DecodeGraphState {
up: alloc(intermediate * es),
silu_out: alloc(intermediate * es),
fp32_hidden: alloc(hidden * 4),
fp32_q: alloc(num_heads * head_dim * 4),
fp32_kv: alloc(num_kv_heads * head_dim * 4),
fp32_intermediate: alloc(intermediate * 4),
fp32_vocab: alloc(vocab_size * 4),
fp32_hidden: alloc(
gemv_scratch_elems(hidden, hidden).max(gemv_scratch_elems(intermediate, hidden))
* 4,
),
fp32_q: alloc(gemv_scratch_elems(hidden, num_heads * head_dim) * 4),
fp32_kv: alloc(gemv_scratch_elems(hidden, num_kv_heads * head_dim) * 4),
fp32_intermediate: alloc(gemv_scratch_elems(hidden, intermediate) * 4),
fp32_vocab: alloc(gemv_scratch_elems(hidden, vocab_size) * 4),
token_id_gpu: alloc(4),
position_gpu: alloc(4),