#!/usr/bin/env python3 """Quantize gpt-oss expert weights from BF16 to FP8 E4M3 (W8A16). Usage: python quantize_fp8.py Converts expert gate_up_proj and down_proj weights to FP8 E4M3 with per-expert per-matrix FP32 scale factors. All other tensors (attention, router, embeddings, norms, biases) are kept in BF16. The output directory contains: - model.safetensors: quantized weights - config.json: copy with "quantization": "fp8_e4m3" added - All other files (tokenizer, etc.) copied as-is """ import argparse import json import shutil import sys from pathlib import Path import torch from safetensors.torch import load_file, save_file FP8_E4M3_MAX = 448.0 # max representable value in FP8 E4M3 def quantize_expert_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a [num_experts, rows, cols] BF16 tensor to FP8 E4M3. Returns (quantized_fp8, scales) where: - quantized_fp8: [num_experts, rows, cols] torch.float8_e4m3fn - scales: [num_experts] torch.float32 """ assert tensor.ndim == 3, f"expected 3D, got {tensor.ndim}D" num_experts = tensor.shape[0] # Per-expert absmax scale flat = tensor.view(num_experts, -1).float() absmax = flat.abs().amax(dim=1) # [num_experts] scales = absmax / FP8_E4M3_MAX # Avoid division by zero for all-zero experts scales = scales.clamp(min=1e-12) # Scale and cast to FP8 # Reshape scales for broadcasting: [E, 1, 1] scales_bc = scales.view(num_experts, 1, 1) scaled = tensor.float() / scales_bc quantized = scaled.to(torch.float8_e4m3fn) return quantized, scales def main(): parser = argparse.ArgumentParser(description="Quantize gpt-oss experts to FP8 E4M3") parser.add_argument("input_dir", type=Path, help="Input model directory (BF16)") parser.add_argument("output_dir", type=Path, help="Output model directory (FP8)") args = parser.parse_args() input_dir = args.input_dir output_dir = args.output_dir if not input_dir.exists(): print(f"Error: input directory {input_dir} does not exist", file=sys.stderr) sys.exit(1) output_dir.mkdir(parents=True, exist_ok=True) # Load config config_path = input_dir / "config.json" with open(config_path) as f: config = json.load(f) num_layers = config.get("num_hidden_layers", 0) num_experts = config.get("num_local_experts", 0) print(f"Model: {num_layers} layers, {num_experts} experts per layer") # Load weights (may be sharded) safetensor_files = sorted(input_dir.glob("*.safetensors")) if not safetensor_files: print("Error: no .safetensors files found", file=sys.stderr) sys.exit(1) print(f"Loading {len(safetensor_files)} safetensors file(s)...") all_tensors = {} for sf in safetensor_files: all_tensors.update(load_file(str(sf), device="cpu")) print(f"Loaded {len(all_tensors)} tensors") # Quantize expert weights quantized_count = 0 output_tensors = {} for name, tensor in all_tensors.items(): # Check if this is an expert weight to quantize if ".mlp.experts.gate_up_proj" in name and name.endswith("gate_up_proj"): print(f" Quantizing {name} {list(tensor.shape)} ...") q, s = quantize_expert_tensor(tensor) output_tensors[name] = q output_tensors[name + "_scale"] = s quantized_count += 1 elif ".mlp.experts.down_proj" in name and name.endswith("down_proj"): print(f" Quantizing {name} {list(tensor.shape)} ...") q, s = quantize_expert_tensor(tensor) output_tensors[name] = q output_tensors[name + "_scale"] = s quantized_count += 1 else: output_tensors[name] = tensor print(f"\nQuantized {quantized_count} expert weight tensors to FP8 E4M3") # Compute size savings input_bytes = sum(t.numel() * t.element_size() for t in all_tensors.values()) output_bytes = sum(t.numel() * t.element_size() for t in output_tensors.values()) print(f"Size: {input_bytes / 1e9:.2f} GB → {output_bytes / 1e9:.2f} GB " f"({(1 - output_bytes / input_bytes) * 100:.1f}% reduction)") # Save quantized model output_safetensors = output_dir / "model.safetensors" print(f"\nSaving to {output_safetensors} ...") save_file(output_tensors, str(output_safetensors)) # Save modified config config["quantization"] = "fp8_e4m3" output_config = output_dir / "config.json" with open(output_config, "w") as f: json.dump(config, f, indent=2) print(f"Saved config to {output_config}") # Copy other files for src_file in input_dir.iterdir(): if src_file.suffix == ".safetensors": continue if src_file.name == "config.json": continue dst_file = output_dir / src_file.name if src_file.is_file() and not dst_file.exists(): shutil.copy2(src_file, dst_file) print("\nDone!") if __name__ == "__main__": main()