From d33220498a24e8d565518858be2321915b34c15e Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 12 Jun 2026 15:01:42 +0800 Subject: [PATCH] quantization: MXFP4 W4A16 expert weights (memory-optimization foundation) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 + per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill (M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E]. Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB 5090 with room for KV cache. NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses (350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt block-scaled MXFP4) or a Marlin-class kernel; see docs/benchmarks/mxfp4-and-llama-decode.md. Co-Authored-By: Claude Opus 4.8 --- crates/xserv-kernels/build.rs | 1 + crates/xserv-kernels/src/quantization.rs | 55 ++++++++ crates/xserv-model/src/gpt_oss.rs | 72 +++++++++- csrc/quantization/mxfp4_gemm.cu | 135 +++++++++++++++++++ docs/benchmarks/mxfp4-and-llama-decode.md | 71 ++++++++++ tools/quantize_mxfp4.py | 153 ++++++++++++++++++++++ 6 files changed, 480 insertions(+), 7 deletions(-) create mode 100644 csrc/quantization/mxfp4_gemm.cu create mode 100644 docs/benchmarks/mxfp4-and-llama-decode.md create mode 100644 tools/quantize_mxfp4.py diff --git a/crates/xserv-kernels/build.rs b/crates/xserv-kernels/build.rs index 738ee92..0c63b79 100644 --- a/crates/xserv-kernels/build.rs +++ b/crates/xserv-kernels/build.rs @@ -33,6 +33,7 @@ fn main() { .file("../../csrc/moe/moe_kernels.cu") .file("../../csrc/quantization/dequant_fp8.cu") .file("../../csrc/quantization/quantize_fp8.cu") + .file("../../csrc/quantization/mxfp4_gemm.cu") .compile("xserv_kernels"); println!("cargo:rerun-if-changed=../../csrc/"); diff --git a/crates/xserv-kernels/src/quantization.rs b/crates/xserv-kernels/src/quantization.rs index dde3ee1..292d24b 100644 --- a/crates/xserv-kernels/src/quantization.rs +++ b/crates/xserv-kernels/src/quantization.rs @@ -30,6 +30,14 @@ unsafe extern "C" { num_rows: i32, cols: i32, tokens: i32, stream: *mut c_void, ); + fn launch_batched_gemv_mxfp4_bf16( + x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void, + e: i32, n: i32, k: i32, stream: *mut c_void, + ); + fn launch_dequant_mxfp4_to_bf16_t( + w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void, + e: i32, n: i32, k: i32, stream: *mut c_void, + ); } // ============================================================ @@ -428,3 +436,50 @@ pub fn batched_gemm_fp8( c } + +// ============================================================ +// MXFP4 W4A16 (weight-only 4-bit) for MoE experts +// ============================================================ + +/// MXFP4 W4A16 batched GEMV for decode (M=1). +/// +/// x: [E, K] BF16 (per-expert activation; replicated across experts) +/// w_packed: [E, N, K/2] byte tensor — two E2M1 nibbles per byte (lo = even k) +/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block +/// +/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]). +pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor { + assert_eq!(x.dtype(), DType::BF16); + assert!(x.is_contiguous()); + let e = x.shape()[0]; + assert_eq!(x.shape()[x.ndim() - 1], k, "GEMV K mismatch"); + + let y = Tensor::empty(&[e, n], DType::BF16, x.device()); + unsafe { + launch_batched_gemv_mxfp4_bf16( + x.data_ptr() as *const c_void, + w_packed.data_ptr() as *const c_void, + w_scales.data_ptr() as *const c_void, + y.data_ptr() as *mut c_void, + e as i32, n as i32, k as i32, + std::ptr::null_mut(), + ); + } + y +} + +/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path +/// (the BF16 batched GEMM expects weights as [E, K, N]). +pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor { + let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device()); + unsafe { + launch_dequant_mxfp4_to_bf16_t( + w_packed.data_ptr() as *const c_void, + w_scales.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + e as i32, n as i32, k as i32, + std::ptr::null_mut(), + ); + } + out +} diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index 3fa3842..fc0b8f0 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -53,6 +53,10 @@ struct GptOssBlock { expert_gate_up_scale: Option,// [local_experts] F32 expert_down_fp8: Option, // [local_experts, hidden, inter] FP8E4M3 expert_down_scale: Option, // [local_experts] F32 + // MXFP4 W4A16 expert weights (Some when running 4-bit weight-only). + // (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout. + expert_gate_up_mxfp4: Option<(Tensor, Tensor)>, + expert_down_mxfp4: Option<(Tensor, Tensor)>, local_experts: usize, // Activation params glu_alpha: f32, @@ -169,11 +173,18 @@ impl GptOss { let local_experts = num_experts / world; let expert_start = rank * local_experts; - let is_fp8 = gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3; + // MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype + // as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E], + // MXFP4 scale is 3-D [E, N, K/32]. + let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false); + let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3; - let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size - let hidden = gate_up_3d.shape()[1]; - let inter = down_3d.shape()[1]; // intermediate_size + let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None; + let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None; + + let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N) + let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] }; + let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] }; // Slice the rank's range of experts as contiguous 3D tensors on GPU let expert_gate_up_wt; @@ -183,7 +194,24 @@ impl GptOss { let expert_down_fp8; let expert_down_scale_gpu; - if is_fp8 { + if is_mxfp4 { + // MXFP4 W4A16: weights already [E, N, K] packed ([E, N, K/2] bytes) + // + scales [E, N, K/32]. Slice this rank's experts (raw bytes). + let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale"); + let d_s = down_scale.expect("MXFP4 model missing down_proj_scale"); + let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev); + let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev); + let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev); + let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev); + expert_gate_up_mxfp4 = Some((gu_packed, gu_scl)); + expert_down_mxfp4 = Some((dn_packed, dn_scl)); + expert_gate_up_fp8 = None; + expert_gate_up_scale_gpu = None; + expert_down_fp8 = None; + expert_down_scale_gpu = None; + expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); + expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); + } else if is_fp8 { // FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell). // Original: [E, K, N] → Transposed: [E, N, K] let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2); @@ -243,6 +271,8 @@ impl GptOss { expert_gate_up_scale: expert_gate_up_scale_gpu, expert_down_fp8, expert_down_scale: expert_down_scale_gpu, + expert_gate_up_mxfp4, + expert_down_mxfp4, local_experts, glu_alpha, glu_limit, @@ -254,6 +284,7 @@ impl GptOss { let has_norm_bias = norm_bias.is_some(); let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false); + let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false); if rank == 0 { if has_norm_bias { eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm"); @@ -261,6 +292,9 @@ impl GptOss { if is_fp8 { eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"); } + if is_mxfp4 { + eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"); + } } // Warn about unused weights that the model didn't consume @@ -519,7 +553,20 @@ impl GptOss { let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts); // 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter] - let gate_up = if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 { + let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 { + // MXFP4 W4A16: decode (M=1) uses the fused 4-bit dequant GEMV; prefill + // dequantizes to BF16 then reuses the batched GEMM. + let n = packed.shape()[1]; + let k = packed.shape()[2] * 2; + if num_tokens == 1 { + let x2 = x_rep.reshape(&[local_experts, k]); + xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k) + .reshape(&[local_experts, 1, n]) + } else { + let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k); + xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16) + } + } else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 { // W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep); xserv_kernels::quantization::batched_gemm_fp8( @@ -541,7 +588,18 @@ impl GptOss { let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]); // 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden] - let down = if let Some(ref wt_fp8) = layer.expert_down_fp8 { + let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 { + let n = packed.shape()[1]; + let k = packed.shape()[2] * 2; + if num_tokens == 1 { + let a2 = activated.reshape(&[local_experts, k]); + xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k) + .reshape(&[local_experts, 1, n]) + } else { + let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k); + xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16) + } + } else if let Some(ref wt_fp8) = layer.expert_down_fp8 { // W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated); xserv_kernels::quantization::batched_gemm_fp8( diff --git a/csrc/quantization/mxfp4_gemm.cu b/csrc/quantization/mxfp4_gemm.cu new file mode 100644 index 0000000..2e4f9e8 --- /dev/null +++ b/csrc/quantization/mxfp4_gemm.cu @@ -0,0 +1,135 @@ +#include +#include +#include "../common.cuh" + +// MXFP4 W4A16 for MoE experts. Weights stored [E, N, K] with K (reduction) +// contiguous, blocked by 32: packed 4-bit E2M1 (two nibbles/byte, lo = even k) +// + one UE8M0 scale byte per 32 elements. The decode win is reading 4-bit +// weights from HBM (half of FP8) and dequantizing on-chip to BF16. + +#define MXFP4_BLOCK 32 + +// E2M1 magnitude by 3-bit code; bit 3 is the sign. +__device__ __constant__ float kFp4Levels[8] = {0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f}; + +__device__ __forceinline__ float fp4_to_float(uint8_t code) { + float mag = kFp4Levels[code & 0x7]; + return (code & 0x8) ? -mag : mag; +} + +// Decode (M=1) fused GEMV, batched over experts. +// y[e, n] = sum_k x[e, k] * dequant(W[e, n, k]) +// Grid: (N/TILE_N, E). Each block loads the activation x[e, :] into shared +// memory ONCE and computes TILE_N output columns from it (one warp per column), +// so the activation is read from HBM once per TILE_N outputs instead of once +// per output. Weights are unique per output and read coalesced as uint4; the +// UE8M0 block scale is hoisted to once per 32-element block. +#define MXFP4_TILE_N 8 // output columns per block (= warps per block) + +__global__ void batched_gemv_mxfp4_bf16_kernel( + const __nv_bfloat16* __restrict__ x, // [E, K] + const uint8_t* __restrict__ w_packed, // [E, N, K/2] + const uint8_t* __restrict__ w_scales, // [E, N, K/32] + __nv_bfloat16* __restrict__ y, // [E, N] + int E, int N, int K +) { + extern __shared__ float xs[]; // [K] activation for this expert + int e = blockIdx.y; + int n_base = blockIdx.x * MXFP4_TILE_N; + int warp = threadIdx.x >> 5; // 0..TILE_N-1 + int lane = threadIdx.x & 31; + int nthreads = blockDim.x; // TILE_N * 32 + int nblk = K / MXFP4_BLOCK; + + // Cooperatively stage x[e, :] into shared memory (converted to float). + const __nv_bfloat16* xe = x + (long long)e * K; + for (int k = threadIdx.x; k < K; k += nthreads) { + xs[k] = __bfloat162float(xe[k]); + } + __syncthreads(); + + int n = n_base + warp; + if (n >= N) return; + + const uint8_t* wp = w_packed + ((long long)e * N + n) * (K >> 1); + const uint8_t* ws = w_scales + ((long long)e * N + n) * nblk; + + float acc = 0.0f; + for (int blk = lane; blk < nblk; blk += 32) { + float scale = exp2f((float)((int)ws[blk] - 127)); + uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 16 bytes = 32 nibbles + const uint8_t* pb = (const uint8_t*)&packed; + const float* xk = xs + blk * MXFP4_BLOCK; + #pragma unroll + for (int i = 0; i < 16; i++) { + uint8_t b = pb[i]; + acc += xk[2 * i] * (fp4_to_float(b & 0xF) * scale); + acc += xk[2 * i + 1] * (fp4_to_float(b >> 4) * scale); + } + } + + // Warp reduction. + #pragma unroll + for (int o = 16; o > 0; o >>= 1) { + acc += __shfl_down_sync(0xffffffffu, acc, o); + } + if (lane == 0) y[(long long)e * N + n] = __float2bfloat16(acc); +} + +// Prefill fallback: dequant MXFP4 [E, N, K] -> BF16 [E, K, N] (transposed back +// to the [E, K, N] layout the BF16 batched GEMM expects). Not bandwidth-optimal, +// but prefill is compute-bound so it is not the decode hot path. +__global__ void dequant_mxfp4_to_bf16_t_kernel( + const uint8_t* __restrict__ w_packed, // [E, N, K/2] + const uint8_t* __restrict__ w_scales, // [E, N, K/32] + __nv_bfloat16* __restrict__ out, // [E, K, N] + int E, int N, int K +) { + long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x; + long long total = (long long)E * N * K; + if (idx >= total) return; + int k = idx % K; + int n = (idx / K) % N; + int e = idx / ((long long)N * K); + + int Kh = K >> 1; + int Ks = K / MXFP4_BLOCK; + uint8_t byte = w_packed[((long long)e * N + n) * Kh + (k >> 1)]; + uint8_t code = (k & 1) ? (byte >> 4) : (byte & 0xF); + float scale = exp2f((float)((int)w_scales[((long long)e * N + n) * Ks + k / MXFP4_BLOCK] - 127)); + float val = fp4_to_float(code) * scale; + // write to out[e, k, n] + out[((long long)e * K + k) * N + n] = __float2bfloat16(val); +} + +extern "C" { + +void launch_batched_gemv_mxfp4_bf16( + const void* x, const void* w_packed, const void* w_scales, void* y, + int E, int N, int K, void* stream +) { + dim3 grid((N + MXFP4_TILE_N - 1) / MXFP4_TILE_N, E); + int block = MXFP4_TILE_N * 32; // one warp per output column + size_t smem = (size_t)K * sizeof(float); + batched_gemv_mxfp4_bf16_kernel<<>>( + (const __nv_bfloat16*)x, (const uint8_t*)w_packed, (const uint8_t*)w_scales, + (__nv_bfloat16*)y, E, N, K + ); + CUDA_CHECK_LAST_ERROR(); +} + +void launch_dequant_mxfp4_to_bf16_t( + const void* w_packed, const void* w_scales, void* out, + int E, int N, int K, void* stream +) { + long long total = (long long)E * N * K; + int block = 256; + long long grid = (total + block - 1) / block; + dequant_mxfp4_to_bf16_t_kernel<<<(unsigned)grid, block, 0, (cudaStream_t)stream>>>( + (const uint8_t*)w_packed, (const uint8_t*)w_scales, (__nv_bfloat16*)out, + E, N, K + ); + CUDA_CHECK_LAST_ERROR(); +} + +} diff --git a/docs/benchmarks/mxfp4-and-llama-decode.md b/docs/benchmarks/mxfp4-and-llama-decode.md new file mode 100644 index 0000000..b078d08 --- /dev/null +++ b/docs/benchmarks/mxfp4-and-llama-decode.md @@ -0,0 +1,71 @@ +# MXFP4 W4A16 + decode-speed vs llama.cpp (gpt-oss-20b, 2×RTX 5090) + +## xserv vs llama.cpp — single-stream decode (TP=2, same GPUs) + +`tools/xserv_vs_llama.py` streams identical prompts through each server's +OpenAI endpoint (counting llama's `reasoning_content` as real decode tokens). + +| metric | xserv FP8 | llama MXFP4 | +|---|---|---| +| Decode TPOT (medium) | 13.1 ms | **6.6 ms** (2.0× faster) | +| Throughput | 76 tok/s | **151 tok/s** | +| TTFT (short/medium) | 35–50 ms | 60–63 ms | +| TTFT (long, 1.6k tok) | 94 ms | **35 ms** | + +llama.cpp decodes ~2× faster; prefill is comparable-to-better. + +## Why — decode is memory/comm-bound, not launch-bound + +Traced + measured (not assumed): + +- The 24-layer decode loop is already fully async (no per-layer syncs), so kernel + launches hide behind GPU work — a CUDA graph would buy ~0.5–1.5 ms, not 2×. +- **TP=2→TP=4 probe**: TPOT 13.5→10.2 ms (FP8) with the *same* launch count and + *more* NCCL — confirms the bottleneck is **expert HBM traffic + all-reduce**, + not launch overhead. +- Even FP8 TP=4 (10.2 ms) can't catch llama TP=2 (6.6 ms): the gap is + *algorithmic*. llama is **sparse (top-4 of 32 experts) + 4-bit (MXFP4)**; + xserv is **dense (all 16 local experts) + 8-bit (FP8)** → ~8× the expert bytes + per token. Dense also makes xserv's long-prefill TTFT worse. + +The two levers that close it: **sparse top-k MoE** (≈4×, the bigger structural +change) and **4-bit weights** (≈2×). + +## MXFP4 W4A16 (this change) — correct, smallest, not yet faster than FP8 + +Weight-only 4-bit: expert weights are MXFP4 (E2M1 + per-32 UE8M0 scale, +`tools/quantize_mxfp4.py`); a fused kernel reads the 4-bit weights and +dequantizes on-chip to BF16. Decode uses `batched_gemv_mxfp4`; prefill (M>1) +dequantizes to BF16 then reuses the BF16 batched GEMM. + +| | MXFP4 W4A16 | FP8 W8A8 | BF16 | +|---|---|---|---| +| Model size | **13 GB** | 22 GB | 39 GB | +| Greedy tokens | identical | identical | baseline | +| Decode TPOT (TP=2) | 17.0 ms | **13.5 ms** | 18.8 ms | +| Decode TPOT (TP=4) | 11.8 ms | **10.2 ms** | — | +| Prefill TTFT | 350 ms | **134 ms** | 135 ms | + +- **Correct** (byte-identical greedy tokens to FP8/BF16) and **smallest + footprint** — fits one 32 GB 5090 with ample room for KV cache. +- **Not faster than FP8**: the hand-written W4A16 dequant-GEMV (no tensor cores) + is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even reading half + the bytes it stays ~2–3.5 ms behind FP8 at every TP. The TP=4 scaling + (17→11.8) shows it *is* partly memory-bound; a fixed per-GEMM inefficiency + dominates. Vectorized loads, hoisted scale, warp reduction, and shared-memory + activation tiling did not change it. +- **Prefill regresses** (350 vs 134 ms) — the dequant-to-BF16 fallback. + +Committed as a **memory-optimization foundation**, not a decode speedup. + +## To make 4-bit actually win + +- **FP4 tensor cores (W4A4)** — cuBLASLt block-scaled MXFP4 GEMM + (`CUDA_R_4F_E2M1` + `CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0`, available on + sm_120). Tensor-core throughput *at* 4-bit would beat FP8. Risk: the scale + swizzle layout. +- A **Marlin-class W4A16 kernel** (register-blocked, async-copy pipelined). +- **Sparse top-k MoE** for the larger, llama-matching win. + +FP8 (the plan-cache fix + strided-batched optimization, 1.41× over BF16) remains +xserv's best-performing quantization today. diff --git a/tools/quantize_mxfp4.py b/tools/quantize_mxfp4.py new file mode 100644 index 0000000..63f07a8 --- /dev/null +++ b/tools/quantize_mxfp4.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +"""Quantize gpt-oss expert weights BF16 -> MXFP4 (W4A16 weight-only). + +MXFP4 (OCP microscaling): blocks of 32 consecutive elements along the reduction +(K) dimension share one UE8M0 (power-of-two) scale; each element is FP4 E2M1 +(values {0,±0.5,±1,±1.5,±2,±3,±4,±6}). Effective ~4.25 bits/weight. + +The decode win is purely from reading 4-bit weights from HBM (half the FP8 +traffic, a quarter of BF16); a fused kernel dequantizes on-chip to BF16. + +Output layout (per expert weight, already transposed to [E, N, K] so K is +contiguous and block-of-32 friendly — matches the cuBLASLt-FP8 transpose): + : uint8 [E, N, K//2] two E2M1 nibbles per byte (lo=even k) + _scale : uint8 [E, N, K//32] UE8M0 per 32-element block +Stored in safetensors as F8_E4M3 byte containers (xserv loads them as raw bytes). + +Usage: python quantize_mxfp4.py +""" +import argparse, json, shutil, sys +from pathlib import Path +import numpy as np + +# FP4 E2M1 representable magnitudes by 3-bit code (sign handled separately). +FP4_LEVELS = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32) +# Midpoints for round-to-nearest into the 8 magnitude levels. +FP4_MIDS = (FP4_LEVELS[1:] + FP4_LEVELS[:-1]) / 2.0 # 7 thresholds +BLOCK = 32 + + +def quant_block_mxfp4(w): + """w: [..., K] float32, K % 32 == 0. Returns (packed uint8 [...,K//2], + scales uint8 [...,K//32]) using per-32 UE8M0 shared scale + E2M1 elements.""" + *lead, K = w.shape + nblk = K // BLOCK + wb = w.reshape(*lead, nblk, BLOCK) + amax = np.abs(wb).max(axis=-1) # [..., nblk] + # Shared scale exponent = floor(log2(amax)) - 2 (emax_fp4 = floor(log2 6) = 2). + with np.errstate(divide="ignore"): + e = np.floor(np.log2(np.where(amax > 0, amax, 1.0))).astype(np.int32) - 2 + e = np.clip(e, -127, 127) + ue8m0 = (e + 127).astype(np.uint8) # [..., nblk] + scale = (2.0 ** e.astype(np.float32))[..., None] # [..., nblk, 1] + q = wb / scale # [..., nblk, BLOCK] + sign = (q < 0).astype(np.uint8) + mag = np.abs(q) + code = np.digitize(mag, FP4_MIDS).astype(np.uint8) # 0..7 nearest level + nib = (sign << 3) | code # 4-bit value + nib = nib.reshape(*lead, K) + lo = nib[..., 0::2] + hi = nib[..., 1::2] + packed = (lo | (hi << 4)).astype(np.uint8) # [..., K//2] + return packed, ue8m0.astype(np.uint8) + + +def dequant_mxfp4(packed, scales): + """Inverse, for the self-test. packed [...,K//2] u8, scales [...,K//32] u8.""" + *lead, Kh = packed.shape + K = Kh * 2 + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + nib = np.empty((*lead, K), dtype=np.uint8) + nib[..., 0::2] = lo + nib[..., 1::2] = hi + sign = np.where((nib >> 3) & 1 == 1, -1.0, 1.0) + mag = FP4_LEVELS[nib & 0x7] + e = scales.astype(np.int32) - 127 + scale = (2.0 ** e.astype(np.float32)) + scale = np.repeat(scale, BLOCK, axis=-1) + return (sign * mag) * scale + + +def quant_expert_tensor(t): + # t: [E, K, N] bf16 -> store as [E, N, K] MXFP4 (packed, scales) u8. + # GPU path: numpy on 19G elements is minutes; torch-on-GPU is seconds. + import torch + dev = "cuda" if torch.cuda.is_available() else "cpu" + w = t.transpose(1, 2).contiguous().to(dev, torch.float32) # [E, N, K] + E, N, K = w.shape + nblk = K // BLOCK + wb = w.view(E, N, nblk, BLOCK) + amax = wb.abs().amax(-1) # [E, N, nblk] + amax_safe = torch.where(amax > 0, amax, torch.ones_like(amax)) + e = torch.floor(torch.log2(amax_safe)) - 2 + e = e.clamp(-127, 127) + ue8m0 = (e + 127).to(torch.uint8) # [E, N, nblk] + scale = torch.exp2(e).unsqueeze(-1) # [E, N, nblk, 1] + q = wb / scale + sign = (q < 0).to(torch.uint8) + mids = torch.tensor(FP4_MIDS.tolist(), device=dev) + code = torch.bucketize(q.abs(), mids).to(torch.uint8) # 0..7 nearest level + nib = ((sign << 3) | code).view(E, N, K) + lo = nib[..., 0::2] + hi = nib[..., 1::2] + packed = (lo | (hi << 4)).to(torch.uint8) # [E, N, K/2] + return packed.cpu(), ue8m0.view(E, N, nblk).cpu() + + +def _selftest(): + rng = np.random.default_rng(0) + w = (rng.standard_normal((2, 64)) * 0.3).astype(np.float32) + p, s = quant_block_mxfp4(w) + r = dequant_mxfp4(p, s) + rel = np.abs(r - w).mean() / (np.abs(w).mean() + 1e-9) + print(f"[selftest] mean rel err {rel:.4f} (expect ~0.05-0.12 for FP4)") + assert rel < 0.2, "MXFP4 roundtrip error too high" + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("input_dir", type=Path) + ap.add_argument("output_dir", type=Path) + ap.add_argument("--selftest", action="store_true") + args = ap.parse_args() + if args.selftest: + _selftest(); return + + import torch + from safetensors.torch import load_file, save_file + out = args.output_dir; out.mkdir(parents=True, exist_ok=True) + cfg = json.load(open(args.input_dir / "config.json")) + files = sorted(args.input_dir.glob("*.safetensors")) + tensors = {} + for f in files: + tensors.update(load_file(str(f), device="cpu")) + print(f"loaded {len(tensors)} tensors") + + out_t = {} + nq = 0 + for name, t in tensors.items(): + if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.down_proj"): + print(f" mxfp4 {name} {list(t.shape)} -> [E,N,K] packed") + packed, scales = quant_expert_tensor(t) + # Store as raw bytes via float8_e4m3 container (xserv reads raw bytes). + out_t[name] = packed.view(torch.float8_e4m3fn) + out_t[name + "_scale"] = scales.view(torch.float8_e4m3fn) + nq += 1 + else: + out_t[name] = t + print(f"quantized {nq} expert tensors to MXFP4") + + save_file(out_t, str(out / "model.safetensors")) + cfg["quantization"] = "mxfp4_w4a16" + json.dump(cfg, open(out / "config.json", "w"), indent=2) + for src in args.input_dir.iterdir(): + if src.suffix == ".safetensors" or src.name == "config.json": + continue + if src.is_file() and not (out / src.name).exists(): + shutil.copy2(src, out / src.name) + print("done.") + + +if __name__ == "__main__": + main()