quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)

Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 15:01:42 +08:00
parent e631a71b68
commit d33220498a
6 changed files with 480 additions and 7 deletions

153
tools/quantize_mxfp4.py Normal file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""Quantize gpt-oss expert weights BF16 -> MXFP4 (W4A16 weight-only).
MXFP4 (OCP microscaling): blocks of 32 consecutive elements along the reduction
(K) dimension share one UE8M0 (power-of-two) scale; each element is FP4 E2M1
(values {0,±0.5,±1,±1.5,±2,±3,±4,±6}). Effective ~4.25 bits/weight.
The decode win is purely from reading 4-bit weights from HBM (half the FP8
traffic, a quarter of BF16); a fused kernel dequantizes on-chip to BF16.
Output layout (per expert weight, already transposed to [E, N, K] so K is
contiguous and block-of-32 friendly — matches the cuBLASLt-FP8 transpose):
<name> : uint8 [E, N, K//2] two E2M1 nibbles per byte (lo=even k)
<name>_scale : uint8 [E, N, K//32] UE8M0 per 32-element block
Stored in safetensors as F8_E4M3 byte containers (xserv loads them as raw bytes).
Usage: python quantize_mxfp4.py <bf16_model_dir> <out_dir>
"""
import argparse, json, shutil, sys
from pathlib import Path
import numpy as np
# FP4 E2M1 representable magnitudes by 3-bit code (sign handled separately).
FP4_LEVELS = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32)
# Midpoints for round-to-nearest into the 8 magnitude levels.
FP4_MIDS = (FP4_LEVELS[1:] + FP4_LEVELS[:-1]) / 2.0 # 7 thresholds
BLOCK = 32
def quant_block_mxfp4(w):
"""w: [..., K] float32, K % 32 == 0. Returns (packed uint8 [...,K//2],
scales uint8 [...,K//32]) using per-32 UE8M0 shared scale + E2M1 elements."""
*lead, K = w.shape
nblk = K // BLOCK
wb = w.reshape(*lead, nblk, BLOCK)
amax = np.abs(wb).max(axis=-1) # [..., nblk]
# Shared scale exponent = floor(log2(amax)) - 2 (emax_fp4 = floor(log2 6) = 2).
with np.errstate(divide="ignore"):
e = np.floor(np.log2(np.where(amax > 0, amax, 1.0))).astype(np.int32) - 2
e = np.clip(e, -127, 127)
ue8m0 = (e + 127).astype(np.uint8) # [..., nblk]
scale = (2.0 ** e.astype(np.float32))[..., None] # [..., nblk, 1]
q = wb / scale # [..., nblk, BLOCK]
sign = (q < 0).astype(np.uint8)
mag = np.abs(q)
code = np.digitize(mag, FP4_MIDS).astype(np.uint8) # 0..7 nearest level
nib = (sign << 3) | code # 4-bit value
nib = nib.reshape(*lead, K)
lo = nib[..., 0::2]
hi = nib[..., 1::2]
packed = (lo | (hi << 4)).astype(np.uint8) # [..., K//2]
return packed, ue8m0.astype(np.uint8)
def dequant_mxfp4(packed, scales):
"""Inverse, for the self-test. packed [...,K//2] u8, scales [...,K//32] u8."""
*lead, Kh = packed.shape
K = Kh * 2
lo = packed & 0x0F
hi = (packed >> 4) & 0x0F
nib = np.empty((*lead, K), dtype=np.uint8)
nib[..., 0::2] = lo
nib[..., 1::2] = hi
sign = np.where((nib >> 3) & 1 == 1, -1.0, 1.0)
mag = FP4_LEVELS[nib & 0x7]
e = scales.astype(np.int32) - 127
scale = (2.0 ** e.astype(np.float32))
scale = np.repeat(scale, BLOCK, axis=-1)
return (sign * mag) * scale
def quant_expert_tensor(t):
# t: [E, K, N] bf16 -> store as [E, N, K] MXFP4 (packed, scales) u8.
# GPU path: numpy on 19G elements is minutes; torch-on-GPU is seconds.
import torch
dev = "cuda" if torch.cuda.is_available() else "cpu"
w = t.transpose(1, 2).contiguous().to(dev, torch.float32) # [E, N, K]
E, N, K = w.shape
nblk = K // BLOCK
wb = w.view(E, N, nblk, BLOCK)
amax = wb.abs().amax(-1) # [E, N, nblk]
amax_safe = torch.where(amax > 0, amax, torch.ones_like(amax))
e = torch.floor(torch.log2(amax_safe)) - 2
e = e.clamp(-127, 127)
ue8m0 = (e + 127).to(torch.uint8) # [E, N, nblk]
scale = torch.exp2(e).unsqueeze(-1) # [E, N, nblk, 1]
q = wb / scale
sign = (q < 0).to(torch.uint8)
mids = torch.tensor(FP4_MIDS.tolist(), device=dev)
code = torch.bucketize(q.abs(), mids).to(torch.uint8) # 0..7 nearest level
nib = ((sign << 3) | code).view(E, N, K)
lo = nib[..., 0::2]
hi = nib[..., 1::2]
packed = (lo | (hi << 4)).to(torch.uint8) # [E, N, K/2]
return packed.cpu(), ue8m0.view(E, N, nblk).cpu()
def _selftest():
rng = np.random.default_rng(0)
w = (rng.standard_normal((2, 64)) * 0.3).astype(np.float32)
p, s = quant_block_mxfp4(w)
r = dequant_mxfp4(p, s)
rel = np.abs(r - w).mean() / (np.abs(w).mean() + 1e-9)
print(f"[selftest] mean rel err {rel:.4f} (expect ~0.05-0.12 for FP4)")
assert rel < 0.2, "MXFP4 roundtrip error too high"
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input_dir", type=Path)
ap.add_argument("output_dir", type=Path)
ap.add_argument("--selftest", action="store_true")
args = ap.parse_args()
if args.selftest:
_selftest(); return
import torch
from safetensors.torch import load_file, save_file
out = args.output_dir; out.mkdir(parents=True, exist_ok=True)
cfg = json.load(open(args.input_dir / "config.json"))
files = sorted(args.input_dir.glob("*.safetensors"))
tensors = {}
for f in files:
tensors.update(load_file(str(f), device="cpu"))
print(f"loaded {len(tensors)} tensors")
out_t = {}
nq = 0
for name, t in tensors.items():
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.down_proj"):
print(f" mxfp4 {name} {list(t.shape)} -> [E,N,K] packed")
packed, scales = quant_expert_tensor(t)
# Store as raw bytes via float8_e4m3 container (xserv reads raw bytes).
out_t[name] = packed.view(torch.float8_e4m3fn)
out_t[name + "_scale"] = scales.view(torch.float8_e4m3fn)
nq += 1
else:
out_t[name] = t
print(f"quantized {nq} expert tensors to MXFP4")
save_file(out_t, str(out / "model.safetensors"))
cfg["quantization"] = "mxfp4_w4a16"
json.dump(cfg, open(out / "config.json", "w"), indent=2)
for src in args.input_dir.iterdir():
if src.suffix == ".safetensors" or src.name == "config.json":
continue
if src.is_file() and not (out / src.name).exists():
shutil.copy2(src, out / src.name)
print("done.")
if __name__ == "__main__":
main()