quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)

Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 15:01:42 +08:00
parent e631a71b68
commit d33220498a
6 changed files with 480 additions and 7 deletions

View File

@@ -33,6 +33,7 @@ fn main() {
.file("../../csrc/moe/moe_kernels.cu") .file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/quantization/dequant_fp8.cu") .file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu") .file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")
.compile("xserv_kernels"); .compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/"); println!("cargo:rerun-if-changed=../../csrc/");

View File

@@ -30,6 +30,14 @@ unsafe extern "C" {
num_rows: i32, cols: i32, tokens: i32, num_rows: i32, cols: i32, tokens: i32,
stream: *mut c_void, stream: *mut c_void,
); );
fn launch_batched_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
);
fn launch_dequant_mxfp4_to_bf16_t(
w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void,
e: i32, n: i32, k: i32, stream: *mut c_void,
);
} }
// ============================================================ // ============================================================
@@ -428,3 +436,50 @@ pub fn batched_gemm_fp8(
c c
} }
// ============================================================
// MXFP4 W4A16 (weight-only 4-bit) for MoE experts
// ============================================================
/// MXFP4 W4A16 batched GEMV for decode (M=1).
///
/// x: [E, K] BF16 (per-expert activation; replicated across experts)
/// w_packed: [E, N, K/2] byte tensor — two E2M1 nibbles per byte (lo = even k)
/// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block
///
/// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]).
pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let e = x.shape()[0];
assert_eq!(x.shape()[x.ndim() - 1], k, "GEMV K mismatch");
let y = Tensor::empty(&[e, n], DType::BF16, x.device());
unsafe {
launch_batched_gemv_mxfp4_bf16(
x.data_ptr() as *const c_void,
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
);
}
y
}
/// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path
/// (the BF16 batched GEMM expects weights as [E, K, N]).
pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor {
let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device());
unsafe {
launch_dequant_mxfp4_to_bf16_t(
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32,
std::ptr::null_mut(),
);
}
out
}

View File

@@ -53,6 +53,10 @@ struct GptOssBlock {
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32 expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3 expert_down_fp8: Option<Tensor>, // [local_experts, hidden, inter] FP8E4M3
expert_down_scale: Option<Tensor>, // [local_experts] F32 expert_down_scale: Option<Tensor>, // [local_experts] F32
// MXFP4 W4A16 expert weights (Some when running 4-bit weight-only).
// (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout.
expert_gate_up_mxfp4: Option<(Tensor, Tensor)>,
expert_down_mxfp4: Option<(Tensor, Tensor)>,
local_experts: usize, local_experts: usize,
// Activation params // Activation params
glu_alpha: f32, glu_alpha: f32,
@@ -169,11 +173,18 @@ impl GptOss {
let local_experts = num_experts / world; let local_experts = num_experts / world;
let expert_start = rank * local_experts; let expert_start = rank * local_experts;
let is_fp8 = gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3; // MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype
// as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E],
// MXFP4 scale is 3-D [E, N, K/32].
let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false);
let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None;
let hidden = gate_up_3d.shape()[1]; let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None;
let inter = down_3d.shape()[1]; // intermediate_size
let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N)
let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] };
let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] };
// Slice the rank's range of experts as contiguous 3D tensors on GPU // Slice the rank's range of experts as contiguous 3D tensors on GPU
let expert_gate_up_wt; let expert_gate_up_wt;
@@ -183,7 +194,24 @@ impl GptOss {
let expert_down_fp8; let expert_down_fp8;
let expert_down_scale_gpu; let expert_down_scale_gpu;
if is_fp8 { if is_mxfp4 {
// MXFP4 W4A16: weights already [E, N, K] packed ([E, N, K/2] bytes)
// + scales [E, N, K/32]. Slice this rank's experts (raw bytes).
let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale");
let d_s = down_scale.expect("MXFP4 model missing down_proj_scale");
let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev);
let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev);
let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev);
let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev);
expert_gate_up_mxfp4 = Some((gu_packed, gu_scl));
expert_down_mxfp4 = Some((dn_packed, dn_scl));
expert_gate_up_fp8 = None;
expert_gate_up_scale_gpu = None;
expert_down_fp8 = None;
expert_down_scale_gpu = None;
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
} else if is_fp8 {
// FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell). // FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell).
// Original: [E, K, N] → Transposed: [E, N, K] // Original: [E, K, N] → Transposed: [E, N, K]
let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2); let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2);
@@ -243,6 +271,8 @@ impl GptOss {
expert_gate_up_scale: expert_gate_up_scale_gpu, expert_gate_up_scale: expert_gate_up_scale_gpu,
expert_down_fp8, expert_down_fp8,
expert_down_scale: expert_down_scale_gpu, expert_down_scale: expert_down_scale_gpu,
expert_gate_up_mxfp4,
expert_down_mxfp4,
local_experts, local_experts,
glu_alpha, glu_alpha,
glu_limit, glu_limit,
@@ -254,6 +284,7 @@ impl GptOss {
let has_norm_bias = norm_bias.is_some(); let has_norm_bias = norm_bias.is_some();
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false); let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false);
if rank == 0 { if rank == 0 {
if has_norm_bias { if has_norm_bias {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm"); eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
@@ -261,6 +292,9 @@ impl GptOss {
if is_fp8 { if is_fp8 {
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"); eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)");
} }
if is_mxfp4 {
eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)");
}
} }
// Warn about unused weights that the model didn't consume // Warn about unused weights that the model didn't consume
@@ -519,7 +553,20 @@ impl GptOss {
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts); let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter] // 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
let gate_up = if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 { let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
// MXFP4 W4A16: decode (M=1) uses the fused 4-bit dequant GEMV; prefill
// dequantizes to BF16 then reuses the batched GEMM.
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
if num_tokens == 1 {
let x2 = x_rep.reshape(&[local_experts, k]);
xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16)
}
} else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 {
// W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM // W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM
let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep); let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep);
xserv_kernels::quantization::batched_gemm_fp8( xserv_kernels::quantization::batched_gemm_fp8(
@@ -541,7 +588,18 @@ impl GptOss {
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]); let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden] // 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
let down = if let Some(ref wt_fp8) = layer.expert_down_fp8 { let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
if num_tokens == 1 {
let a2 = activated.reshape(&[local_experts, k]);
xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k)
.reshape(&[local_experts, 1, n])
} else {
let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k);
xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16)
}
} else if let Some(ref wt_fp8) = layer.expert_down_fp8 {
// W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM // W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM
let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated); let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated);
xserv_kernels::quantization::batched_gemm_fp8( xserv_kernels::quantization::batched_gemm_fp8(

View File

@@ -0,0 +1,135 @@
#include <cuda_bf16.h>
#include <cstdint>
#include "../common.cuh"
// MXFP4 W4A16 for MoE experts. Weights stored [E, N, K] with K (reduction)
// contiguous, blocked by 32: packed 4-bit E2M1 (two nibbles/byte, lo = even k)
// + one UE8M0 scale byte per 32 elements. The decode win is reading 4-bit
// weights from HBM (half of FP8) and dequantizing on-chip to BF16.
#define MXFP4_BLOCK 32
// E2M1 magnitude by 3-bit code; bit 3 is the sign.
__device__ __constant__ float kFp4Levels[8] = {0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
__device__ __forceinline__ float fp4_to_float(uint8_t code) {
float mag = kFp4Levels[code & 0x7];
return (code & 0x8) ? -mag : mag;
}
// Decode (M=1) fused GEMV, batched over experts.
// y[e, n] = sum_k x[e, k] * dequant(W[e, n, k])
// Grid: (N/TILE_N, E). Each block loads the activation x[e, :] into shared
// memory ONCE and computes TILE_N output columns from it (one warp per column),
// so the activation is read from HBM once per TILE_N outputs instead of once
// per output. Weights are unique per output and read coalesced as uint4; the
// UE8M0 block scale is hoisted to once per 32-element block.
#define MXFP4_TILE_N 8 // output columns per block (= warps per block)
__global__ void batched_gemv_mxfp4_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [E, K]
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
__nv_bfloat16* __restrict__ y, // [E, N]
int E, int N, int K
) {
extern __shared__ float xs[]; // [K] activation for this expert
int e = blockIdx.y;
int n_base = blockIdx.x * MXFP4_TILE_N;
int warp = threadIdx.x >> 5; // 0..TILE_N-1
int lane = threadIdx.x & 31;
int nthreads = blockDim.x; // TILE_N * 32
int nblk = K / MXFP4_BLOCK;
// Cooperatively stage x[e, :] into shared memory (converted to float).
const __nv_bfloat16* xe = x + (long long)e * K;
for (int k = threadIdx.x; k < K; k += nthreads) {
xs[k] = __bfloat162float(xe[k]);
}
__syncthreads();
int n = n_base + warp;
if (n >= N) return;
const uint8_t* wp = w_packed + ((long long)e * N + n) * (K >> 1);
const uint8_t* ws = w_scales + ((long long)e * N + n) * nblk;
float acc = 0.0f;
for (int blk = lane; blk < nblk; blk += 32) {
float scale = exp2f((float)((int)ws[blk] - 127));
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 16 bytes = 32 nibbles
const uint8_t* pb = (const uint8_t*)&packed;
const float* xk = xs + blk * MXFP4_BLOCK;
#pragma unroll
for (int i = 0; i < 16; i++) {
uint8_t b = pb[i];
acc += xk[2 * i] * (fp4_to_float(b & 0xF) * scale);
acc += xk[2 * i + 1] * (fp4_to_float(b >> 4) * scale);
}
}
// Warp reduction.
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) y[(long long)e * N + n] = __float2bfloat16(acc);
}
// Prefill fallback: dequant MXFP4 [E, N, K] -> BF16 [E, K, N] (transposed back
// to the [E, K, N] layout the BF16 batched GEMM expects). Not bandwidth-optimal,
// but prefill is compute-bound so it is not the decode hot path.
__global__ void dequant_mxfp4_to_bf16_t_kernel(
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
__nv_bfloat16* __restrict__ out, // [E, K, N]
int E, int N, int K
) {
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)E * N * K;
if (idx >= total) return;
int k = idx % K;
int n = (idx / K) % N;
int e = idx / ((long long)N * K);
int Kh = K >> 1;
int Ks = K / MXFP4_BLOCK;
uint8_t byte = w_packed[((long long)e * N + n) * Kh + (k >> 1)];
uint8_t code = (k & 1) ? (byte >> 4) : (byte & 0xF);
float scale = exp2f((float)((int)w_scales[((long long)e * N + n) * Ks + k / MXFP4_BLOCK] - 127));
float val = fp4_to_float(code) * scale;
// write to out[e, k, n]
out[((long long)e * K + k) * N + n] = __float2bfloat16(val);
}
extern "C" {
void launch_batched_gemv_mxfp4_bf16(
const void* x, const void* w_packed, const void* w_scales, void* y,
int E, int N, int K, void* stream
) {
dim3 grid((N + MXFP4_TILE_N - 1) / MXFP4_TILE_N, E);
int block = MXFP4_TILE_N * 32; // one warp per output column
size_t smem = (size_t)K * sizeof(float);
batched_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const uint8_t*)w_packed, (const uint8_t*)w_scales,
(__nv_bfloat16*)y, E, N, K
);
CUDA_CHECK_LAST_ERROR();
}
void launch_dequant_mxfp4_to_bf16_t(
const void* w_packed, const void* w_scales, void* out,
int E, int N, int K, void* stream
) {
long long total = (long long)E * N * K;
int block = 256;
long long grid = (total + block - 1) / block;
dequant_mxfp4_to_bf16_t_kernel<<<(unsigned)grid, block, 0, (cudaStream_t)stream>>>(
(const uint8_t*)w_packed, (const uint8_t*)w_scales, (__nv_bfloat16*)out,
E, N, K
);
CUDA_CHECK_LAST_ERROR();
}
}

View File

@@ -0,0 +1,71 @@
# MXFP4 W4A16 + decode-speed vs llama.cpp (gpt-oss-20b, 2×RTX 5090)
## xserv vs llama.cpp — single-stream decode (TP=2, same GPUs)
`tools/xserv_vs_llama.py` streams identical prompts through each server's
OpenAI endpoint (counting llama's `reasoning_content` as real decode tokens).
| metric | xserv FP8 | llama MXFP4 |
|---|---|---|
| Decode TPOT (medium) | 13.1 ms | **6.6 ms** (2.0× faster) |
| Throughput | 76 tok/s | **151 tok/s** |
| TTFT (short/medium) | 3550 ms | 6063 ms |
| TTFT (long, 1.6k tok) | 94 ms | **35 ms** |
llama.cpp decodes ~2× faster; prefill is comparable-to-better.
## Why — decode is memory/comm-bound, not launch-bound
Traced + measured (not assumed):
- The 24-layer decode loop is already fully async (no per-layer syncs), so kernel
launches hide behind GPU work — a CUDA graph would buy ~0.51.5 ms, not 2×.
- **TP=2→TP=4 probe**: TPOT 13.5→10.2 ms (FP8) with the *same* launch count and
*more* NCCL — confirms the bottleneck is **expert HBM traffic + all-reduce**,
not launch overhead.
- Even FP8 TP=4 (10.2 ms) can't catch llama TP=2 (6.6 ms): the gap is
*algorithmic*. llama is **sparse (top-4 of 32 experts) + 4-bit (MXFP4)**;
xserv is **dense (all 16 local experts) + 8-bit (FP8)** → ~8× the expert bytes
per token. Dense also makes xserv's long-prefill TTFT worse.
The two levers that close it: **sparse top-k MoE** (≈4×, the bigger structural
change) and **4-bit weights** (≈2×).
## MXFP4 W4A16 (this change) — correct, smallest, not yet faster than FP8
Weight-only 4-bit: expert weights are MXFP4 (E2M1 + per-32 UE8M0 scale,
`tools/quantize_mxfp4.py`); a fused kernel reads the 4-bit weights and
dequantizes on-chip to BF16. Decode uses `batched_gemv_mxfp4`; prefill (M>1)
dequantizes to BF16 then reuses the BF16 batched GEMM.
| | MXFP4 W4A16 | FP8 W8A8 | BF16 |
|---|---|---|---|
| Model size | **13 GB** | 22 GB | 39 GB |
| Greedy tokens | identical | identical | baseline |
| Decode TPOT (TP=2) | 17.0 ms | **13.5 ms** | 18.8 ms |
| Decode TPOT (TP=4) | 11.8 ms | **10.2 ms** | — |
| Prefill TTFT | 350 ms | **134 ms** | 135 ms |
- **Correct** (byte-identical greedy tokens to FP8/BF16) and **smallest
footprint** — fits one 32 GB 5090 with ample room for KV cache.
- **Not faster than FP8**: the hand-written W4A16 dequant-GEMV (no tensor cores)
is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even reading half
the bytes it stays ~23.5 ms behind FP8 at every TP. The TP=4 scaling
(17→11.8) shows it *is* partly memory-bound; a fixed per-GEMM inefficiency
dominates. Vectorized loads, hoisted scale, warp reduction, and shared-memory
activation tiling did not change it.
- **Prefill regresses** (350 vs 134 ms) — the dequant-to-BF16 fallback.
Committed as a **memory-optimization foundation**, not a decode speedup.
## To make 4-bit actually win
- **FP4 tensor cores (W4A4)** — cuBLASLt block-scaled MXFP4 GEMM
(`CUDA_R_4F_E2M1` + `CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0`, available on
sm_120). Tensor-core throughput *at* 4-bit would beat FP8. Risk: the scale
swizzle layout.
- A **Marlin-class W4A16 kernel** (register-blocked, async-copy pipelined).
- **Sparse top-k MoE** for the larger, llama-matching win.
FP8 (the plan-cache fix + strided-batched optimization, 1.41× over BF16) remains
xserv's best-performing quantization today.

153
tools/quantize_mxfp4.py Normal file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""Quantize gpt-oss expert weights BF16 -> MXFP4 (W4A16 weight-only).
MXFP4 (OCP microscaling): blocks of 32 consecutive elements along the reduction
(K) dimension share one UE8M0 (power-of-two) scale; each element is FP4 E2M1
(values {0,±0.5,±1,±1.5,±2,±3,±4,±6}). Effective ~4.25 bits/weight.
The decode win is purely from reading 4-bit weights from HBM (half the FP8
traffic, a quarter of BF16); a fused kernel dequantizes on-chip to BF16.
Output layout (per expert weight, already transposed to [E, N, K] so K is
contiguous and block-of-32 friendly — matches the cuBLASLt-FP8 transpose):
<name> : uint8 [E, N, K//2] two E2M1 nibbles per byte (lo=even k)
<name>_scale : uint8 [E, N, K//32] UE8M0 per 32-element block
Stored in safetensors as F8_E4M3 byte containers (xserv loads them as raw bytes).
Usage: python quantize_mxfp4.py <bf16_model_dir> <out_dir>
"""
import argparse, json, shutil, sys
from pathlib import Path
import numpy as np
# FP4 E2M1 representable magnitudes by 3-bit code (sign handled separately).
FP4_LEVELS = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32)
# Midpoints for round-to-nearest into the 8 magnitude levels.
FP4_MIDS = (FP4_LEVELS[1:] + FP4_LEVELS[:-1]) / 2.0 # 7 thresholds
BLOCK = 32
def quant_block_mxfp4(w):
"""w: [..., K] float32, K % 32 == 0. Returns (packed uint8 [...,K//2],
scales uint8 [...,K//32]) using per-32 UE8M0 shared scale + E2M1 elements."""
*lead, K = w.shape
nblk = K // BLOCK
wb = w.reshape(*lead, nblk, BLOCK)
amax = np.abs(wb).max(axis=-1) # [..., nblk]
# Shared scale exponent = floor(log2(amax)) - 2 (emax_fp4 = floor(log2 6) = 2).
with np.errstate(divide="ignore"):
e = np.floor(np.log2(np.where(amax > 0, amax, 1.0))).astype(np.int32) - 2
e = np.clip(e, -127, 127)
ue8m0 = (e + 127).astype(np.uint8) # [..., nblk]
scale = (2.0 ** e.astype(np.float32))[..., None] # [..., nblk, 1]
q = wb / scale # [..., nblk, BLOCK]
sign = (q < 0).astype(np.uint8)
mag = np.abs(q)
code = np.digitize(mag, FP4_MIDS).astype(np.uint8) # 0..7 nearest level
nib = (sign << 3) | code # 4-bit value
nib = nib.reshape(*lead, K)
lo = nib[..., 0::2]
hi = nib[..., 1::2]
packed = (lo | (hi << 4)).astype(np.uint8) # [..., K//2]
return packed, ue8m0.astype(np.uint8)
def dequant_mxfp4(packed, scales):
"""Inverse, for the self-test. packed [...,K//2] u8, scales [...,K//32] u8."""
*lead, Kh = packed.shape
K = Kh * 2
lo = packed & 0x0F
hi = (packed >> 4) & 0x0F
nib = np.empty((*lead, K), dtype=np.uint8)
nib[..., 0::2] = lo
nib[..., 1::2] = hi
sign = np.where((nib >> 3) & 1 == 1, -1.0, 1.0)
mag = FP4_LEVELS[nib & 0x7]
e = scales.astype(np.int32) - 127
scale = (2.0 ** e.astype(np.float32))
scale = np.repeat(scale, BLOCK, axis=-1)
return (sign * mag) * scale
def quant_expert_tensor(t):
# t: [E, K, N] bf16 -> store as [E, N, K] MXFP4 (packed, scales) u8.
# GPU path: numpy on 19G elements is minutes; torch-on-GPU is seconds.
import torch
dev = "cuda" if torch.cuda.is_available() else "cpu"
w = t.transpose(1, 2).contiguous().to(dev, torch.float32) # [E, N, K]
E, N, K = w.shape
nblk = K // BLOCK
wb = w.view(E, N, nblk, BLOCK)
amax = wb.abs().amax(-1) # [E, N, nblk]
amax_safe = torch.where(amax > 0, amax, torch.ones_like(amax))
e = torch.floor(torch.log2(amax_safe)) - 2
e = e.clamp(-127, 127)
ue8m0 = (e + 127).to(torch.uint8) # [E, N, nblk]
scale = torch.exp2(e).unsqueeze(-1) # [E, N, nblk, 1]
q = wb / scale
sign = (q < 0).to(torch.uint8)
mids = torch.tensor(FP4_MIDS.tolist(), device=dev)
code = torch.bucketize(q.abs(), mids).to(torch.uint8) # 0..7 nearest level
nib = ((sign << 3) | code).view(E, N, K)
lo = nib[..., 0::2]
hi = nib[..., 1::2]
packed = (lo | (hi << 4)).to(torch.uint8) # [E, N, K/2]
return packed.cpu(), ue8m0.view(E, N, nblk).cpu()
def _selftest():
rng = np.random.default_rng(0)
w = (rng.standard_normal((2, 64)) * 0.3).astype(np.float32)
p, s = quant_block_mxfp4(w)
r = dequant_mxfp4(p, s)
rel = np.abs(r - w).mean() / (np.abs(w).mean() + 1e-9)
print(f"[selftest] mean rel err {rel:.4f} (expect ~0.05-0.12 for FP4)")
assert rel < 0.2, "MXFP4 roundtrip error too high"
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input_dir", type=Path)
ap.add_argument("output_dir", type=Path)
ap.add_argument("--selftest", action="store_true")
args = ap.parse_args()
if args.selftest:
_selftest(); return
import torch
from safetensors.torch import load_file, save_file
out = args.output_dir; out.mkdir(parents=True, exist_ok=True)
cfg = json.load(open(args.input_dir / "config.json"))
files = sorted(args.input_dir.glob("*.safetensors"))
tensors = {}
for f in files:
tensors.update(load_file(str(f), device="cpu"))
print(f"loaded {len(tensors)} tensors")
out_t = {}
nq = 0
for name, t in tensors.items():
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.down_proj"):
print(f" mxfp4 {name} {list(t.shape)} -> [E,N,K] packed")
packed, scales = quant_expert_tensor(t)
# Store as raw bytes via float8_e4m3 container (xserv reads raw bytes).
out_t[name] = packed.view(torch.float8_e4m3fn)
out_t[name + "_scale"] = scales.view(torch.float8_e4m3fn)
nq += 1
else:
out_t[name] = t
print(f"quantized {nq} expert tensors to MXFP4")
save_file(out_t, str(out / "model.safetensors"))
cfg["quantization"] = "mxfp4_w4a16"
json.dump(cfg, open(out / "config.json", "w"), indent=2)
for src in args.input_dir.iterdir():
if src.suffix == ".safetensors" or src.name == "config.json":
continue
if src.is_file() and not (out / src.name).exists():
shutil.copy2(src, out / src.name)
print("done.")
if __name__ == "__main__":
main()