Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 + per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill (M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E]. Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB 5090 with room for KV cache. NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses (350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt block-scaled MXFP4) or a Marlin-class kernel; see docs/benchmarks/mxfp4-and-llama-decode.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
154 lines
6.2 KiB
Python
154 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Quantize gpt-oss expert weights BF16 -> MXFP4 (W4A16 weight-only).
|
|
|
|
MXFP4 (OCP microscaling): blocks of 32 consecutive elements along the reduction
|
|
(K) dimension share one UE8M0 (power-of-two) scale; each element is FP4 E2M1
|
|
(values {0,±0.5,±1,±1.5,±2,±3,±4,±6}). Effective ~4.25 bits/weight.
|
|
|
|
The decode win is purely from reading 4-bit weights from HBM (half the FP8
|
|
traffic, a quarter of BF16); a fused kernel dequantizes on-chip to BF16.
|
|
|
|
Output layout (per expert weight, already transposed to [E, N, K] so K is
|
|
contiguous and block-of-32 friendly — matches the cuBLASLt-FP8 transpose):
|
|
<name> : uint8 [E, N, K//2] two E2M1 nibbles per byte (lo=even k)
|
|
<name>_scale : uint8 [E, N, K//32] UE8M0 per 32-element block
|
|
Stored in safetensors as F8_E4M3 byte containers (xserv loads them as raw bytes).
|
|
|
|
Usage: python quantize_mxfp4.py <bf16_model_dir> <out_dir>
|
|
"""
|
|
import argparse, json, shutil, sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
|
|
# FP4 E2M1 representable magnitudes by 3-bit code (sign handled separately).
|
|
FP4_LEVELS = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32)
|
|
# Midpoints for round-to-nearest into the 8 magnitude levels.
|
|
FP4_MIDS = (FP4_LEVELS[1:] + FP4_LEVELS[:-1]) / 2.0 # 7 thresholds
|
|
BLOCK = 32
|
|
|
|
|
|
def quant_block_mxfp4(w):
|
|
"""w: [..., K] float32, K % 32 == 0. Returns (packed uint8 [...,K//2],
|
|
scales uint8 [...,K//32]) using per-32 UE8M0 shared scale + E2M1 elements."""
|
|
*lead, K = w.shape
|
|
nblk = K // BLOCK
|
|
wb = w.reshape(*lead, nblk, BLOCK)
|
|
amax = np.abs(wb).max(axis=-1) # [..., nblk]
|
|
# Shared scale exponent = floor(log2(amax)) - 2 (emax_fp4 = floor(log2 6) = 2).
|
|
with np.errstate(divide="ignore"):
|
|
e = np.floor(np.log2(np.where(amax > 0, amax, 1.0))).astype(np.int32) - 2
|
|
e = np.clip(e, -127, 127)
|
|
ue8m0 = (e + 127).astype(np.uint8) # [..., nblk]
|
|
scale = (2.0 ** e.astype(np.float32))[..., None] # [..., nblk, 1]
|
|
q = wb / scale # [..., nblk, BLOCK]
|
|
sign = (q < 0).astype(np.uint8)
|
|
mag = np.abs(q)
|
|
code = np.digitize(mag, FP4_MIDS).astype(np.uint8) # 0..7 nearest level
|
|
nib = (sign << 3) | code # 4-bit value
|
|
nib = nib.reshape(*lead, K)
|
|
lo = nib[..., 0::2]
|
|
hi = nib[..., 1::2]
|
|
packed = (lo | (hi << 4)).astype(np.uint8) # [..., K//2]
|
|
return packed, ue8m0.astype(np.uint8)
|
|
|
|
|
|
def dequant_mxfp4(packed, scales):
|
|
"""Inverse, for the self-test. packed [...,K//2] u8, scales [...,K//32] u8."""
|
|
*lead, Kh = packed.shape
|
|
K = Kh * 2
|
|
lo = packed & 0x0F
|
|
hi = (packed >> 4) & 0x0F
|
|
nib = np.empty((*lead, K), dtype=np.uint8)
|
|
nib[..., 0::2] = lo
|
|
nib[..., 1::2] = hi
|
|
sign = np.where((nib >> 3) & 1 == 1, -1.0, 1.0)
|
|
mag = FP4_LEVELS[nib & 0x7]
|
|
e = scales.astype(np.int32) - 127
|
|
scale = (2.0 ** e.astype(np.float32))
|
|
scale = np.repeat(scale, BLOCK, axis=-1)
|
|
return (sign * mag) * scale
|
|
|
|
|
|
def quant_expert_tensor(t):
|
|
# t: [E, K, N] bf16 -> store as [E, N, K] MXFP4 (packed, scales) u8.
|
|
# GPU path: numpy on 19G elements is minutes; torch-on-GPU is seconds.
|
|
import torch
|
|
dev = "cuda" if torch.cuda.is_available() else "cpu"
|
|
w = t.transpose(1, 2).contiguous().to(dev, torch.float32) # [E, N, K]
|
|
E, N, K = w.shape
|
|
nblk = K // BLOCK
|
|
wb = w.view(E, N, nblk, BLOCK)
|
|
amax = wb.abs().amax(-1) # [E, N, nblk]
|
|
amax_safe = torch.where(amax > 0, amax, torch.ones_like(amax))
|
|
e = torch.floor(torch.log2(amax_safe)) - 2
|
|
e = e.clamp(-127, 127)
|
|
ue8m0 = (e + 127).to(torch.uint8) # [E, N, nblk]
|
|
scale = torch.exp2(e).unsqueeze(-1) # [E, N, nblk, 1]
|
|
q = wb / scale
|
|
sign = (q < 0).to(torch.uint8)
|
|
mids = torch.tensor(FP4_MIDS.tolist(), device=dev)
|
|
code = torch.bucketize(q.abs(), mids).to(torch.uint8) # 0..7 nearest level
|
|
nib = ((sign << 3) | code).view(E, N, K)
|
|
lo = nib[..., 0::2]
|
|
hi = nib[..., 1::2]
|
|
packed = (lo | (hi << 4)).to(torch.uint8) # [E, N, K/2]
|
|
return packed.cpu(), ue8m0.view(E, N, nblk).cpu()
|
|
|
|
|
|
def _selftest():
|
|
rng = np.random.default_rng(0)
|
|
w = (rng.standard_normal((2, 64)) * 0.3).astype(np.float32)
|
|
p, s = quant_block_mxfp4(w)
|
|
r = dequant_mxfp4(p, s)
|
|
rel = np.abs(r - w).mean() / (np.abs(w).mean() + 1e-9)
|
|
print(f"[selftest] mean rel err {rel:.4f} (expect ~0.05-0.12 for FP4)")
|
|
assert rel < 0.2, "MXFP4 roundtrip error too high"
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("input_dir", type=Path)
|
|
ap.add_argument("output_dir", type=Path)
|
|
ap.add_argument("--selftest", action="store_true")
|
|
args = ap.parse_args()
|
|
if args.selftest:
|
|
_selftest(); return
|
|
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
out = args.output_dir; out.mkdir(parents=True, exist_ok=True)
|
|
cfg = json.load(open(args.input_dir / "config.json"))
|
|
files = sorted(args.input_dir.glob("*.safetensors"))
|
|
tensors = {}
|
|
for f in files:
|
|
tensors.update(load_file(str(f), device="cpu"))
|
|
print(f"loaded {len(tensors)} tensors")
|
|
|
|
out_t = {}
|
|
nq = 0
|
|
for name, t in tensors.items():
|
|
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.down_proj"):
|
|
print(f" mxfp4 {name} {list(t.shape)} -> [E,N,K] packed")
|
|
packed, scales = quant_expert_tensor(t)
|
|
# Store as raw bytes via float8_e4m3 container (xserv reads raw bytes).
|
|
out_t[name] = packed.view(torch.float8_e4m3fn)
|
|
out_t[name + "_scale"] = scales.view(torch.float8_e4m3fn)
|
|
nq += 1
|
|
else:
|
|
out_t[name] = t
|
|
print(f"quantized {nq} expert tensors to MXFP4")
|
|
|
|
save_file(out_t, str(out / "model.safetensors"))
|
|
cfg["quantization"] = "mxfp4_w4a16"
|
|
json.dump(cfg, open(out / "config.json", "w"), indent=2)
|
|
for src in args.input_dir.iterdir():
|
|
if src.suffix == ".safetensors" or src.name == "config.json":
|
|
continue
|
|
if src.is_file() and not (out / src.name).exists():
|
|
shutil.copy2(src, out / src.name)
|
|
print("done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|