quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)

Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 15:01:42 +08:00
parent e631a71b68
commit d33220498a
6 changed files with 480 additions and 7 deletions

View File

@@ -0,0 +1,135 @@
#include <cuda_bf16.h>
#include <cstdint>
#include "../common.cuh"
// MXFP4 W4A16 for MoE experts. Weights stored [E, N, K] with K (reduction)
// contiguous, blocked by 32: packed 4-bit E2M1 (two nibbles/byte, lo = even k)
// + one UE8M0 scale byte per 32 elements. The decode win is reading 4-bit
// weights from HBM (half of FP8) and dequantizing on-chip to BF16.
#define MXFP4_BLOCK 32
// E2M1 magnitude by 3-bit code; bit 3 is the sign.
__device__ __constant__ float kFp4Levels[8] = {0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
__device__ __forceinline__ float fp4_to_float(uint8_t code) {
float mag = kFp4Levels[code & 0x7];
return (code & 0x8) ? -mag : mag;
}
// Decode (M=1) fused GEMV, batched over experts.
// y[e, n] = sum_k x[e, k] * dequant(W[e, n, k])
// Grid: (N/TILE_N, E). Each block loads the activation x[e, :] into shared
// memory ONCE and computes TILE_N output columns from it (one warp per column),
// so the activation is read from HBM once per TILE_N outputs instead of once
// per output. Weights are unique per output and read coalesced as uint4; the
// UE8M0 block scale is hoisted to once per 32-element block.
#define MXFP4_TILE_N 8 // output columns per block (= warps per block)
__global__ void batched_gemv_mxfp4_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [E, K]
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
__nv_bfloat16* __restrict__ y, // [E, N]
int E, int N, int K
) {
extern __shared__ float xs[]; // [K] activation for this expert
int e = blockIdx.y;
int n_base = blockIdx.x * MXFP4_TILE_N;
int warp = threadIdx.x >> 5; // 0..TILE_N-1
int lane = threadIdx.x & 31;
int nthreads = blockDim.x; // TILE_N * 32
int nblk = K / MXFP4_BLOCK;
// Cooperatively stage x[e, :] into shared memory (converted to float).
const __nv_bfloat16* xe = x + (long long)e * K;
for (int k = threadIdx.x; k < K; k += nthreads) {
xs[k] = __bfloat162float(xe[k]);
}
__syncthreads();
int n = n_base + warp;
if (n >= N) return;
const uint8_t* wp = w_packed + ((long long)e * N + n) * (K >> 1);
const uint8_t* ws = w_scales + ((long long)e * N + n) * nblk;
float acc = 0.0f;
for (int blk = lane; blk < nblk; blk += 32) {
float scale = exp2f((float)((int)ws[blk] - 127));
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 16 bytes = 32 nibbles
const uint8_t* pb = (const uint8_t*)&packed;
const float* xk = xs + blk * MXFP4_BLOCK;
#pragma unroll
for (int i = 0; i < 16; i++) {
uint8_t b = pb[i];
acc += xk[2 * i] * (fp4_to_float(b & 0xF) * scale);
acc += xk[2 * i + 1] * (fp4_to_float(b >> 4) * scale);
}
}
// Warp reduction.
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) y[(long long)e * N + n] = __float2bfloat16(acc);
}
// Prefill fallback: dequant MXFP4 [E, N, K] -> BF16 [E, K, N] (transposed back
// to the [E, K, N] layout the BF16 batched GEMM expects). Not bandwidth-optimal,
// but prefill is compute-bound so it is not the decode hot path.
__global__ void dequant_mxfp4_to_bf16_t_kernel(
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
__nv_bfloat16* __restrict__ out, // [E, K, N]
int E, int N, int K
) {
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)E * N * K;
if (idx >= total) return;
int k = idx % K;
int n = (idx / K) % N;
int e = idx / ((long long)N * K);
int Kh = K >> 1;
int Ks = K / MXFP4_BLOCK;
uint8_t byte = w_packed[((long long)e * N + n) * Kh + (k >> 1)];
uint8_t code = (k & 1) ? (byte >> 4) : (byte & 0xF);
float scale = exp2f((float)((int)w_scales[((long long)e * N + n) * Ks + k / MXFP4_BLOCK] - 127));
float val = fp4_to_float(code) * scale;
// write to out[e, k, n]
out[((long long)e * K + k) * N + n] = __float2bfloat16(val);
}
extern "C" {
void launch_batched_gemv_mxfp4_bf16(
const void* x, const void* w_packed, const void* w_scales, void* y,
int E, int N, int K, void* stream
) {
dim3 grid((N + MXFP4_TILE_N - 1) / MXFP4_TILE_N, E);
int block = MXFP4_TILE_N * 32; // one warp per output column
size_t smem = (size_t)K * sizeof(float);
batched_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const uint8_t*)w_packed, (const uint8_t*)w_scales,
(__nv_bfloat16*)y, E, N, K
);
CUDA_CHECK_LAST_ERROR();
}
void launch_dequant_mxfp4_to_bf16_t(
const void* w_packed, const void* w_scales, void* out,
int E, int N, int K, void* stream
) {
long long total = (long long)E * N * K;
int block = 256;
long long grid = (total + block - 1) / block;
dequant_mxfp4_to_bf16_t_kernel<<<(unsigned)grid, block, 0, (cudaStream_t)stream>>>(
(const uint8_t*)w_packed, (const uint8_t*)w_scales, (__nv_bfloat16*)out,
E, N, K
);
CUDA_CHECK_LAST_ERROR();
}
}