quantization: cache cuBLASLt FP8 plan per shape — fix per-expert heuristic churn
batched_gemm_fp8 rebuilt the cuBLASLt matmul descriptor, four matrix layouts, a preference, and a 4-byte scale alloc, AND ran the algo heuristic search — once per expert, per GEMM, per layer, on every forward (~1500 heuristic searches per decoded token). FP8 decode ran at 27.0 ms/tok vs BF16 18.8 ms, i.e. slower than the path it was meant to accelerate. Cache the full plan (descriptor + layouts + heuristically-chosen algo) in a thread-local map keyed by (M, N, K) so the heuristic runs once per shape and is reused across experts and forwards; allocate the 1.0 scale buffer once; pass each expert's weight scale via the cuBLASLt B-scale device pointer instead of folding it into alpha (identical FP32-epilogue precision, and no host readback of b_scales). The per-expert loop now issues only cublasLtMatmul. Measured on dash5 (gpt-oss-20b, TP=2, 5090): FP8 decode TPOT 27.0 -> 17.9 ms, now faster than BF16 (18.8 ms); GSM8K-200 accuracy unchanged (FP8 93.0% vs BF16 90.5%, within noise). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use xserv_cuda::GpuBuffer;
|
||||
use xserv_tensor::{DType, Tensor};
|
||||
@@ -113,9 +114,33 @@ const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
|
||||
|
||||
const WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
|
||||
|
||||
const CUBLASLT_MATMUL_DESC_TRANSA: i32 = 3;
|
||||
|
||||
/// A fully-prepared FP8 matmul plan for one (M, N, K) shape: the matmul
|
||||
/// descriptor, the four matrix layouts, and the heuristically-chosen algo.
|
||||
/// Built once per shape and reused across every expert and every forward
|
||||
/// pass — the heuristic search and descriptor/layout creation are the
|
||||
/// expensive parts, so doing them once instead of per-expert-per-layer is
|
||||
/// the difference between FP8 being faster or slower than BF16.
|
||||
#[derive(Clone, Copy)]
|
||||
struct Fp8Plan {
|
||||
desc: CublasLtMatmulDesc,
|
||||
a_layout: CublasLtMatrixLayout,
|
||||
b_layout: CublasLtMatrixLayout,
|
||||
c_layout: CublasLtMatrixLayout,
|
||||
d_layout: CublasLtMatrixLayout,
|
||||
algo: CublasLtMatmulAlgo,
|
||||
workspace_size: usize,
|
||||
}
|
||||
|
||||
struct CublasLtContext {
|
||||
handle: CublasLtHandle,
|
||||
workspace: GpuBuffer,
|
||||
/// Persistent device scalar holding 1.0, used as the A/B scale pointer
|
||||
/// placeholder. Allocated once instead of per-expert.
|
||||
one_buf: GpuBuffer,
|
||||
/// Cache of prepared matmul plans keyed by (M, N, K).
|
||||
plans: HashMap<(usize, usize, usize), Fp8Plan>,
|
||||
}
|
||||
|
||||
impl CublasLtContext {
|
||||
@@ -124,18 +149,100 @@ impl CublasLtContext {
|
||||
let status = unsafe { cublasLtCreate(&mut handle) };
|
||||
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
|
||||
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
|
||||
Self { handle, workspace }
|
||||
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
|
||||
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
|
||||
Self { handle, workspace, one_buf, plans: HashMap::new() }
|
||||
}
|
||||
|
||||
/// Get the cached plan for (m, n, k), building (and caching) it on first use.
|
||||
fn plan(&mut self, m: usize, n: usize, k: usize) -> Fp8Plan {
|
||||
if let Some(p) = self.plans.get(&(m, n, k)) {
|
||||
return *p;
|
||||
}
|
||||
let one_ptr = self.one_buf.as_ptr() as *const c_void;
|
||||
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k) };
|
||||
self.plans.insert((m, n, k), plan);
|
||||
plan
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CublasLtContext {
|
||||
fn drop(&mut self) {
|
||||
// Tear down cached plans before destroying the handle.
|
||||
for (_, p) in self.plans.drain() {
|
||||
unsafe {
|
||||
cublasLtMatrixLayoutDestroy(p.a_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.b_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.c_layout);
|
||||
cublasLtMatrixLayoutDestroy(p.d_layout);
|
||||
cublasLtMatmulDescDestroy(p.desc);
|
||||
}
|
||||
}
|
||||
if !self.handle.is_null() {
|
||||
unsafe { cublasLtDestroy(self.handle) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an FP8 matmul plan for one (m, n, k) shape. See `batched_gemm_fp8`
|
||||
/// for the row-major → cuBLASLt col-major layout mapping (transA=T, transB=N,
|
||||
/// m_lt=N, n_lt=M, k_lt=K). The B-scale pointer is initialised to `one_ptr`
|
||||
/// and overwritten per-expert at call time.
|
||||
unsafe fn build_fp8_plan(
|
||||
handle: CublasLtHandle,
|
||||
one_ptr: *const c_void,
|
||||
m: usize,
|
||||
n: usize,
|
||||
k: usize,
|
||||
) -> Fp8Plan {
|
||||
let m_lt = n as u64;
|
||||
let n_lt = m as u64;
|
||||
let k_lt = k as u64;
|
||||
|
||||
let mut desc: CublasLtMatmulDesc = std::ptr::null_mut();
|
||||
cublasLtMatmulDescCreate(&mut desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
|
||||
|
||||
// transA=T (required for FP8 on Blackwell)
|
||||
let trans_a: i32 = 1;
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
|
||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
|
||||
|
||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
||||
// "B" layout (activations): physical (K, M) col-major, ld=K
|
||||
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
||||
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
||||
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
|
||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
|
||||
|
||||
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||
let mut found: i32 = 0;
|
||||
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||
handle, desc, a_layout, b_layout, c_layout, d_layout,
|
||||
pref, 1, &mut heuristic, &mut found,
|
||||
);
|
||||
assert!(status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM (m={m}, n={n}, k={k}): status={status}, found={found}");
|
||||
cublasLtMatmulPreferenceDestroy(pref);
|
||||
|
||||
Fp8Plan {
|
||||
desc, a_layout, b_layout, c_layout, d_layout,
|
||||
algo: heuristic.algo,
|
||||
workspace_size: heuristic.workspace_size,
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CUBLASLT_CTX: RefCell<CublasLtContext> = RefCell::new(CublasLtContext::new());
|
||||
}
|
||||
@@ -215,9 +322,9 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||
/// as [b, N, K] for cuBLASLt FP8 compatibility.
|
||||
///
|
||||
/// a_fp8: [batch, M, K] FP8E4M3 (activations, quantized per-row)
|
||||
/// a_scales: [batch * M] F32 (per-token scales, collapsed to per-batch max)
|
||||
/// a_scales: [batch * M] F32 (per-token activation scales, applied post-GEMM)
|
||||
/// b_fp8_t: [batch, N, K] FP8E4M3 (weights, TRANSPOSED for cuBLASLt)
|
||||
/// b_scales: [batch] F32 (per-expert scalar scales)
|
||||
/// b_scales: [batch] F32 (per-expert scalar weight scales, applied in-GEMM)
|
||||
///
|
||||
/// Returns: [batch, M, N] BF16
|
||||
pub fn batched_gemm_fp8(
|
||||
@@ -240,127 +347,64 @@ pub fn batched_gemm_fp8(
|
||||
let k = a_fp8.shape()[2]; // hidden
|
||||
let n = b_fp8_t.shape()[1]; // out_dim (from transposed weight)
|
||||
|
||||
// Per-token scales → per-expert scales (max over tokens within each expert batch)
|
||||
// a_scales: [batch * M] → we take the max per expert to get [batch] scalar scales
|
||||
// This is a slight accuracy tradeoff vs per-token, but allows scalar GEMM scale mode.
|
||||
// a_scales: [batch * M] per-token activation scales (applied post-GEMM, per row).
|
||||
// b_scales: [batch] per-expert scalar weight scales (applied in-GEMM via B-scale ptr).
|
||||
assert_eq!(a_scales.shape()[0], batch * m);
|
||||
assert_eq!(b_scales.shape()[0], batch);
|
||||
|
||||
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
|
||||
|
||||
// Read weight scales to host for the per-expert loop
|
||||
let b_scales_cpu = b_scales.to_device(xserv_tensor::Device::Cpu);
|
||||
let b_s_data = b_scales_cpu.as_slice::<f32>();
|
||||
// Strides (in bytes) for one expert slice
|
||||
let stride_a = m * k; // FP8: 1 byte per elem
|
||||
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
|
||||
let stride_c = m * n * 2; // BF16: 2 bytes per elem
|
||||
|
||||
CUBLASLT_CTX.with(|cell| {
|
||||
let ctx = cell.borrow();
|
||||
let mut ctx = cell.borrow_mut();
|
||||
let handle = ctx.handle;
|
||||
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
|
||||
// Build (or fetch) the cached plan for this shape — heuristic search and
|
||||
// descriptor/layout creation happen once per (m, n, k), not per-expert.
|
||||
let plan = ctx.plan(m, n, k);
|
||||
|
||||
// Strides (in bytes) for one expert slice
|
||||
let stride_a = m * k; // FP8: 1 byte per elem
|
||||
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
|
||||
let stride_c = m * n * 2; // BF16: 2 bytes per elem
|
||||
// alpha=1, beta=0. Per-expert weight scale is supplied via the cuBLASLt
|
||||
// B-scale pointer (device, scalar): cuBLASLt computes in the FP32 epilogue
|
||||
// D = (1.0 * A_fp8) @ (b_scale[e] * B_fp8)^T = b_scale[e] * (A_fp8 @ B_fp8^T)
|
||||
// Per-token activation scale (a_scale) is applied post-GEMM (per row).
|
||||
let alpha: f32 = 1.0;
|
||||
let beta: f32 = 0.0;
|
||||
let ptr_sz = std::mem::size_of::<*const c_void>();
|
||||
|
||||
for e in 0..batch {
|
||||
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
|
||||
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
|
||||
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
|
||||
|
||||
// alpha = b_scale (weight scale). Per-row activation scale applied post-GEMM.
|
||||
// GEMM computes: D = alpha * (A_fp8 @ B_fp8_T)
|
||||
// = b_scale * ((A_real / a_scale_row) @ (B_real / b_scale))
|
||||
// = (A_real / a_scale_row) @ B_real
|
||||
// Post-multiply row i by a_scale[i] to recover the correct result.
|
||||
let alpha: f32 = b_s_data[e];
|
||||
let beta: f32 = 0.0;
|
||||
// Device pointer to this expert's scalar weight scale (FP32, 4 bytes).
|
||||
let b_scale_ptr = unsafe { (b_scales.data_ptr() as *const u8).add(e * 4) as *const c_void };
|
||||
|
||||
unsafe {
|
||||
// cuBLASLt FP8 on Blackwell requires transA=T, transB=N.
|
||||
// cuBLASLt computes: D(m,n) = op(A)(m,k) * B(k,n) with transA=T
|
||||
//
|
||||
// We want: D_row[M,N] = A_act_row[M,K] @ B_wt_row[K,N]
|
||||
// Map to cuBLASLt with m_lt=N, n_lt=M, k_lt=K:
|
||||
// "A" (transA=T): stored as (K, N) col-major ld=K → transposed to (N, K)
|
||||
// Our weights are stored TRANSPOSED as [E, N, K] row-major = col-major (K, N) ld=K ✓
|
||||
// "B" (transB=N): stored as (K, M) col-major ld=K
|
||||
// Our activations A_act_row[M,K] = col-major (K, M) ld=K ✓
|
||||
// "D": stored as (N, M) col-major ld=N
|
||||
// Our output D_row[M,N] = col-major (N, M) ld=N ✓
|
||||
let m_lt = n as u64;
|
||||
let n_lt = m as u64;
|
||||
let k_lt = k as u64;
|
||||
|
||||
let mut matmul_desc: CublasLtMatmulDesc = std::ptr::null_mut();
|
||||
cublasLtMatmulDescCreate(&mut matmul_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
|
||||
|
||||
// Set transA=T (required for FP8 on Blackwell)
|
||||
let trans_a: i32 = 1; // CUBLAS_OP_T
|
||||
cublasLtMatmulDescSetAttribute(matmul_desc, 3 /*TRANSA*/, &trans_a as *const i32 as _, 4);
|
||||
|
||||
// FP8 requires scale pointers. We fold the actual scales into alpha,
|
||||
// so set dummy 1.0 scale pointers on device.
|
||||
let one_val: f32 = 1.0;
|
||||
let mut one_buf = xserv_cuda::allocator::cached_alloc(4).unwrap();
|
||||
one_buf.copy_from_host(&one_val.to_le_bytes()).unwrap();
|
||||
let one_ptr = one_buf.as_ptr() as *const c_void;
|
||||
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
|
||||
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
|
||||
|
||||
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
|
||||
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
|
||||
|
||||
// "B" layout (activations): physical (K, M) col-major, ld=K
|
||||
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
|
||||
|
||||
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
|
||||
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
|
||||
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
|
||||
|
||||
// Get algo heuristic
|
||||
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
|
||||
cublasLtMatmulPreferenceCreate(&mut pref);
|
||||
let ws_bytes = WORKSPACE_BYTES as u64;
|
||||
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
|
||||
|
||||
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
|
||||
let mut found: i32 = 0;
|
||||
let status = cublasLtMatmulAlgoGetHeuristic(
|
||||
handle, matmul_desc,
|
||||
a_layout, b_layout, c_layout, d_layout,
|
||||
pref, 1, &mut heuristic, &mut found,
|
||||
cublasLtMatmulDescSetAttribute(
|
||||
plan.desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
||||
&b_scale_ptr as *const _ as _, ptr_sz,
|
||||
);
|
||||
assert!(status == 0 && found > 0,
|
||||
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM: status={status}, found={found}");
|
||||
|
||||
let status = cublasLtMatmul(
|
||||
handle, matmul_desc,
|
||||
handle, plan.desc,
|
||||
&alpha as *const f32 as _,
|
||||
b_ptr, // cuBLASLt "A" = weights
|
||||
a_layout,
|
||||
plan.a_layout,
|
||||
a_ptr, // cuBLASLt "B" = activations
|
||||
b_layout,
|
||||
plan.b_layout,
|
||||
&beta as *const f32 as _,
|
||||
c_ptr, // C (unused with beta=0)
|
||||
c_layout,
|
||||
plan.c_layout,
|
||||
c_ptr, // D = output
|
||||
d_layout,
|
||||
&heuristic.algo,
|
||||
ctx.workspace.as_ptr() as *mut c_void,
|
||||
heuristic.workspace_size,
|
||||
plan.d_layout,
|
||||
&plan.algo,
|
||||
ws_ptr,
|
||||
plan.workspace_size,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
|
||||
|
||||
cublasLtMatmulPreferenceDestroy(pref);
|
||||
cublasLtMatrixLayoutDestroy(a_layout);
|
||||
cublasLtMatrixLayoutDestroy(b_layout);
|
||||
cublasLtMatrixLayoutDestroy(c_layout);
|
||||
cublasLtMatrixLayoutDestroy(d_layout);
|
||||
cublasLtMatmulDescDestroy(matmul_desc);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user