#!/usr/bin/env python3 """Quantize gpt-oss expert weights BF16 -> MXFP4 (W4A16 weight-only). MXFP4 (OCP microscaling): blocks of 32 consecutive elements along the reduction (K) dimension share one UE8M0 (power-of-two) scale; each element is FP4 E2M1 (values {0,±0.5,±1,±1.5,±2,±3,±4,±6}). Effective ~4.25 bits/weight. The decode win is purely from reading 4-bit weights from HBM (half the FP8 traffic, a quarter of BF16); a fused kernel dequantizes on-chip to BF16. Output layout (per expert weight, already transposed to [E, N, K] so K is contiguous and block-of-32 friendly — matches the cuBLASLt-FP8 transpose): : uint8 [E, N, K//2] two E2M1 nibbles per byte (lo=even k) _scale : uint8 [E, N, K//32] UE8M0 per 32-element block Stored in safetensors as F8_E4M3 byte containers (xserv loads them as raw bytes). Usage: python quantize_mxfp4.py """ import argparse, json, shutil, sys from pathlib import Path import numpy as np # FP4 E2M1 representable magnitudes by 3-bit code (sign handled separately). FP4_LEVELS = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=np.float32) # Midpoints for round-to-nearest into the 8 magnitude levels. FP4_MIDS = (FP4_LEVELS[1:] + FP4_LEVELS[:-1]) / 2.0 # 7 thresholds BLOCK = 32 def quant_block_mxfp4(w): """w: [..., K] float32, K % 32 == 0. Returns (packed uint8 [...,K//2], scales uint8 [...,K//32]) using per-32 UE8M0 shared scale + E2M1 elements.""" *lead, K = w.shape nblk = K // BLOCK wb = w.reshape(*lead, nblk, BLOCK) amax = np.abs(wb).max(axis=-1) # [..., nblk] # Shared scale exponent = floor(log2(amax)) - 2 (emax_fp4 = floor(log2 6) = 2). with np.errstate(divide="ignore"): e = np.floor(np.log2(np.where(amax > 0, amax, 1.0))).astype(np.int32) - 2 e = np.clip(e, -127, 127) ue8m0 = (e + 127).astype(np.uint8) # [..., nblk] scale = (2.0 ** e.astype(np.float32))[..., None] # [..., nblk, 1] q = wb / scale # [..., nblk, BLOCK] sign = (q < 0).astype(np.uint8) mag = np.abs(q) code = np.digitize(mag, FP4_MIDS).astype(np.uint8) # 0..7 nearest level nib = (sign << 3) | code # 4-bit value nib = nib.reshape(*lead, K) lo = nib[..., 0::2] hi = nib[..., 1::2] packed = (lo | (hi << 4)).astype(np.uint8) # [..., K//2] return packed, ue8m0.astype(np.uint8) def dequant_mxfp4(packed, scales): """Inverse, for the self-test. packed [...,K//2] u8, scales [...,K//32] u8.""" *lead, Kh = packed.shape K = Kh * 2 lo = packed & 0x0F hi = (packed >> 4) & 0x0F nib = np.empty((*lead, K), dtype=np.uint8) nib[..., 0::2] = lo nib[..., 1::2] = hi sign = np.where((nib >> 3) & 1 == 1, -1.0, 1.0) mag = FP4_LEVELS[nib & 0x7] e = scales.astype(np.int32) - 127 scale = (2.0 ** e.astype(np.float32)) scale = np.repeat(scale, BLOCK, axis=-1) return (sign * mag) * scale def quant_expert_tensor(t): # t: [E, K, N] bf16 -> store as [E, N, K] MXFP4 (packed, scales) u8. # GPU path: numpy on 19G elements is minutes; torch-on-GPU is seconds. import torch dev = "cuda" if torch.cuda.is_available() else "cpu" w = t.transpose(1, 2).contiguous().to(dev, torch.float32) # [E, N, K] E, N, K = w.shape nblk = K // BLOCK wb = w.view(E, N, nblk, BLOCK) amax = wb.abs().amax(-1) # [E, N, nblk] amax_safe = torch.where(amax > 0, amax, torch.ones_like(amax)) e = torch.floor(torch.log2(amax_safe)) - 2 e = e.clamp(-127, 127) ue8m0 = (e + 127).to(torch.uint8) # [E, N, nblk] scale = torch.exp2(e).unsqueeze(-1) # [E, N, nblk, 1] q = wb / scale sign = (q < 0).to(torch.uint8) mids = torch.tensor(FP4_MIDS.tolist(), device=dev) code = torch.bucketize(q.abs(), mids).to(torch.uint8) # 0..7 nearest level nib = ((sign << 3) | code).view(E, N, K) lo = nib[..., 0::2] hi = nib[..., 1::2] packed = (lo | (hi << 4)).to(torch.uint8) # [E, N, K/2] return packed.cpu(), ue8m0.view(E, N, nblk).cpu() def _selftest(): rng = np.random.default_rng(0) w = (rng.standard_normal((2, 64)) * 0.3).astype(np.float32) p, s = quant_block_mxfp4(w) r = dequant_mxfp4(p, s) rel = np.abs(r - w).mean() / (np.abs(w).mean() + 1e-9) print(f"[selftest] mean rel err {rel:.4f} (expect ~0.05-0.12 for FP4)") assert rel < 0.2, "MXFP4 roundtrip error too high" def main(): ap = argparse.ArgumentParser() ap.add_argument("input_dir", type=Path) ap.add_argument("output_dir", type=Path) ap.add_argument("--selftest", action="store_true") args = ap.parse_args() if args.selftest: _selftest(); return import torch from safetensors.torch import load_file, save_file out = args.output_dir; out.mkdir(parents=True, exist_ok=True) cfg = json.load(open(args.input_dir / "config.json")) files = sorted(args.input_dir.glob("*.safetensors")) tensors = {} for f in files: tensors.update(load_file(str(f), device="cpu")) print(f"loaded {len(tensors)} tensors") out_t = {} nq = 0 for name, t in tensors.items(): if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.down_proj"): print(f" mxfp4 {name} {list(t.shape)} -> [E,N,K] packed") packed, scales = quant_expert_tensor(t) # Store as raw bytes via float8_e4m3 container (xserv reads raw bytes). out_t[name] = packed.view(torch.float8_e4m3fn) out_t[name + "_scale"] = scales.view(torch.float8_e4m3fn) nq += 1 else: out_t[name] = t print(f"quantized {nq} expert tensors to MXFP4") save_file(out_t, str(out / "model.safetensors")) cfg["quantization"] = "mxfp4_w4a16" json.dump(cfg, open(out / "config.json", "w"), indent=2) for src in args.input_dir.iterdir(): if src.suffix == ".safetensors" or src.name == "config.json": continue if src.is_file() and not (out / src.name).exists(): shutil.copy2(src, out / src.name) print("done.") if __name__ == "__main__": main()