quantization: single strided-batched FP8 MoE GEMM — cut per-token launches ~768→48

The plan-cache fix removed the per-expert heuristic churn but still issued one
cublasLtMatmul per expert: ~768 tiny launches per decoded token (16 local
experts × 2 GEMMs × 24 layers), which capped the FP8 decode win at ~1.05× over
BF16. Collapse each MoE GEMM into ONE strided-batched cuBLASLt FP8 matmul
(BATCH_COUNT + strided-batch offsets on all four layouts) → ~48 launches/token.

A single strided call can't carry a per-batch scalar B-scale, so the per-expert
weight scale moves out of the GEMM epilogue into a fused post-scale kernel
(rowwise_scale_moe_bf16) that applies a_scale[token]·b_scale[expert] in one
pass. This is precision-equivalent: BF16's relative error is scale-invariant, so
scaling the unscaled GEMM output afterward loses nothing vs scaling in-epilogue.

Measured on dash5 (gpt-oss-20b, TP=2, 5090), warm-server GSM8K:
  decode TPOT 17.45 → 13.08 ms (FP8 now 1.41× vs BF16 18.39 ms),
  throughput 57.3 → 76.4 tok/s, accuracy unchanged (FP8 91.0% vs BF16 90.0%).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 01:23:29 +08:00
parent 24c49c31c2
commit e631a71b68
3 changed files with 150 additions and 94 deletions

View File

@@ -23,10 +23,11 @@ unsafe extern "C" {
num_rows: i32, cols: i32,
stream: *mut c_void,
);
fn launch_rowwise_scale_bf16(
fn launch_rowwise_scale_moe_bf16(
data: *mut c_void,
scales: *const c_void,
num_rows: i32, cols: i32,
a_scales: *const c_void,
b_scales: *const c_void,
num_rows: i32, cols: i32, tokens: i32,
stream: *mut c_void,
);
}
@@ -136,11 +137,11 @@ struct Fp8Plan {
struct CublasLtContext {
handle: CublasLtHandle,
workspace: GpuBuffer,
/// Persistent device scalar holding 1.0, used as the A/B scale pointer
/// placeholder. Allocated once instead of per-expert.
/// Persistent device scalar holding 1.0, used as the A/B scale pointer.
/// Scales are applied post-GEMM, so the in-GEMM scales stay 1.0.
one_buf: GpuBuffer,
/// Cache of prepared matmul plans keyed by (M, N, K).
plans: HashMap<(usize, usize, usize), Fp8Plan>,
/// Cache of prepared matmul plans keyed by (M, N, K, batch).
plans: HashMap<(usize, usize, usize, usize), Fp8Plan>,
}
impl CublasLtContext {
@@ -154,14 +155,15 @@ impl CublasLtContext {
Self { handle, workspace, one_buf, plans: HashMap::new() }
}
/// Get the cached plan for (m, n, k), building (and caching) it on first use.
fn plan(&mut self, m: usize, n: usize, k: usize) -> Fp8Plan {
if let Some(p) = self.plans.get(&(m, n, k)) {
/// Get the cached strided-batched plan for (m, n, k, batch), building it on
/// first use.
fn plan(&mut self, m: usize, n: usize, k: usize, batch: usize) -> Fp8Plan {
if let Some(p) = self.plans.get(&(m, n, k, batch)) {
return *p;
}
let one_ptr = self.one_buf.as_ptr() as *const c_void;
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k) };
self.plans.insert((m, n, k), plan);
let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k, batch) };
self.plans.insert((m, n, k, batch), plan);
plan
}
}
@@ -184,16 +186,18 @@ impl Drop for CublasLtContext {
}
}
/// Build an FP8 matmul plan for one (m, n, k) shape. See `batched_gemm_fp8`
/// for the row-major → cuBLASLt col-major layout mapping (transA=T, transB=N,
/// m_lt=N, n_lt=M, k_lt=K). The B-scale pointer is initialised to `one_ptr`
/// and overwritten per-expert at call time.
/// Build a strided-batched FP8 matmul plan for `batch` experts of one
/// (m, n, k) shape. Row-major → cuBLASLt col-major mapping (transA=T,
/// transB=N, m_lt=N, n_lt=M, k_lt=K). A/B scale pointers stay at 1.0 — both
/// the per-expert weight scale and the per-token activation scale are applied
/// post-GEMM in a fused kernel, which lets all experts run in one matmul.
unsafe fn build_fp8_plan(
handle: CublasLtHandle,
one_ptr: *const c_void,
m: usize,
n: usize,
k: usize,
batch: usize,
) -> Fp8Plan {
let m_lt = n as u64;
let n_lt = m as u64;
@@ -209,17 +213,33 @@ unsafe fn build_fp8_plan(
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz);
// Per-expert strides in ELEMENTS for the strided-batch layout.
let stride_a = (n * k) as i64; // weights [N, K]
let stride_b = (m * k) as i64; // activations [M, K]
let stride_c = (m * n) as i64; // output [M, N]
let bc = batch as i32;
let set_batch = |layout: CublasLtMatrixLayout, stride: i64| {
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&bc as *const i32 as _, 4);
cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&stride as *const i64 as _, 8);
};
// "A" layout (weights, transposed): physical (K, N) col-major, ld=K
let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64);
set_batch(a_layout, stride_a);
// "B" layout (activations): physical (K, M) col-major, ld=K
let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64);
set_batch(b_layout, stride_b);
// "C"/"D" layout (output): physical (N, M) col-major, ld=N
let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
set_batch(c_layout, stride_c);
let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut();
cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64);
set_batch(d_layout, stride_c);
let mut pref: CublasLtMatmulPreference = std::ptr::null_mut();
cublasLtMatmulPreferenceCreate(&mut pref);
@@ -233,7 +253,7 @@ unsafe fn build_fp8_plan(
pref, 1, &mut heuristic, &mut found,
);
assert!(status == 0 && found > 0,
"cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM (m={m}, n={n}, k={k}): status={status}, found={found}");
"cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}");
cublasLtMatmulPreferenceDestroy(pref);
Fp8Plan {
@@ -354,71 +374,54 @@ pub fn batched_gemm_fp8(
let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device());
// Strides (in bytes) for one expert slice
let stride_a = m * k; // FP8: 1 byte per elem
let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K])
let stride_c = m * n * 2; // BF16: 2 bytes per elem
CUBLASLT_CTX.with(|cell| {
let mut ctx = cell.borrow_mut();
let handle = ctx.handle;
let ws_ptr = ctx.workspace.as_ptr() as *mut c_void;
// Build (or fetch) the cached plan for this shape — heuristic search and
// descriptor/layout creation happen once per (m, n, k), not per-expert.
let plan = ctx.plan(m, n, k);
// Cached strided-batched plan: heuristic + descriptor/layout creation
// happen once per (m, n, k, batch). All experts run in ONE matmul.
let plan = ctx.plan(m, n, k, batch);
// alpha=1, beta=0. Per-expert weight scale is supplied via the cuBLASLt
// B-scale pointer (device, scalar): cuBLASLt computes in the FP32 epilogue
// D = (1.0 * A_fp8) @ (b_scale[e] * B_fp8)^T = b_scale[e] * (A_fp8 @ B_fp8^T)
// Per-token activation scale (a_scale) is applied post-GEMM (per row).
// alpha=1, beta=0, in-GEMM scales=1.0. The unscaled result
// D_raw[e] = A_fp8[e] @ B_fp8[e]^T
// is recovered to the real value by the fused post-scale kernel below.
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let ptr_sz = std::mem::size_of::<*const c_void>();
for e in 0..batch {
let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void };
let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void };
let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void };
// Device pointer to this expert's scalar weight scale (FP32, 4 bytes).
let b_scale_ptr = unsafe { (b_scales.data_ptr() as *const u8).add(e * 4) as *const c_void };
unsafe {
cublasLtMatmulDescSetAttribute(
plan.desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale_ptr as *const _ as _, ptr_sz,
);
let status = cublasLtMatmul(
handle, plan.desc,
&alpha as *const f32 as _,
b_ptr, // cuBLASLt "A" = weights
plan.a_layout,
a_ptr, // cuBLASLt "B" = activations
plan.b_layout,
&beta as *const f32 as _,
c_ptr, // C (unused with beta=0)
plan.c_layout,
c_ptr, // D = output
plan.d_layout,
&plan.algo,
ws_ptr,
plan.workspace_size,
std::ptr::null_mut(),
);
assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}");
}
unsafe {
let status = cublasLtMatmul(
handle, plan.desc,
&alpha as *const f32 as _,
b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights
plan.a_layout,
a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations
plan.b_layout,
&beta as *const f32 as _,
c.data_ptr() as *const c_void, // C (unused with beta=0)
plan.c_layout,
c.data_ptr() as *mut c_void, // D = output
plan.d_layout,
&plan.algo,
ws_ptr,
plan.workspace_size,
std::ptr::null_mut(),
);
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
}
});
// Post-GEMM: multiply each row of c by its activation scale.
// c is [batch, M, N] BF16. a_scales is [batch * M] F32.
// This recovers the per-token scale that was divided out during quantization.
// Post-GEMM: recover the real result in one pass.
// c[e, t, :] *= a_scales[e*M + t] * b_scales[e]
// (per-token activation scale × per-expert weight scale). BF16's relative
// error is scale-invariant, so applying the scale here is precision-
// equivalent to folding it into the GEMM epilogue.
let total_rows = (batch * m) as i32;
let cols = n as i32;
unsafe {
launch_rowwise_scale_bf16(
launch_rowwise_scale_moe_bf16(
c.data_ptr() as *mut c_void,
a_scales.data_ptr() as *const c_void,
total_rows, cols,
b_scales.data_ptr() as *const c_void,
total_rows, n as i32, m as i32,
std::ptr::null_mut(),
);
}