quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)
Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 + per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill (M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E]. Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB 5090 with room for KV cache. NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses (350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt block-scaled MXFP4) or a Marlin-class kernel; see docs/benchmarks/mxfp4-and-llama-decode.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
135
csrc/quantization/mxfp4_gemm.cu
Normal file
135
csrc/quantization/mxfp4_gemm.cu
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cstdint>
|
||||
#include "../common.cuh"
|
||||
|
||||
// MXFP4 W4A16 for MoE experts. Weights stored [E, N, K] with K (reduction)
|
||||
// contiguous, blocked by 32: packed 4-bit E2M1 (two nibbles/byte, lo = even k)
|
||||
// + one UE8M0 scale byte per 32 elements. The decode win is reading 4-bit
|
||||
// weights from HBM (half of FP8) and dequantizing on-chip to BF16.
|
||||
|
||||
#define MXFP4_BLOCK 32
|
||||
|
||||
// E2M1 magnitude by 3-bit code; bit 3 is the sign.
|
||||
__device__ __constant__ float kFp4Levels[8] = {0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
||||
|
||||
__device__ __forceinline__ float fp4_to_float(uint8_t code) {
|
||||
float mag = kFp4Levels[code & 0x7];
|
||||
return (code & 0x8) ? -mag : mag;
|
||||
}
|
||||
|
||||
// Decode (M=1) fused GEMV, batched over experts.
|
||||
// y[e, n] = sum_k x[e, k] * dequant(W[e, n, k])
|
||||
// Grid: (N/TILE_N, E). Each block loads the activation x[e, :] into shared
|
||||
// memory ONCE and computes TILE_N output columns from it (one warp per column),
|
||||
// so the activation is read from HBM once per TILE_N outputs instead of once
|
||||
// per output. Weights are unique per output and read coalesced as uint4; the
|
||||
// UE8M0 block scale is hoisted to once per 32-element block.
|
||||
#define MXFP4_TILE_N 8 // output columns per block (= warps per block)
|
||||
|
||||
__global__ void batched_gemv_mxfp4_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x, // [E, K]
|
||||
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
|
||||
__nv_bfloat16* __restrict__ y, // [E, N]
|
||||
int E, int N, int K
|
||||
) {
|
||||
extern __shared__ float xs[]; // [K] activation for this expert
|
||||
int e = blockIdx.y;
|
||||
int n_base = blockIdx.x * MXFP4_TILE_N;
|
||||
int warp = threadIdx.x >> 5; // 0..TILE_N-1
|
||||
int lane = threadIdx.x & 31;
|
||||
int nthreads = blockDim.x; // TILE_N * 32
|
||||
int nblk = K / MXFP4_BLOCK;
|
||||
|
||||
// Cooperatively stage x[e, :] into shared memory (converted to float).
|
||||
const __nv_bfloat16* xe = x + (long long)e * K;
|
||||
for (int k = threadIdx.x; k < K; k += nthreads) {
|
||||
xs[k] = __bfloat162float(xe[k]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int n = n_base + warp;
|
||||
if (n >= N) return;
|
||||
|
||||
const uint8_t* wp = w_packed + ((long long)e * N + n) * (K >> 1);
|
||||
const uint8_t* ws = w_scales + ((long long)e * N + n) * nblk;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int blk = lane; blk < nblk; blk += 32) {
|
||||
float scale = exp2f((float)((int)ws[blk] - 127));
|
||||
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 16 bytes = 32 nibbles
|
||||
const uint8_t* pb = (const uint8_t*)&packed;
|
||||
const float* xk = xs + blk * MXFP4_BLOCK;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uint8_t b = pb[i];
|
||||
acc += xk[2 * i] * (fp4_to_float(b & 0xF) * scale);
|
||||
acc += xk[2 * i + 1] * (fp4_to_float(b >> 4) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
// Warp reduction.
|
||||
#pragma unroll
|
||||
for (int o = 16; o > 0; o >>= 1) {
|
||||
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||
}
|
||||
if (lane == 0) y[(long long)e * N + n] = __float2bfloat16(acc);
|
||||
}
|
||||
|
||||
// Prefill fallback: dequant MXFP4 [E, N, K] -> BF16 [E, K, N] (transposed back
|
||||
// to the [E, K, N] layout the BF16 batched GEMM expects). Not bandwidth-optimal,
|
||||
// but prefill is compute-bound so it is not the decode hot path.
|
||||
__global__ void dequant_mxfp4_to_bf16_t_kernel(
|
||||
const uint8_t* __restrict__ w_packed, // [E, N, K/2]
|
||||
const uint8_t* __restrict__ w_scales, // [E, N, K/32]
|
||||
__nv_bfloat16* __restrict__ out, // [E, K, N]
|
||||
int E, int N, int K
|
||||
) {
|
||||
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
||||
long long total = (long long)E * N * K;
|
||||
if (idx >= total) return;
|
||||
int k = idx % K;
|
||||
int n = (idx / K) % N;
|
||||
int e = idx / ((long long)N * K);
|
||||
|
||||
int Kh = K >> 1;
|
||||
int Ks = K / MXFP4_BLOCK;
|
||||
uint8_t byte = w_packed[((long long)e * N + n) * Kh + (k >> 1)];
|
||||
uint8_t code = (k & 1) ? (byte >> 4) : (byte & 0xF);
|
||||
float scale = exp2f((float)((int)w_scales[((long long)e * N + n) * Ks + k / MXFP4_BLOCK] - 127));
|
||||
float val = fp4_to_float(code) * scale;
|
||||
// write to out[e, k, n]
|
||||
out[((long long)e * K + k) * N + n] = __float2bfloat16(val);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_batched_gemv_mxfp4_bf16(
|
||||
const void* x, const void* w_packed, const void* w_scales, void* y,
|
||||
int E, int N, int K, void* stream
|
||||
) {
|
||||
dim3 grid((N + MXFP4_TILE_N - 1) / MXFP4_TILE_N, E);
|
||||
int block = MXFP4_TILE_N * 32; // one warp per output column
|
||||
size_t smem = (size_t)K * sizeof(float);
|
||||
batched_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const uint8_t*)w_packed, (const uint8_t*)w_scales,
|
||||
(__nv_bfloat16*)y, E, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_dequant_mxfp4_to_bf16_t(
|
||||
const void* w_packed, const void* w_scales, void* out,
|
||||
int E, int N, int K, void* stream
|
||||
) {
|
||||
long long total = (long long)E * N * K;
|
||||
int block = 256;
|
||||
long long grid = (total + block - 1) / block;
|
||||
dequant_mxfp4_to_bf16_t_kernel<<<(unsigned)grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const uint8_t*)w_packed, (const uint8_t*)w_scales, (__nv_bfloat16*)out,
|
||||
E, N, K
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user