quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)

Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 15:01:42 +08:00
parent e631a71b68
commit d33220498a
6 changed files with 480 additions and 7 deletions

View File

@@ -33,6 +33,7 @@ fn main() {
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");

View File

@@ -30,6 +30,14 @@ unsafe extern "C" {
num_rows: i32, cols: i32, tokens: i32,
stream: *mut c_void,
);
fn launch_batched_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
);
fn launch_dequant_mxfp4_to_bf16_t(
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
);
}
// ============================================================
@@ -428,3 +436,50 @@ pub fn batched_gemm_fp8(
c
}
// ============================================================
// MXFP4 W4A16 (weight-only 4-bit) for MoE experts
// ============================================================
/// MXFP4 W4A16 batched GEMV for decode (M=1).
///
/// x: [E, K] BF16 (per-expert activation; replicated across experts)
/// w_packed: [E, N, K/2] byte tensor — two E2M1 nibbles per byte (lo = even k)
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
///
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let e = x.shape()[0];
assert_eq!(x.shape()[x.ndim() - 1], k, "GEMV K mismatch");
let y = Tensor::empty(&[e, n], DType::BF16, x.device());
unsafe {
launch_batched_gemv_mxfp4_bf16(
x.data_ptr() as *const c_void,
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
);
}
y
}
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
/// (the BF16 batched GEMM expects weights as [E, K, N]).
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
unsafe {
launch_dequant_mxfp4_to_bf16_t(
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
);
}
out
}

View File

@@ -53,6 +53,10 @@ struct GptOssBlock {
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32
// MXFP4 W4A16 expert weights (Some when running 4-bit weight-only).
// (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout.
expert_gate_up_mxfp4: Option<(Tensor, Tensor)>,
expert_down_mxfp4: Option<(Tensor, Tensor)>,
local_experts: usize,
// Activation params
glu_alpha: f32,
@@ -169,11 +173,18 @@ impl GptOss {
let local_experts = num_experts / world;
let expert_start = rank * local_experts;
let is_fp8 = gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
// MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
// MXFP4 scale is 3-D [E, N, K/32].
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size
let hidden = gate_up_3d.shape()[1];
let inter = down_3d.shape()[1]; // intermediate_size
let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
// Slice the rank's range of experts as contiguous 3D tensors on GPU
let expert_gate_up_wt;
@@ -183,7 +194,24 @@ impl GptOss {
let expert_down_fp8;
let expert_down_scale_gpu;
if is_fp8 {
if is_mxfp4 {
// MXFP4 W4A16: weights already [E, N, K] packed ([E, N, K/2] bytes)
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
expert_down_mxfp4 = Some((dn_packed, dn_scl));
expert_gate_up_fp8 = None;
expert_gate_up_scale_gpu = None;
expert_down_fp8 = None;
expert_down_scale_gpu = None;
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
} else if is_fp8 {
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
// Original: [E, K, N] → Transposed: [E, N, K]
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
@@ -243,6 +271,8 @@ impl GptOss {
expert_gate_up_scale: expert_gate_up_scale_gpu,
expert_down_fp8,
expert_down_scale: expert_down_scale_gpu,
expert_gate_up_mxfp4,
expert_down_mxfp4,
local_experts,
glu_alpha,
glu_limit,
@@ -254,6 +284,7 @@ impl GptOss {
let has_norm_bias = norm_bias.is_some();
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
if rank == 0 {
if has_norm_bias {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
@@ -261,6 +292,9 @@ impl GptOss {
if is_fp8 {
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
}
if is_mxfp4 {
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
}
}
// Warn about unused weights that the model didn't consume
@@ -519,7 +553,20 @@ impl GptOss {
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
let gate_up = if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
// MXFP4 W4A16: decode (M=1) uses the fused 4-bit dequant GEMV; prefill
// dequantizes to BF16 then reuses the batched GEMM.
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
if num_tokens == 1 {
let x2 = x_rep.reshape(&[local_experts, k]);
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
}
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
xserv_kernels::quantization::batched_gemm_fp8(
@@ -541,7 +588,18 @@ impl GptOss {
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
let down = if let Some(ref wt_fp8) = layer.expert_down_fp8 {
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
if num_tokens == 1 {
let a2 = activated.reshape(&[local_experts, k]);
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
}
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
xserv_kernels::quantization::batched_gemm_fp8(