Add launch_gemv_bf16_batched: runs M m=1 GEMVs in a single 3D grid launch (z = batch row) with numerically identical output to M sequential launch_gemv_bf16 calls — same K-block partial accumulation, same fixed-order reduction. Verified on dash5 with 10 prompts × 32 tokens: matched=true, verify_decode_mismatches=0. Expose as matmul_batched_gemv(a: [M,K], b: [K,N]) → [M,N] in xserv-kernels. Replace the old matmul_rows_gemv helper in qwen3 forward_verify_paged_decode_attention; the per-row loop over matmul_2d + concat_rows is replaced by a single matmul_batched_gemv call that allocates the partials buffer in one shot and launches 2 kernels instead of 2*M. Current speedup_e2e is 0.47× (same ballpark as Phase 23 0.44×); the batched launch saves ~3 ms overhead but this is small relative to the total 28 ms spec cost. The path forward (per docs/24 §4) is higher acceptance rate or cheaper draft, not further kernel optimization.
473 lines
14 KiB
Rust
473 lines
14 KiB
Rust
use std::cell::RefCell;
|
||
use std::ffi::c_void;
|
||
use xserv_cuda::GpuBuffer;
|
||
use xserv_cuda::error::{self, Result};
|
||
use xserv_tensor::{DType, Device, Tensor};
|
||
|
||
const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||
const GEMV_TILE_K: usize = 256;
|
||
|
||
// GEMV: single-kernel, no FP32 temp buffer needed
|
||
unsafe extern "C" {
|
||
fn launch_gemv_bf16(
|
||
x: *const c_void,
|
||
w: *const c_void,
|
||
y_bf16: *mut c_void,
|
||
y_fp32_buf: *mut c_void,
|
||
k: i32,
|
||
n: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
|
||
fn launch_gemv_bf16_batched(
|
||
x: *const c_void,
|
||
w: *const c_void,
|
||
y_bf16: *mut c_void,
|
||
y_fp32_buf: *mut c_void,
|
||
m: i32,
|
||
k: i32,
|
||
n: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy)]
|
||
pub enum GemmBackend {
|
||
Naive,
|
||
Tiled,
|
||
CuBlas,
|
||
}
|
||
|
||
pub fn gemv_scratch_elems(k: usize, n: usize) -> usize {
|
||
n * k.div_ceil(GEMV_TILE_K)
|
||
}
|
||
|
||
/// Batched GEMV: [M, K] × [K, N] → [M, N], all BF16.
|
||
/// Bit-exact with calling matmul on each row individually (same K-block partial
|
||
/// + fixed-order reduction path), but in a single kernel launch per phase.
|
||
pub fn matmul_batched_gemv(a: &Tensor, b: &Tensor) -> Tensor {
|
||
assert_eq!(a.ndim(), 2);
|
||
assert_eq!(b.ndim(), 2);
|
||
assert!(a.is_contiguous());
|
||
assert!(b.is_contiguous());
|
||
assert_eq!(a.dtype(), DType::BF16);
|
||
assert_eq!(b.dtype(), DType::BF16);
|
||
let m = a.shape()[0];
|
||
let k = a.shape()[1];
|
||
let n = b.shape()[1];
|
||
assert_eq!(b.shape()[0], k);
|
||
|
||
let out = Tensor::empty(&[m, n], DType::BF16, a.device());
|
||
let scratch_elems = m * gemv_scratch_elems(k, n);
|
||
let mut fp32_buf = xserv_cuda::allocator::cached_alloc(scratch_elems * 4).unwrap();
|
||
|
||
let null_stream = xserv_cuda::current_stream_raw();
|
||
if m == 1 {
|
||
unsafe {
|
||
launch_gemv_bf16(
|
||
a.data_ptr() as *const c_void,
|
||
b.data_ptr() as *const c_void,
|
||
out.data_ptr() as *mut c_void,
|
||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||
k as i32,
|
||
n as i32,
|
||
null_stream,
|
||
);
|
||
}
|
||
} else {
|
||
unsafe {
|
||
launch_gemv_bf16_batched(
|
||
a.data_ptr() as *const c_void,
|
||
b.data_ptr() as *const c_void,
|
||
out.data_ptr() as *mut c_void,
|
||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||
m as i32,
|
||
k as i32,
|
||
n as i32,
|
||
null_stream,
|
||
);
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
// --- FFI: custom CUDA kernels ---
|
||
unsafe extern "C" {
|
||
fn launch_gemm_naive_f32(
|
||
a: *const c_void,
|
||
b: *const c_void,
|
||
c: *mut c_void,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
fn launch_gemm_naive_bf16(
|
||
a: *const c_void,
|
||
b: *const c_void,
|
||
c: *mut c_void,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
fn launch_gemm_tiled_f32(
|
||
a: *const c_void,
|
||
b: *const c_void,
|
||
c: *mut c_void,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
fn launch_gemm_tiled_bf16(
|
||
a: *const c_void,
|
||
b: *const c_void,
|
||
c: *mut c_void,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
stream: *mut c_void,
|
||
);
|
||
}
|
||
|
||
// --- FFI: cuBLAS ---
|
||
pub type CublasHandle = *mut c_void;
|
||
|
||
#[allow(non_upper_case_globals)]
|
||
const CUBLAS_OP_N: i32 = 0;
|
||
|
||
// cudaDataType
|
||
const CUDA_R_32F: i32 = 0;
|
||
const CUDA_R_16BF: i32 = 14;
|
||
|
||
// cublasComputeType
|
||
const CUBLAS_COMPUTE_32F: i32 = 68;
|
||
|
||
unsafe extern "C" {
|
||
fn cublasCreate_v2(handle: *mut CublasHandle) -> i32;
|
||
fn cublasDestroy_v2(handle: CublasHandle) -> i32;
|
||
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
|
||
fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32;
|
||
fn cublasGemmEx(
|
||
handle: CublasHandle,
|
||
transa: i32,
|
||
transb: i32,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
alpha: *const c_void,
|
||
a: *const c_void,
|
||
a_type: i32,
|
||
lda: i32,
|
||
b: *const c_void,
|
||
b_type: i32,
|
||
ldb: i32,
|
||
beta: *const c_void,
|
||
c: *mut c_void,
|
||
c_type: i32,
|
||
ldc: i32,
|
||
compute_type: i32,
|
||
algo: i32,
|
||
) -> i32;
|
||
fn cublasGemmStridedBatchedEx(
|
||
handle: CublasHandle,
|
||
transa: i32,
|
||
transb: i32,
|
||
m: i32,
|
||
n: i32,
|
||
k: i32,
|
||
alpha: *const c_void,
|
||
a: *const c_void,
|
||
a_type: i32,
|
||
lda: i32,
|
||
stride_a: i64,
|
||
b: *const c_void,
|
||
b_type: i32,
|
||
ldb: i32,
|
||
stride_b: i64,
|
||
beta: *const c_void,
|
||
c: *mut c_void,
|
||
c_type: i32,
|
||
ldc: i32,
|
||
stride_c: i64,
|
||
batch_count: i32,
|
||
compute_type: i32,
|
||
algo: i32,
|
||
) -> i32;
|
||
}
|
||
|
||
pub struct CublasContext {
|
||
handle: CublasHandle,
|
||
/// Dedicated 32 MiB workspace owned by this handle. Held to keep the GPU
|
||
/// buffer alive for the lifetime of the handle; cuBLAS reads/writes into
|
||
/// it during GEMM. Dropped after `cublasDestroy_v2` so cuBLAS can't touch
|
||
/// freed memory.
|
||
_workspace: Option<GpuBuffer>,
|
||
}
|
||
|
||
impl CublasContext {
|
||
pub fn new() -> Result<Self> {
|
||
let mut handle = std::ptr::null_mut();
|
||
error::check(unsafe { cublasCreate_v2(&mut handle) })?;
|
||
// Attach a per-handle workspace. cublasSetWorkspace requires the
|
||
// pointer to remain valid until destroy or until a new workspace is
|
||
// set, so we keep the GpuBuffer in this struct.
|
||
let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?;
|
||
error::check(unsafe {
|
||
cublasSetWorkspace_v2(
|
||
handle,
|
||
workspace.as_mut_ptr() as *mut c_void,
|
||
CUBLAS_WORKSPACE_BYTES,
|
||
)
|
||
})?;
|
||
Ok(Self {
|
||
handle,
|
||
_workspace: Some(workspace),
|
||
})
|
||
}
|
||
}
|
||
|
||
impl Drop for CublasContext {
|
||
fn drop(&mut self) {
|
||
if !self.handle.is_null() {
|
||
unsafe { cublasDestroy_v2(self.handle) };
|
||
}
|
||
// _workspace drops here, after cublasDestroy_v2 has released the handle.
|
||
}
|
||
}
|
||
|
||
thread_local! {
|
||
static CUBLAS_CTX: RefCell<CublasContext> = RefCell::new(
|
||
CublasContext::new().expect("failed to create thread-local cuBLAS handle")
|
||
);
|
||
}
|
||
|
||
/// Borrow the thread-local cuBLAS handle for the duration of a closure.
|
||
fn with_cublas<F, R>(f: F) -> R
|
||
where
|
||
F: FnOnce(CublasHandle) -> R,
|
||
{
|
||
CUBLAS_CTX.with(|cell| {
|
||
let ctx = cell.borrow();
|
||
f(ctx.handle)
|
||
})
|
||
}
|
||
|
||
/// Get the thread-local cuBLAS handle for use with dispatch module.
|
||
pub fn cublas_handle() -> CublasHandle {
|
||
CUBLAS_CTX.with(|cell| cell.borrow().handle)
|
||
}
|
||
|
||
/// Matrix multiplication: C = A @ B
|
||
/// A: [M, K], B: [K, N], C: [M, N]
|
||
/// All tensors must be contiguous and on the same GPU.
|
||
pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||
assert_eq!(a.ndim(), 2);
|
||
assert_eq!(b.ndim(), 2);
|
||
assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch");
|
||
assert_eq!(a.dtype(), b.dtype(), "dtype mismatch");
|
||
assert!(
|
||
a.is_contiguous() && b.is_contiguous(),
|
||
"matmul requires contiguous tensors"
|
||
);
|
||
assert!(
|
||
matches!(a.device(), Device::Cuda(_)),
|
||
"matmul requires GPU tensors"
|
||
);
|
||
|
||
let m = a.shape()[0];
|
||
let k = a.shape()[1];
|
||
let n = b.shape()[1];
|
||
let dtype = a.dtype();
|
||
|
||
// All backends (naive, tiled, cuBLAS with beta=0, custom GEMV) fully
|
||
// overwrite every element of C, so we skip the cudaMemset.
|
||
let c = Tensor::empty(&[m, n], dtype, a.device());
|
||
|
||
let a_ptr = a.data_ptr() as *const c_void;
|
||
let b_ptr = b.data_ptr() as *const c_void;
|
||
let c_ptr = c.data_ptr() as *mut c_void;
|
||
let null_stream = xserv_cuda::current_stream_raw();
|
||
|
||
match backend {
|
||
GemmBackend::Naive => unsafe {
|
||
match dtype {
|
||
DType::F32 => launch_gemm_naive_f32(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
m as i32,
|
||
n as i32,
|
||
k as i32,
|
||
null_stream,
|
||
),
|
||
DType::BF16 => launch_gemm_naive_bf16(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
m as i32,
|
||
n as i32,
|
||
k as i32,
|
||
null_stream,
|
||
),
|
||
_ => panic!("unsupported dtype for naive GEMM"),
|
||
}
|
||
},
|
||
GemmBackend::Tiled => unsafe {
|
||
match dtype {
|
||
DType::F32 => launch_gemm_tiled_f32(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
m as i32,
|
||
n as i32,
|
||
k as i32,
|
||
null_stream,
|
||
),
|
||
DType::BF16 => launch_gemm_tiled_bf16(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
m as i32,
|
||
n as i32,
|
||
k as i32,
|
||
null_stream,
|
||
),
|
||
_ => panic!("unsupported dtype for tiled GEMM"),
|
||
}
|
||
},
|
||
GemmBackend::CuBlas => {
|
||
if m == 1 && dtype == DType::BF16 && n >= 256 {
|
||
let mut fp32_buf =
|
||
xserv_cuda::allocator::cached_alloc(gemv_scratch_elems(k, n) * 4).unwrap();
|
||
unsafe {
|
||
launch_gemv_bf16(
|
||
a_ptr,
|
||
b_ptr,
|
||
c_ptr,
|
||
fp32_buf.as_mut_ptr() as *mut c_void,
|
||
k as i32,
|
||
n as i32,
|
||
null_stream,
|
||
);
|
||
}
|
||
} else {
|
||
let alpha = 1.0f32;
|
||
let beta = 0.0f32;
|
||
|
||
let (a_type, b_type, c_type) = match dtype {
|
||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||
_ => panic!("unsupported dtype for cuBLAS GEMM"),
|
||
};
|
||
|
||
with_cublas(|handle| unsafe {
|
||
cublasSetStream_v2(handle, null_stream);
|
||
error::check(cublasGemmEx(
|
||
handle,
|
||
CUBLAS_OP_N,
|
||
CUBLAS_OP_N,
|
||
n as i32,
|
||
m as i32,
|
||
k as i32,
|
||
&alpha as *const f32 as *const c_void,
|
||
b_ptr,
|
||
b_type,
|
||
n as i32,
|
||
a_ptr,
|
||
a_type,
|
||
k as i32,
|
||
&beta as *const f32 as *const c_void,
|
||
c_ptr,
|
||
c_type,
|
||
n as i32,
|
||
CUBLAS_COMPUTE_32F,
|
||
-1,
|
||
))
|
||
.expect("cuBLAS GEMM failed");
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
c
|
||
}
|
||
|
||
/// Batched matrix multiplication via cuBLAS: C[b] = A[b] @ B[b]
|
||
/// a: [..., M, K], b: [..., K, N] → [..., M, N]
|
||
/// Leading dimensions must match and tensors must be contiguous.
|
||
pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||
assert!(a.ndim() >= 2 && b.ndim() >= 2);
|
||
assert_eq!(a.ndim(), b.ndim());
|
||
assert!(a.is_contiguous() && b.is_contiguous());
|
||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||
assert_eq!(a.dtype(), b.dtype());
|
||
|
||
let ndim = a.ndim();
|
||
let m = a.shape()[ndim - 2];
|
||
let k = a.shape()[ndim - 1];
|
||
let n = b.shape()[ndim - 1];
|
||
assert_eq!(b.shape()[ndim - 2], k, "inner dimension mismatch");
|
||
|
||
// Compute batch count from leading dimensions
|
||
let batch: usize = a.shape()[..ndim - 2].iter().product();
|
||
assert_eq!(
|
||
b.shape()[..ndim - 2].iter().product::<usize>(),
|
||
batch,
|
||
"batch dimensions mismatch"
|
||
);
|
||
|
||
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||
out_shape.push(m);
|
||
out_shape.push(n);
|
||
// cuBLAS with beta=0 fully overwrites every element of C.
|
||
let c = Tensor::empty(&out_shape, a.dtype(), a.device());
|
||
|
||
let dtype = a.dtype();
|
||
let (a_type, b_type, c_type) = match dtype {
|
||
DType::F32 => (CUDA_R_32F, CUDA_R_32F, CUDA_R_32F),
|
||
DType::BF16 => (CUDA_R_16BF, CUDA_R_16BF, CUDA_R_16BF),
|
||
_ => panic!("unsupported dtype for batched matmul"),
|
||
};
|
||
|
||
let alpha = 1.0f32;
|
||
let beta = 0.0f32;
|
||
// cuBLAS strides are in elements (not bytes)
|
||
let stride_a = (m * k) as i64;
|
||
let stride_b = (k * n) as i64;
|
||
let stride_c = (m * n) as i64;
|
||
|
||
with_cublas(|handle| unsafe {
|
||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||
error::check(cublasGemmStridedBatchedEx(
|
||
handle,
|
||
CUBLAS_OP_N,
|
||
CUBLAS_OP_N,
|
||
n as i32,
|
||
m as i32,
|
||
k as i32,
|
||
&alpha as *const f32 as *const c_void,
|
||
b.data_ptr() as _,
|
||
b_type,
|
||
n as i32,
|
||
stride_b,
|
||
a.data_ptr() as _,
|
||
a_type,
|
||
k as i32,
|
||
stride_a,
|
||
&beta as *const f32 as *const c_void,
|
||
c.data_ptr() as *mut c_void,
|
||
c_type,
|
||
n as i32,
|
||
stride_c,
|
||
batch as i32,
|
||
CUBLAS_COMPUTE_32F,
|
||
-1,
|
||
))
|
||
.expect("cuBLAS batched GEMM failed");
|
||
});
|
||
c
|
||
}
|