Files
xserv/tools/quantize_fp8.py
Gahow Wang 9f1fbbb98b quantization: add FP8 E4M3 W8A16 for gpt-oss MoE expert weights
Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem)
with per-expert FP32 scale factors. At inference, a fused CUDA kernel
dequantizes to BF16 before the existing cuBLAS batched GEMM.

Results on gpt-oss-20b (50-problem GSM8K subset):
  - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM)
  - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total)

No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%).

New files:
  - tools/quantize_fp8.py: offline BF16→FP8 conversion script
  - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel
  - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper
  - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-07 19:33:07 +08:00

148 lines
5.0 KiB
Python
Executable File

#!/usr/bin/env python3
"""Quantize gpt-oss expert weights from BF16 to FP8 E4M3 (W8A16).
Usage:
python quantize_fp8.py <input_model_dir> <output_model_dir>
Converts expert gate_up_proj and down_proj weights to FP8 E4M3 with
per-expert per-matrix FP32 scale factors. All other tensors (attention,
router, embeddings, norms, biases) are kept in BF16.
The output directory contains:
- model.safetensors: quantized weights
- config.json: copy with "quantization": "fp8_e4m3" added
- All other files (tokenizer, etc.) copied as-is
"""
import argparse
import json
import shutil
import sys
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
FP8_E4M3_MAX = 448.0 # max representable value in FP8 E4M3
def quantize_expert_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize a [num_experts, rows, cols] BF16 tensor to FP8 E4M3.
Returns (quantized_fp8, scales) where:
- quantized_fp8: [num_experts, rows, cols] torch.float8_e4m3fn
- scales: [num_experts] torch.float32
"""
assert tensor.ndim == 3, f"expected 3D, got {tensor.ndim}D"
num_experts = tensor.shape[0]
# Per-expert absmax scale
flat = tensor.view(num_experts, -1).float()
absmax = flat.abs().amax(dim=1) # [num_experts]
scales = absmax / FP8_E4M3_MAX
# Avoid division by zero for all-zero experts
scales = scales.clamp(min=1e-12)
# Scale and cast to FP8
# Reshape scales for broadcasting: [E, 1, 1]
scales_bc = scales.view(num_experts, 1, 1)
scaled = tensor.float() / scales_bc
quantized = scaled.to(torch.float8_e4m3fn)
return quantized, scales
def main():
parser = argparse.ArgumentParser(description="Quantize gpt-oss experts to FP8 E4M3")
parser.add_argument("input_dir", type=Path, help="Input model directory (BF16)")
parser.add_argument("output_dir", type=Path, help="Output model directory (FP8)")
args = parser.parse_args()
input_dir = args.input_dir
output_dir = args.output_dir
if not input_dir.exists():
print(f"Error: input directory {input_dir} does not exist", file=sys.stderr)
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)
# Load config
config_path = input_dir / "config.json"
with open(config_path) as f:
config = json.load(f)
num_layers = config.get("num_hidden_layers", 0)
num_experts = config.get("num_local_experts", 0)
print(f"Model: {num_layers} layers, {num_experts} experts per layer")
# Load weights (may be sharded)
safetensor_files = sorted(input_dir.glob("*.safetensors"))
if not safetensor_files:
print("Error: no .safetensors files found", file=sys.stderr)
sys.exit(1)
print(f"Loading {len(safetensor_files)} safetensors file(s)...")
all_tensors = {}
for sf in safetensor_files:
all_tensors.update(load_file(str(sf), device="cpu"))
print(f"Loaded {len(all_tensors)} tensors")
# Quantize expert weights
quantized_count = 0
output_tensors = {}
for name, tensor in all_tensors.items():
# Check if this is an expert weight to quantize
if ".mlp.experts.gate_up_proj" in name and name.endswith("gate_up_proj"):
print(f" Quantizing {name} {list(tensor.shape)} ...")
q, s = quantize_expert_tensor(tensor)
output_tensors[name] = q
output_tensors[name + "_scale"] = s
quantized_count += 1
elif ".mlp.experts.down_proj" in name and name.endswith("down_proj"):
print(f" Quantizing {name} {list(tensor.shape)} ...")
q, s = quantize_expert_tensor(tensor)
output_tensors[name] = q
output_tensors[name + "_scale"] = s
quantized_count += 1
else:
output_tensors[name] = tensor
print(f"\nQuantized {quantized_count} expert weight tensors to FP8 E4M3")
# Compute size savings
input_bytes = sum(t.numel() * t.element_size() for t in all_tensors.values())
output_bytes = sum(t.numel() * t.element_size() for t in output_tensors.values())
print(f"Size: {input_bytes / 1e9:.2f} GB → {output_bytes / 1e9:.2f} GB "
f"({(1 - output_bytes / input_bytes) * 100:.1f}% reduction)")
# Save quantized model
output_safetensors = output_dir / "model.safetensors"
print(f"\nSaving to {output_safetensors} ...")
save_file(output_tensors, str(output_safetensors))
# Save modified config
config["quantization"] = "fp8_e4m3"
output_config = output_dir / "config.json"
with open(output_config, "w") as f:
json.dump(config, f, indent=2)
print(f"Saved config to {output_config}")
# Copy other files
for src_file in input_dir.iterdir():
if src_file.suffix == ".safetensors":
continue
if src_file.name == "config.json":
continue
dst_file = output_dir / src_file.name
if src_file.is_file() and not dst_file.exists():
shutil.copy2(src_file, dst_file)
print("\nDone!")
if __name__ == "__main__":
main()