diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs index 827c1d5..eebc2af 100644 --- a/crates/xserv-kernels/src/gemm.rs +++ b/crates/xserv-kernels/src/gemm.rs @@ -18,6 +18,17 @@ unsafe extern "C" { n: i32, stream: *mut c_void, ); + + fn launch_gemv_bf16_batched( + x: *const c_void, + w: *const c_void, + y_bf16: *mut c_void, + y_fp32_buf: *mut c_void, + m: i32, + k: i32, + n: i32, + stream: *mut c_void, + ); } #[derive(Debug, Clone, Copy)] @@ -31,6 +42,55 @@ pub fn gemv_scratch_elems(k: usize, n: usize) -> usize { n * k.div_ceil(GEMV_TILE_K) } +/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16. +/// Bit-exact with calling matmul on each row individually (same K-block partial +/// + fixed-order reduction path), but in a single kernel launch per phase. +pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor { + assert_eq!(a.ndim(), 2); + assert_eq!(b.ndim(), 2); + assert!(a.is_contiguous()); + assert!(b.is_contiguous()); + assert_eq!(a.dtype(), DType::BF16); + assert_eq!(b.dtype(), DType::BF16); + let m = a.shape()[0]; + let k = a.shape()[1]; + let n = b.shape()[1]; + assert_eq!(b.shape()[0], k); + + let out = Tensor::empty(&[m, n], DType::BF16, a.device()); + let scratch_elems = m * gemv_scratch_elems(k, n); + let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap(); + + let null_stream = xserv_cuda::current_stream_raw(); + if m == 1 { + unsafe { + launch_gemv_bf16( + a.data_ptr() as *const c_void, + b.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + fp32_buf.as_mut_ptr() as *mut c_void, + k as i32, + n as i32, + null_stream, + ); + } + } else { + unsafe { + launch_gemv_bf16_batched( + a.data_ptr() as *const c_void, + b.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + fp32_buf.as_mut_ptr() as *mut c_void, + m as i32, + k as i32, + n as i32, + null_stream, + ); + } + } + out +} + // --- FFI: custom CUDA kernels --- unsafe extern "C" { fn launch_gemm_naive_f32( diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 98e6150..7e62694 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -19,7 +19,7 @@ pub use attention::{ paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16, }; pub use embedding::{embedding, embedding_device_ids}; -pub use gemm::{GemmBackend, batched_matmul, matmul}; +pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv}; pub use layernorm::layernorm; pub use rmsnorm::{add_rmsnorm, rmsnorm}; pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos}; diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 376d89c..933eee4 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -923,7 +923,7 @@ impl Qwen3 { let residual = x.clone(); let normed = rmsnorm(&x, &layer.input_norm, eps); - let qkv = matmul_rows_gemv(&normed, &layer.qkv_proj_wt); + let qkv = matmul_batched_gemv(&normed, &layer.qkv_proj_wt); let q_dim = num_heads * head_dim; let kv_dim = num_kv_heads * head_dim; let q_all = qkv.narrow(1, 0, q_dim); @@ -966,25 +966,25 @@ impl Qwen3 { ); let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]); - let attn_proj = matmul_rows_gemv(&attn_merged, &layer.o_proj_wt); + let attn_proj = matmul_batched_gemv(&attn_merged, &layer.o_proj_wt); self.all_reduce(&attn_proj); let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); - let gate_up = matmul_rows_gemv(&normed, &layer.gate_up_proj_wt); + let gate_up = matmul_batched_gemv(&normed, &layer.gate_up_proj_wt); let ffn_dim = gate_up.shape()[1] / 2; let gate = gate_up.narrow(1, 0, ffn_dim).contiguous(); let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous(); let hidden_states = xserv_kernels::silu_mul(&gate, &up); - let down = matmul_rows_gemv(&hidden_states, &layer.down_proj_wt); + let down = matmul_batched_gemv(&hidden_states, &layer.down_proj_wt); self.all_reduce(&down); x = add_any(&residual, &down); } let x = rmsnorm(&x, &self.norm, eps); - matmul_rows_gemv(&x, &self.lm_head_t) + matmul_batched_gemv(&x, &self.lm_head_t) } /// Forward with GPU-resident KV cache and GPU transpose/reshape kernels. @@ -1261,20 +1261,6 @@ fn row_view(t: &Tensor, row: usize) -> Tensor { ) } -/// Run a 2D matmul row by row so each row uses the same GEMV kernel as -/// single-token decode. Used by speculative verify parity, where near-tie -/// logits must follow decode's BF16 rounding path. -fn matmul_rows_gemv(a: &Tensor, b: &Tensor) -> Tensor { - assert_eq!(a.ndim(), 2); - assert!(a.is_contiguous()); - let rows = a.shape()[0]; - if rows == 1 { - return matmul_2d(a, b); - } - let out_rows: Vec = (0..rows).map(|i| matmul_2d(&row_view(a, i), b)).collect(); - concat_rows(&out_rows) -} - /// Concatenate row tensors [1, cols] into a single [B, cols] tensor via D2D memcpy. fn concat_rows(rows: &[Tensor]) -> Tensor { assert!(!rows.is_empty()); diff --git a/csrc/gemm/gemv.cu b/csrc/gemm/gemv.cu index cb32433..1193fb3 100644 --- a/csrc/gemm/gemv.cu +++ b/csrc/gemm/gemv.cu @@ -69,6 +69,62 @@ __global__ void gemv_reduce_to_bf16_kernel( } } +// Batched variant: M rows, same W. Grid.z = batch row index. +// Numerically identical to calling launch_gemv_bf16 M times in sequence because +// each z-slice executes the same accumulation order on the same data. +// partials buffer must be [M * num_k_blocks * N] floats. +__global__ void gemv_bf16_batched_partial_kernel( + const __nv_bfloat16* __restrict__ x, // [M, K] + const __nv_bfloat16* __restrict__ W, // [K, N] + float* __restrict__ partials, // [M, num_k_blocks, N] + int K, int N +) { + const int block_n = blockIdx.x; + const int block_k = blockIdx.y; + const int row = blockIdx.z; + const int t = threadIdx.x; + const int col = block_n * GEMV_TILE_N + t; + + const int k_start = block_k * GEMV_TILE_K; + const int k_end = min(k_start + GEMV_TILE_K, K); + const int k_len = k_end - k_start; + + __shared__ float x_shared[GEMV_TILE_K]; + const __nv_bfloat16* x_row = x + (long long)row * K; + for (int i = t; i < k_len; i += GEMV_BLOCK) { + x_shared[i] = __bfloat162float(x_row[k_start + i]); + } + __syncthreads(); + + if (col >= N) return; + + float sum = 0.0f; + for (int ki = 0; ki < k_len; ki++) { + sum += x_shared[ki] * __bfloat162float(W[(long long)(k_start + ki) * N + col]); + } + + int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K; + partials[((long long)row * num_k_blocks + block_k) * N + col] = sum; +} + +__global__ void gemv_batched_reduce_to_bf16_kernel( + const float* __restrict__ partials, // [M, num_k_blocks, N] + __nv_bfloat16* __restrict__ dst, // [M, N] + int n, + int num_k_blocks +) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + int row = blockIdx.y; + if (col >= n) return; + + float sum = 0.0f; + const float* row_partials = partials + (long long)row * num_k_blocks * n; + for (int kb = 0; kb < num_k_blocks; kb++) { + sum += row_partials[(long long)kb * n + col]; + } + dst[(long long)row * n + col] = __float2bfloat16(sum); +} + extern "C" { void launch_gemv_bf16( @@ -104,4 +160,37 @@ void launch_gemv_bf16( CUDA_CHECK_LAST_ERROR(); } +void launch_gemv_bf16_batched( + const void* x, // [M, K] BF16 + const void* W, // [K, N] BF16 + void* y_bf16, // [M, N] BF16 + void* y_fp32_buf, // [M * num_k_blocks * N] FP32 + int M, int K, int N, + void* stream +) { + cudaStream_t s = (cudaStream_t)stream; + + int num_k_blocks = (K + GEMV_TILE_K - 1) / GEMV_TILE_K; + dim3 grid((N + GEMV_TILE_N - 1) / GEMV_TILE_N, num_k_blocks, M); + + gemv_bf16_batched_partial_kernel<<>>( + (const __nv_bfloat16*)x, + (const __nv_bfloat16*)W, + (float*)y_fp32_buf, + K, N + ); + CUDA_CHECK_LAST_ERROR(); + + int conv_block = 256; + int conv_grid_x = (N + conv_block - 1) / conv_block; + dim3 reduce_grid(conv_grid_x, M); + gemv_batched_reduce_to_bf16_kernel<<>>( + (const float*)y_fp32_buf, + (__nv_bfloat16*)y_bf16, + N, + num_k_blocks + ); + CUDA_CHECK_LAST_ERROR(); +} + } // extern "C"