diff --git a/crates/xserv-kernels/src/quantization.rs b/crates/xserv-kernels/src/quantization.rs index afcd4e5..dde3ee1 100644 --- a/crates/xserv-kernels/src/quantization.rs +++ b/crates/xserv-kernels/src/quantization.rs @@ -23,10 +23,11 @@ unsafe extern "C" { num_rows: i32, cols: i32, stream: *mut c_void, ); - fn launch_rowwise_scale_bf16( + fn launch_rowwise_scale_moe_bf16( data: *mut c_void, - scales: *const c_void, - num_rows: i32, cols: i32, + a_scales: *const c_void, + b_scales: *const c_void, + num_rows: i32, cols: i32, tokens: i32, stream: *mut c_void, ); } @@ -136,11 +137,11 @@ struct Fp8Plan { struct CublasLtContext { handle: CublasLtHandle, workspace: GpuBuffer, - /// Persistent device scalar holding 1.0, used as the A/B scale pointer - /// placeholder. Allocated once instead of per-expert. + /// Persistent device scalar holding 1.0, used as the A/B scale pointer. + /// Scales are applied post-GEMM, so the in-GEMM scales stay 1.0. one_buf: GpuBuffer, - /// Cache of prepared matmul plans keyed by (M, N, K). - plans: HashMap<(usize, usize, usize), Fp8Plan>, + /// Cache of prepared matmul plans keyed by (M, N, K, batch). + plans: HashMap<(usize, usize, usize, usize), Fp8Plan>, } impl CublasLtContext { @@ -154,14 +155,15 @@ impl CublasLtContext { Self { handle, workspace, one_buf, plans: HashMap::new() } } - /// Get the cached plan for (m, n, k), building (and caching) it on first use. - fn plan(&mut self, m: usize, n: usize, k: usize) -> Fp8Plan { - if let Some(p) = self.plans.get(&(m, n, k)) { + /// Get the cached strided-batched plan for (m, n, k, batch), building it on + /// first use. + fn plan(&mut self, m: usize, n: usize, k: usize, batch: usize) -> Fp8Plan { + if let Some(p) = self.plans.get(&(m, n, k, batch)) { return *p; } let one_ptr = self.one_buf.as_ptr() as *const c_void; - let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k) }; - self.plans.insert((m, n, k), plan); + let plan = unsafe { build_fp8_plan(self.handle, one_ptr, m, n, k, batch) }; + self.plans.insert((m, n, k, batch), plan); plan } } @@ -184,16 +186,18 @@ impl Drop for CublasLtContext { } } -/// Build an FP8 matmul plan for one (m, n, k) shape. See `batched_gemm_fp8` -/// for the row-major → cuBLASLt col-major layout mapping (transA=T, transB=N, -/// m_lt=N, n_lt=M, k_lt=K). The B-scale pointer is initialised to `one_ptr` -/// and overwritten per-expert at call time. +/// Build a strided-batched FP8 matmul plan for `batch` experts of one +/// (m, n, k) shape. Row-major → cuBLASLt col-major mapping (transA=T, +/// transB=N, m_lt=N, n_lt=M, k_lt=K). A/B scale pointers stay at 1.0 — both +/// the per-expert weight scale and the per-token activation scale are applied +/// post-GEMM in a fused kernel, which lets all experts run in one matmul. unsafe fn build_fp8_plan( handle: CublasLtHandle, one_ptr: *const c_void, m: usize, n: usize, k: usize, + batch: usize, ) -> Fp8Plan { let m_lt = n as u64; let n_lt = m as u64; @@ -209,17 +213,33 @@ unsafe fn build_fp8_plan( cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz); cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz); + // Per-expert strides in ELEMENTS for the strided-batch layout. + let stride_a = (n * k) as i64; // weights [N, K] + let stride_b = (m * k) as i64; // activations [M, K] + let stride_c = (m * n) as i64; // output [M, N] + let bc = batch as i32; + let set_batch = |layout: CublasLtMatrixLayout, stride: i64| { + cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc as *const i32 as _, 4); + cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride as *const i64 as _, 8); + }; + // "A" layout (weights, transposed): physical (K, N) col-major, ld=K let mut a_layout: CublasLtMatrixLayout = std::ptr::null_mut(); cublasLtMatrixLayoutCreate(&mut a_layout, CUDA_R_8F_E4M3, k_lt, m_lt, k as i64); + set_batch(a_layout, stride_a); // "B" layout (activations): physical (K, M) col-major, ld=K let mut b_layout: CublasLtMatrixLayout = std::ptr::null_mut(); cublasLtMatrixLayoutCreate(&mut b_layout, CUDA_R_8F_E4M3, k_lt, n_lt, k as i64); + set_batch(b_layout, stride_b); // "C"/"D" layout (output): physical (N, M) col-major, ld=N let mut c_layout: CublasLtMatrixLayout = std::ptr::null_mut(); cublasLtMatrixLayoutCreate(&mut c_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64); + set_batch(c_layout, stride_c); let mut d_layout: CublasLtMatrixLayout = std::ptr::null_mut(); cublasLtMatrixLayoutCreate(&mut d_layout, CUDA_R_16BF, m_lt, n_lt, m_lt as i64); + set_batch(d_layout, stride_c); let mut pref: CublasLtMatmulPreference = std::ptr::null_mut(); cublasLtMatmulPreferenceCreate(&mut pref); @@ -233,7 +253,7 @@ unsafe fn build_fp8_plan( pref, 1, &mut heuristic, &mut found, ); assert!(status == 0 && found > 0, - "cublasLtMatmulAlgoGetHeuristic failed for FP8 GEMM (m={m}, n={n}, k={k}): status={status}, found={found}"); + "cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"); cublasLtMatmulPreferenceDestroy(pref); Fp8Plan { @@ -354,71 +374,54 @@ pub fn batched_gemm_fp8( let c = Tensor::empty(&[batch, m, n], DType::BF16, a_fp8.device()); - // Strides (in bytes) for one expert slice - let stride_a = m * k; // FP8: 1 byte per elem - let stride_b = n * k; // FP8: 1 byte per elem (transposed: [N, K]) - let stride_c = m * n * 2; // BF16: 2 bytes per elem - CUBLASLT_CTX.with(|cell| { let mut ctx = cell.borrow_mut(); let handle = ctx.handle; let ws_ptr = ctx.workspace.as_ptr() as *mut c_void; - // Build (or fetch) the cached plan for this shape — heuristic search and - // descriptor/layout creation happen once per (m, n, k), not per-expert. - let plan = ctx.plan(m, n, k); + // Cached strided-batched plan: heuristic + descriptor/layout creation + // happen once per (m, n, k, batch). All experts run in ONE matmul. + let plan = ctx.plan(m, n, k, batch); - // alpha=1, beta=0. Per-expert weight scale is supplied via the cuBLASLt - // B-scale pointer (device, scalar): cuBLASLt computes in the FP32 epilogue - // D = (1.0 * A_fp8) @ (b_scale[e] * B_fp8)^T = b_scale[e] * (A_fp8 @ B_fp8^T) - // Per-token activation scale (a_scale) is applied post-GEMM (per row). + // alpha=1, beta=0, in-GEMM scales=1.0. The unscaled result + // D_raw[e] = A_fp8[e] @ B_fp8[e]^T + // is recovered to the real value by the fused post-scale kernel below. let alpha: f32 = 1.0; let beta: f32 = 0.0; - let ptr_sz = std::mem::size_of::<*const c_void>(); - for e in 0..batch { - let a_ptr = unsafe { (a_fp8.data_ptr() as *const u8).add(e * stride_a) as *const c_void }; - let b_ptr = unsafe { (b_fp8_t.data_ptr() as *const u8).add(e * stride_b) as *const c_void }; - let c_ptr = unsafe { (c.data_ptr() as *mut u8).add(e * stride_c) as *mut c_void }; - // Device pointer to this expert's scalar weight scale (FP32, 4 bytes). - let b_scale_ptr = unsafe { (b_scales.data_ptr() as *const u8).add(e * 4) as *const c_void }; - - unsafe { - cublasLtMatmulDescSetAttribute( - plan.desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &b_scale_ptr as *const _ as _, ptr_sz, - ); - let status = cublasLtMatmul( - handle, plan.desc, - &alpha as *const f32 as _, - b_ptr, // cuBLASLt "A" = weights - plan.a_layout, - a_ptr, // cuBLASLt "B" = activations - plan.b_layout, - &beta as *const f32 as _, - c_ptr, // C (unused with beta=0) - plan.c_layout, - c_ptr, // D = output - plan.d_layout, - &plan.algo, - ws_ptr, - plan.workspace_size, - std::ptr::null_mut(), - ); - assert_eq!(status, 0, "cublasLtMatmul FP8 failed for expert {e}: status={status}"); - } + unsafe { + let status = cublasLtMatmul( + handle, plan.desc, + &alpha as *const f32 as _, + b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights + plan.a_layout, + a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations + plan.b_layout, + &beta as *const f32 as _, + c.data_ptr() as *const c_void, // C (unused with beta=0) + plan.c_layout, + c.data_ptr() as *mut c_void, // D = output + plan.d_layout, + &plan.algo, + ws_ptr, + plan.workspace_size, + std::ptr::null_mut(), + ); + assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}"); } }); - // Post-GEMM: multiply each row of c by its activation scale. - // c is [batch, M, N] BF16. a_scales is [batch * M] F32. - // This recovers the per-token scale that was divided out during quantization. + // Post-GEMM: recover the real result in one pass. + // c[e, t, :] *= a_scales[e*M + t] * b_scales[e] + // (per-token activation scale × per-expert weight scale). BF16's relative + // error is scale-invariant, so applying the scale here is precision- + // equivalent to folding it into the GEMM epilogue. let total_rows = (batch * m) as i32; - let cols = n as i32; unsafe { - launch_rowwise_scale_bf16( + launch_rowwise_scale_moe_bf16( c.data_ptr() as *mut c_void, a_scales.data_ptr() as *const c_void, - total_rows, cols, + b_scales.data_ptr() as *const c_void, + total_rows, n as i32, m as i32, std::ptr::null_mut(), ); } diff --git a/csrc/quantization/quantize_fp8.cu b/csrc/quantization/quantize_fp8.cu index 1b52a2f..a0bf610 100644 --- a/csrc/quantization/quantize_fp8.cu +++ b/csrc/quantization/quantize_fp8.cu @@ -86,6 +86,29 @@ __global__ void rowwise_scale_bf16_kernel( } } +// Combined dequant scale for batched MoE FP8 GEMM output. +// data[row, :] *= a_scales[row] * b_scales[row / tokens] +// where row = expert * tokens + token. a_scales is the per-token activation +// scale; b_scales is the per-expert scalar weight scale. Lets a single +// strided-batched FP8 matmul (alpha=1, scales=1) recover the real result in +// one pass instead of folding the weight scale into a per-expert GEMM call. +__global__ void rowwise_scale_moe_bf16_kernel( + __nv_bfloat16* __restrict__ data, + const float* __restrict__ a_scales, + const float* __restrict__ b_scales, + int num_rows, int cols, int tokens +) { + int row = blockIdx.x; + if (row >= num_rows) return; + int tid = threadIdx.x; + float s = a_scales[row] * b_scales[row / tokens]; + __nv_bfloat16* row_data = data + (long long)row * cols; + for (int i = tid; i < cols; i += blockDim.x) { + float v = __bfloat162float(row_data[i]) * s; + row_data[i] = __float2bfloat16(v); + } +} + extern "C" { void launch_rowwise_scale_bf16( @@ -102,6 +125,20 @@ void launch_rowwise_scale_bf16( CUDA_CHECK_LAST_ERROR(); } +void launch_rowwise_scale_moe_bf16( + void* data, const void* a_scales, const void* b_scales, + int num_rows, int cols, int tokens, + void* stream +) { + int block = 256; + int grid = num_rows; + rowwise_scale_moe_bf16_kernel<<>>( + (__nv_bfloat16*)data, (const float*)a_scales, (const float*)b_scales, + num_rows, cols, tokens + ); + CUDA_CHECK_LAST_ERROR(); +} + void launch_quantize_bf16_to_fp8e4m3_rowwise( const void* src, void* dst, diff --git a/docs/benchmarks/fp8-quantization.md b/docs/benchmarks/fp8-quantization.md index 30da244..a1da2ab 100644 --- a/docs/benchmarks/fp8-quantization.md +++ b/docs/benchmarks/fp8-quantization.md @@ -14,9 +14,10 @@ stay BF16. - **Activations**: quantized dynamically at runtime, **per-token** (per-row absmax), recovered by a post-GEMM row scale. - **Compute**: `batched_gemm_fp8` (`crates/xserv-kernels/src/quantization.rs`) - runs one cuBLASLt FP8 matmul per expert; the per-expert weight scale is - supplied via the cuBLASLt B-scale device pointer (FP32 epilogue, so precision - matches folding it into `alpha`). + runs **one strided-batched cuBLASLt FP8 matmul for all experts** (`alpha=1`, + in-GEMM scales `1.0`); a fused kernel then applies `a_scale[token]·b_scale[expert]` + in a single pass. BF16's relative error is scale-invariant, so applying both + scales post-GEMM is precision-equivalent to folding them into the epilogue. - Model size: **22 GB** (FP8) vs **39 GB** (BF16). The FP8 model fits on a single 32 GB 5090; BF16 needs ≥ 2. @@ -34,34 +35,50 @@ decoded token. This made FP8 **slower than BF16**: | Throughput | 37 tok/s | **55.8 tok/s** | 53.2 tok/s | Fix: cache the cuBLASLt plan (descriptor + layouts + heuristically-chosen algo) -in a thread-local map keyed by `(M, N, K)` so the heuristic runs once per shape; -allocate the scale buffer once; pass per-expert weight scales by device pointer. -The per-expert loop now issues only `cublasLtMatmul`. +in a thread-local map keyed by `(M, N, K, batch)` so the heuristic runs once per +shape, and allocate the scale buffer once. -## Results — GSM8K (200 problems, greedy, TP=2 on the same 2 GPUs) +## Reducing launches: one strided-batched matmul + +The per-expert loop still issued one `cublasLtMatmul` per expert — ~768 tiny +launches per decoded token (16 local experts × 2 GEMMs × 24 layers). Collapsing +each MoE GEMM into a single **strided-batched** cuBLASLt FP8 matmul (BATCH_COUNT ++ strided-batch offsets) drops that to ~48, with a fused post-scale kernel +applying both scales. This required moving the per-expert weight scale out of the +GEMM epilogue (a single strided call can't carry a per-batch scalar) into the +post-scale kernel — precision-equivalent, as noted above. + +| (gpt-oss-20b, TP=2) | per-expert FP8 | batched FP8 | BF16 | +|---|---|---|---| +| Decode TPOT | 17.9 ms | **13.8 ms** | 18.8 ms | +| Throughput | 55.8 tok/s | **72.3 tok/s** | 53.2 tok/s | + +## Results — GSM8K (greedy, TP=2 on the same 2 GPUs) + +200-problem run is the per-expert plan-cache fix; 100-problem run is the +strided-batched version. BF16 is the unchanged baseline in both. Harness: `tools/fp8_compare.py` — a warm `xserv-server` per model, GSM8K streamed through `/v1/chat/completions`; TTFT = time to first token, TPOT = mean inter-token latency, per request. -| metric | FP8 W8A8 | BF16 | -|---|---|---| -| GSM8K accuracy | **93.0 %** | 90.5 % | -| TTFT median | 67.4 ms | 68.8 ms | -| TTFT p90 | 90.4 ms | 96.7 ms | -| TPOT median | **17.45 ms** | 18.26 ms | -| TPOT p90 | 17.65 ms | 18.38 ms | -| Throughput | **57.3 tok/s** | 54.8 tok/s | -| Mean output tokens | 288 | 293 | +| metric | FP8 per-expert (n=200) | FP8 batched (n=100) | BF16 | +|---|---|---|---| +| GSM8K accuracy | 93.0 % | 91.0 % | 90.5 / 90.0 % | +| TTFT median | 67.4 ms | 65.0 ms | 68.8 / 69.5 ms | +| TPOT median | 17.45 ms | **13.08 ms** | 18.26 / 18.39 ms | +| TPOT p90 | 17.65 ms | **13.28 ms** | 18.38 / 18.52 ms | +| Throughput | 57.3 tok/s | **76.4 tok/s** | 54.8 / 54.4 tok/s | +| Decode speedup vs BF16 | 1.05× | **1.41×** | 1.00× | -- **Accuracy: unchanged.** FP8 is nominally +2.5 pts, but with n=200 the - standard error is ~2.1 pts, so the two are statistically indistinguishable. - The takeaway is that FP8 did **not** degrade accuracy. -- **Decode: FP8 ~5 % faster** (TPOT 17.45 vs 18.26 ms), reproducible across - runs, with a tighter p90. Modest because the dense-MoE path loads *all* - experts every token and FP8 only halves the *expert* bytes; the per-expert - M=1 launches and M=1 tensor-core inefficiency absorb much of the bandwidth - saving. +- **Accuracy: unchanged.** FP8 is nominally +0.5 … +2.5 pts above BF16, but at + n=100–200 the standard error is ~2–3 pts, so they are statistically + indistinguishable. The takeaway is that neither FP8 quantization nor the + strided-batched rounding degrades accuracy. +- **Decode: FP8 1.41× faster** once batched (TPOT 13.08 vs 18.39 ms), with a + tight p90. The per-expert version was only ~1.05× — the ~768 tiny M=1 launches + per token dominated; batching them into ~48 unlocked most of the FP8 + expert-weight-bandwidth saving. - **Prefill (TTFT): comparable.** A multi-length sweep (113 / 561 / 1681 tokens) gave FP8 480 / 362 / 2451 ms vs BF16 558 / 282 / 2287 ms — non-monotonic, i.e. dominated by fixed overhead (cuBLAS lazy init + FP8's one-time per-shape @@ -75,9 +92,8 @@ that otherwise needs two GPUs onto one — is the largest practical win. ## Follow-ups (not done) -- Strided-batched FP8 (one call instead of ~768 per-expert launches per token) — - requires folding the per-expert weight scale into the post-scale kernel, at a - BF16-intermediate precision cost. - Per-channel (per-output-row) weight scales for better accuracy headroom than per-tensor. - Warm common prefill shapes at load to hide the first-request heuristic stall. +- Sparse (top-k only) MoE compute instead of dense — currently every token runs + all experts, so only ~top_k/num_experts of the FP8 GEMM work is used.