Files
xserv/crates/xserv-kernels/src/gemm.rs
Gahow Wang e5734b41fa speculative: batched-GEMV kernel for verify path (Phase 24 step 1)
Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid
launch (z = batch row) with numerically identical output to M sequential
launch_gemv_bf16 calls — same K-block partial accumulation, same
fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens:
matched=true, verify_decode_mismatches=0.

Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in
xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3
forward_verify_paged_decode_attention; the per-row loop over matmul_2d +
concat_rows is replaced by a single matmul_batched_gemv call that
allocates the partials buffer in one shot and launches 2 kernels instead
of 2*M.

Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×);
the batched launch saves ~3 ms overhead but this is small relative to
the total 28 ms spec cost. The path forward (per docs/24 §4) is
higher acceptance rate or cheaper draft, not further kernel optimization.
2026-07-01 16:13:37 +08:00

473 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::cell::RefCell;
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_cuda::error::{self, Result};
use xserv_tensor::{DType, Device, Tensor};
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
const GEMV_TILE_K: usize = 256;
// GEMV: single-kernel, no FP32 temp buffer needed
unsafe extern "C" {
fn launch_gemv_bf16(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
k: i32,
n: i32,
stream: *mut c_void,
);
fn launch_gemv_bf16_batched(
x: *const c_void,
w: *const c_void,
y_bf16: *mut c_void,
y_fp32_buf: *mut c_void,
m: i32,
k: i32,
n: i32,
stream: *mut c_void,
);
}
#[derive(Debug, Clone, Copy)]
pub enum GemmBackend {
Naive,
Tiled,
CuBlas,
}
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
n * k.div_ceil(GEMV_TILE_K)
}
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
/// Bit-exact with calling matmul on each row individually (same K-block partial
/// + fixed-order reduction path), but in a single kernel launch per phase.
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
assert!(a.is_contiguous());
assert!(b.is_contiguous());
assert_eq!(a.dtype(), DType::BF16);
assert_eq!(b.dtype(), DType::BF16);
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
assert_eq!(b.shape()[0], k);
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
let scratch_elems = m * gemv_scratch_elems(k, n);
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
let null_stream = xserv_cuda::current_stream_raw();
if m == 1 {
unsafe {
launch_gemv_bf16(
a.data_ptr() as *const c_void,
b.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
fp32_buf.as_mut_ptr() as *mut c_void,
k as i32,
n as i32,
null_stream,
);
}
} else {
unsafe {
launch_gemv_bf16_batched(
a.data_ptr() as *const c_void,
b.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
fp32_buf.as_mut_ptr() as *mut c_void,
m as i32,
k as i32,
n as i32,
null_stream,
);
}
}
out
}
// --- FFI: custom CUDA kernels ---
unsafe extern "C" {
fn launch_gemm_naive_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_naive_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_f32(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
fn launch_gemm_tiled_bf16(
a: *const c_void,
b: *const c_void,
c: *mut c_void,
m: i32,
n: i32,
k: i32,
stream: *mut c_void,
);
}
// --- FFI: cuBLAS ---
pub type CublasHandle = *mut c_void;
#[allow(non_upper_case_globals)]
const CUBLAS_OP_N: i32 = 0;
// cudaDataType
const CUDA_R_32F: i32 = 0;
const CUDA_R_16BF: i32 = 14;
// cublasComputeType
const CUBLAS_COMPUTE_32F: i32 = 68;
unsafe extern "C" {
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
fn cublasGemmEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void,
a_type: i32,
lda: i32,
b: *const c_void,
b_type: i32,
ldb: i32,
beta: *const c_void,
c: *mut c_void,
c_type: i32,
ldc: i32,
compute_type: i32,
algo: i32,
) -> i32;
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: *const c_void,
a: *const c_void,
a_type: i32,
lda: i32,
stride_a: i64,
b: *const c_void,
b_type: i32,
ldb: i32,
stride_b: i64,
beta: *const c_void,
c: *mut c_void,
c_type: i32,
ldc: i32,
stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
) -> i32;
}
pub struct CublasContext {
handle: CublasHandle,
/// Dedicated 32 MiB workspace owned by this handle. Held to keep the GPU
/// buffer alive for the lifetime of the handle; cuBLAS reads/writes into
/// it during GEMM. Dropped after `cublasDestroy_v2` so cuBLAS can't touch
/// freed memory.
_workspace: Option<GpuBuffer>,
}
impl CublasContext {
pub fn new() -> Result<Self> {
let mut handle = std::ptr::null_mut();
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
// Attach a per-handle workspace. cublasSetWorkspace requires the
// pointer to remain valid until destroy or until a new workspace is
// set, so we keep the GpuBuffer in this struct.
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
error::check(unsafe {
cublasSetWorkspace_v2(
handle,
workspace.as_mut_ptr() as *mut c_void,
CUBLAS_WORKSPACE_BYTES,
)
})?;
Ok(Self {
handle,
_workspace: Some(workspace),
})
}
}
impl Drop for CublasContext {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { cublasDestroy_v2(self.handle) };
}
// _workspace drops here, after cublasDestroy_v2 has released the handle.
}
}
thread_local! {
static CUBLAS_CTX: RefCell<CublasContext> = RefCell::new(
CublasContext::new().expect("failed to create thread-local cuBLAS handle")
);
}
/// Borrow the thread-local cuBLAS handle for the duration of a closure.
fn with_cublas<F, R>(f: F) -> R
where
F: FnOnce(CublasHandle) -> R,
{
CUBLAS_CTX.with(|cell| {
let ctx = cell.borrow();
f(ctx.handle)
})
}
/// Get the thread-local cuBLAS handle for use with dispatch module.
pub fn cublas_handle() -> CublasHandle {
CUBLAS_CTX.with(|cell| cell.borrow().handle)
}
/// Matrix multiplication: C = A @ B
/// A: [M, K], B: [K, N], C: [M, N]
/// All tensors must be contiguous and on the same GPU.
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
assert!(
a.is_contiguous() && b.is_contiguous(),
"matmul requires contiguous tensors"
);
assert!(
matches!(a.device(), Device::Cuda(_)),
"matmul requires GPU tensors"
);
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
let dtype = a.dtype();
// All backends (naive, tiled, cuBLAS with beta=0, custom GEMV) fully
// overwrite every element of C, so we skip the cudaMemset.
let c = Tensor::empty(&[m, n], dtype, a.device());
let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void;
let c_ptr = c.data_ptr() as *mut c_void;
let null_stream = xserv_cuda::current_stream_raw();
match backend {
GemmBackend::Naive => unsafe {
match dtype {
DType::F32 => launch_gemm_naive_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_naive_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for naive GEMM"),
}
},
GemmBackend::Tiled => unsafe {
match dtype {
DType::F32 => launch_gemm_tiled_f32(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
DType::BF16 => launch_gemm_tiled_bf16(
a_ptr,
b_ptr,
c_ptr,
m as i32,
n as i32,
k as i32,
null_stream,
),
_ => panic!("unsupported dtype for tiled GEMM"),
}
},
GemmBackend::CuBlas => {
if m == 1 && dtype == DType::BF16 && n >= 256 {
let mut fp32_buf =
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
unsafe {
launch_gemv_bf16(
a_ptr,
b_ptr,
c_ptr,
fp32_buf.as_mut_ptr() as *mut c_void,
k as i32,
n as i32,
null_stream,
);
}
} else {
let alpha = 1.0f32;
let beta = 0.0f32;
let (a_type, b_type, c_type) = match dtype {
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
_ => panic!("unsupported dtype for cuBLAS GEMM"),
};
with_cublas(|handle| unsafe {
cublasSetStream_v2(handle, null_stream);
error::check(cublasGemmEx(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b_ptr,
b_type,
n as i32,
a_ptr,
a_type,
k as i32,
&beta as *const f32 as *const c_void,
c_ptr,
c_type,
n as i32,
CUBLAS_COMPUTE_32F,
-1,
))
.expect("cuBLAS GEMM failed");
});
}
}
}
c
}
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
/// Leading dimensions must match and tensors must be contiguous.
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
assert!(a.ndim() >= 2 && b.ndim() >= 2);
assert_eq!(a.ndim(), b.ndim());
assert!(a.is_contiguous() && b.is_contiguous());
assert!(matches!(a.device(), Device::Cuda(_)));
assert_eq!(a.dtype(), b.dtype());
let ndim = a.ndim();
let m = a.shape()[ndim - 2];
let k = a.shape()[ndim - 1];
let n = b.shape()[ndim - 1];
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
// Compute batch count from leading dimensions
let batch: usize = a.shape()[..ndim - 2].iter().product();
assert_eq!(
b.shape()[..ndim - 2].iter().product::<usize>(),
batch,
"batch dimensions mismatch"
);
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
out_shape.push(m);
out_shape.push(n);
// cuBLAS with beta=0 fully overwrites every element of C.
let c = Tensor::empty(&out_shape, a.dtype(), a.device());
let dtype = a.dtype();
let (a_type, b_type, c_type) = match dtype {
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
_ => panic!("unsupported dtype for batched matmul"),
};
let alpha = 1.0f32;
let beta = 0.0f32;
// cuBLAS strides are in elements (not bytes)
let stride_a = (m * k) as i64;
let stride_b = (k * n) as i64;
let stride_c = (m * n) as i64;
with_cublas(|handle| unsafe {
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
error::check(cublasGemmStridedBatchedEx(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
n as i32,
m as i32,
k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as _,
b_type,
n as i32,
stride_b,
a.data_ptr() as _,
a_type,
k as i32,
stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void,
c_type,
n as i32,
stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
-1,
))
.expect("cuBLAS batched GEMM failed");
});
c
}