quantization: cache cuBLASLt FP8 plan per shape — fix per-expert heuristic churn

batched_gemm_fp8 rebuilt the cuBLASLt matmul descriptor, four matrix
layouts, a preference, and a 4-byte scale alloc, AND ran the algo
heuristic search — once per expert, per GEMM, per layer, on every
forward (~1500 heuristic searches per decoded token). FP8 decode ran at
27.0 ms/tok vs BF16 18.8 ms, i.e. slower than the path it was meant to
accelerate.

Cache the full plan (descriptor + layouts + heuristically-chosen algo)
in a thread-local map keyed by (M, N, K) so the heuristic runs once per
shape and is reused across experts and forwards; allocate the 1.0 scale
buffer once; pass each expert's weight scale via the cuBLASLt B-scale
device pointer instead of folding it into alpha (identical FP32-epilogue
precision, and no host readback of b_scales). The per-expert loop now
issues only cublasLtMatmul.

Measured on dash5 (gpt-oss-20b, TP=2, 5090): FP8 decode TPOT 27.0 -> 17.9
ms, now faster than BF16 (18.8 ms); GSM8K-200 accuracy unchanged
(FP8 93.0% vs BF16 90.5%, within noise).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 00:58:46 +08:00
parent 3a530956af
commit 5a16225c1f

View File

@@ -1,4 +1,5 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::ffi::c_void;
use xserv_cuda::GpuBuffer;
use xserv_tensor::{DType, Tensor};
@@ -113,9 +114,33 @@ const CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES: i32 = 1;
const WORKSPACE_BYTES: usize = 32 * 1024 * 1024;
const CUBLASLT_MATMUL_DESC_TRANSA: i32 = 3;
/// A fully-prepared FP8 matmul plan for one (M, N, K) shape: the matmul
/// descriptor, the four matrix layouts, and the heuristically-chosen algo.
/// Built once per shape and reused across every expert and every forward
/// pass — the heuristic search and descriptor/layout creation are the
/// expensive parts, so doing them once instead of per-expert-per-layer is
/// the difference between FP8 being faster or slower than BF16.
#[derive(Clone, Copy)]
struct Fp8Plan {
desc: CublasLtMatmulDesc,
a_layout: CublasLtMatrixLayout,
b_layout: CublasLtMatrixLayout,
c_layout: CublasLtMatrixLayout,
d_layout: CublasLtMatrixLayout,
algo: CublasLtMatmulAlgo,
workspace_size: usize,
}
struct CublasLtContext {
handle: CublasLtHandle,
workspace: GpuBuffer,
/// Persistent device scalar holding 1.0, used as the A/B scale pointer
/// placeholder. Allocated once instead of per-expert.
one_buf: GpuBuffer,
/// Cache of prepared matmul plans keyed by (M, N, K).
plans: HashMap<(usize, usize, usize), Fp8Plan>,
}
impl CublasLtContext {
@@ -124,18 +149,100 @@ impl CublasLtContext {
let status = unsafe { cublasLtCreate(&mut handle) };
assert_eq!(status, 0, "cublasLtCreate failed: {status}");
let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace");
Self { handle, workspace }
let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale");
one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale");
Self { handle, workspace, one_buf, plans: HashMap::new() }
}
/// Get the cached plan for (m, n, k), building (and caching) it on first use.
fn plan(&mut self, m: usize, n: usize, k: usize) -> Fp8Plan {
if let Some(p) = self.plans.get(&(m, n, k)) {
return *p;
}
let one_ptr = self.one_buf.as_ptr() as *const c_void;
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k) };
self.plans.insert((m, n, k), plan);
plan
}
}
impl Drop for CublasLtContext {
fn drop(&mut self) {
// Tear down cached plans before destroying the handle.
for (_, p) in self.plans.drain() {
unsafe {
cublasLtMatrixLayoutDestroy(p.a_layout);
cublasLtMatrixLayoutDestroy(p.b_layout);
cublasLtMatrixLayoutDestroy(p.c_layout);
cublasLtMatrixLayoutDestroy(p.d_layout);
cublasLtMatmulDescDestroy(p.desc);
}
}
if !self.handle.is_null() {
unsafe { cublasLtDestroy(self.handle) };
}
}
}
/// Build an FP8 matmul plan for one (m, n, k) shape. See `batched_gemm_fp8`
/// for the row-major → cuBLASLt col-major layout mapping (transA=T, transB=N,
/// m_lt=N, n_lt=M, k_lt=K). The B-scale pointer is initialised to `one_ptr`
/// and overwritten per-expert at call time.
unsafe fn build_fp8_plan(
handle: CublasLtHandle,
one_ptr: *const c_void,
m: usize,
n: usize,
k: usize,
) -> Fp8Plan {
let m_lt = n as u64;
let n_lt = m as u64;
let k_lt = k as u64;
let mut desc: CublasLtMatmulDesc = std::ptr::null_mut();
cublasLtMatmulDescCreate(&mut desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
// transA=T (required for FP8 on Blackwell)
let trans_a: i32 = 1;
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4);
let ptr_sz = std::mem::size_of::<*const c_void>();
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
// "B" layout (activations): physical (K, M) col-major, ld=K
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
let ws_bytes = WORKSPACE_BYTES as u64;
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
let mut found: i32 = 0;
let status = cublasLtMatmulAlgoGetHeuristic(
handle, desc, a_layout, b_layout, c_layout, d_layout,
pref, 1, &mut heuristic, &mut found,
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM (m={m}, n={n}, k={k}): status={status}, found={found}");
cublasLtMatmulPreferenceDestroy(pref);
Fp8Plan {
desc, a_layout, b_layout, c_layout, d_layout,
algo: heuristic.algo,
workspace_size: heuristic.workspace_size,
}
}
thread_local! {
static CUBLASLT_CTX: RefCell<CublasLtContext> = RefCell::new(CublasLtContext::new());
}
@@ -215,9 +322,9 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
/// as [b, N, K] for cuBLASLt FP8 compatibility.
///
/// a_fp8: [batch, M, K] FP8E4M3 (activations, quantized per-row)
/// a_scales: [batch * M] F32 (per-token scales, collapsed to per-batch max)
/// a_scales: [batch * M] F32 (per-token activation scales, applied post-GEMM)
/// b_fp8_t: [batch, N, K] FP8E4M3 (weights, TRANSPOSED for cuBLASLt)
/// b_scales: [batch] F32 (per-expert scalar scales)
/// b_scales: [batch] F32 (per-expert scalar weight scales, applied in-GEMM)
///
/// Returns: [batch, M, N] BF16
pub fn batched_gemm_fp8(
@@ -240,127 +347,64 @@ pub fn batched_gemm_fp8(
let k = a_fp8.shape()[2]; // hidden
let n = b_fp8_t.shape()[1]; // out_dim (from transposed weight)
// Per-token scales → per-expert scales (max over tokens within each expert batch)
// a_scales: [batch * M] → we take the max per expert to get [batch] scalar scales
// This is a slight accuracy tradeoff vs per-token, but allows scalar GEMM scale mode.
// a_scales: [batch * M] per-token activation scales (applied post-GEMM, per row).
// b_scales: [batch] per-expert scalar weight scales (applied in-GEMM via B-scale ptr).
assert_eq!(a_scales.shape()[0], batch * m);
assert_eq!(b_scales.shape()[0], batch);
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
// Read weight scales to host for the per-expert loop
let b_scales_cpu = b_scales.to_device(xserv_tensor::Device::Cpu);
let b_s_data = b_scales_cpu.as_slice::<f32>();
CUBLASLT_CTX.with(|cell| {
let ctx = cell.borrow();
let handle = ctx.handle;
// Strides (in bytes) for one expert slice
let stride_a = m * k; // FP8: 1 byte per elem
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
let stride_c = m * n * 2; // BF16: 2 bytes per elem
CUBLASLT_CTX.with(|cell| {
let mut ctx = cell.borrow_mut();
let handle = ctx.handle;
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
// Build (or fetch) the cached plan for this shape — heuristic search and
// descriptor/layout creation happen once per (m, n, k), not per-expert.
let plan = ctx.plan(m, n, k);
// alpha=1, beta=0. Per-expert weight scale is supplied via the cuBLASLt
// B-scale pointer (device, scalar): cuBLASLt computes in the FP32 epilogue
// D = (1.0 * A_fp8) @ (b_scale[e] * B_fp8)^T = b_scale[e] * (A_fp8 @ B_fp8^T)
// Per-token activation scale (a_scale) is applied post-GEMM (per row).
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let ptr_sz = std::mem::size_of::<*const c_void>();
for e in 0..batch {
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
// alpha = b_scale (weight scale). Per-row activation scale applied post-GEMM.
// GEMM computes: D = alpha * (A_fp8 @ B_fp8_T)
// = b_scale * ((A_real / a_scale_row) @ (B_real / b_scale))
// = (A_real / a_scale_row) @ B_real
// Post-multiply row i by a_scale[i] to recover the correct result.
let alpha: f32 = b_s_data[e];
let beta: f32 = 0.0;
// Device pointer to this expert's scalar weight scale (FP32, 4 bytes).
let b_scale_ptr = unsafe { (b_scales.data_ptr() as *const u8).add(e * 4) as *const c_void };
unsafe {
// cuBLASLt FP8 on Blackwell requires transA=T, transB=N.
// cuBLASLt computes: D(m,n) = op(A)(m,k) * B(k,n) with transA=T
//
// We want: D_row[M,N] = A_act_row[M,K] @ B_wt_row[K,N]
// Map to cuBLASLt with m_lt=N, n_lt=M, k_lt=K:
// "A" (transA=T): stored as (K, N) col-major ld=K → transposed to (N, K)
// Our weights are stored TRANSPOSED as [E, N, K] row-major = col-major (K, N) ld=K ✓
// "B" (transB=N): stored as (K, M) col-major ld=K
// Our activations A_act_row[M,K] = col-major (K, M) ld=K ✓
// "D": stored as (N, M) col-major ld=N
// Our output D_row[M,N] = col-major (N, M) ld=N ✓
let m_lt = n as u64;
let n_lt = m as u64;
let k_lt = k as u64;
let mut matmul_desc: CublasLtMatmulDesc = std::ptr::null_mut();
cublasLtMatmulDescCreate(&mut matmul_desc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
// Set transA=T (required for FP8 on Blackwell)
let trans_a: i32 = 1; // CUBLAS_OP_T
cublasLtMatmulDescSetAttribute(matmul_desc, 3 /*TRANSA*/, &trans_a as *const i32 as _, 4);
// FP8 requires scale pointers. We fold the actual scales into alpha,
// so set dummy 1.0 scale pointers on device.
let one_val: f32 = 1.0;
let mut one_buf = xserv_cuda::allocator::cached_alloc(4).unwrap();
one_buf.copy_from_host(&one_val.to_le_bytes()).unwrap();
let one_ptr = one_buf.as_ptr() as *const c_void;
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
cublasLtMatmulDescSetAttribute(matmul_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, std::mem::size_of::<*const c_void>());
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
// "B" layout (activations): physical (K, M) col-major, ld=K
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
// Get algo heuristic
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
let ws_bytes = WORKSPACE_BYTES as u64;
cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8);
let mut heuristic = std::mem::zeroed::<CublasLtMatmulHeuristicResult>();
let mut found: i32 = 0;
let status = cublasLtMatmulAlgoGetHeuristic(
handle, matmul_desc,
a_layout, b_layout, c_layout, d_layout,
pref, 1, &mut heuristic, &mut found,
cublasLtMatmulDescSetAttribute(
plan.desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_ptr as *const _ as _, ptr_sz,
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM: status={status}, found={found}");
let status = cublasLtMatmul(
handle, matmul_desc,
handle, plan.desc,
&alpha as *const f32 as _,
b_ptr, // cuBLASLt "A" = weights
a_layout,
plan.a_layout,
a_ptr, // cuBLASLt "B" = activations
b_layout,
plan.b_layout,
&beta as *const f32 as _,
c_ptr, // C (unused with beta=0)
c_layout,
plan.c_layout,
c_ptr, // D = output
d_layout,
&heuristic.algo,
ctx.workspace.as_ptr() as *mut c_void,
heuristic.workspace_size,
plan.d_layout,
&plan.algo,
ws_ptr,
plan.workspace_size,
std::ptr::null_mut(),
);
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
cublasLtMatmulPreferenceDestroy(pref);
cublasLtMatrixLayoutDestroy(a_layout);
cublasLtMatrixLayoutDestroy(b_layout);
cublasLtMatrixLayoutDestroy(c_layout);
cublasLtMatrixLayoutDestroy(d_layout);
cublasLtMatmulDescDestroy(matmul_desc);
}
}
});