quantization: add FP8 E4M3 W8A16 for gpt-oss MoE expert weights
Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem) with per-expert FP32 scale factors. At inference, a fused CUDA kernel dequantizes to BF16 before the existing cuBLAS batched GEMM. Results on gpt-oss-20b (50-problem GSM8K subset): - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM) - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total) No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%). New files: - tools/quantize_fp8.py: offline BF16→FP8 conversion script - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
107
tools/eval_gsm8k_batch.sh
Executable file
107
tools/eval_gsm8k_batch.sh
Executable file
@@ -0,0 +1,107 @@
|
||||
#!/bin/bash
|
||||
# GSM8K evaluation via repeated xserv-chat invocations.
|
||||
# Usage: eval_gsm8k_batch.sh <model-dir> <limit> [gpu_id] [tp]
|
||||
set -uo pipefail
|
||||
export PATH=/usr/local/cuda/bin:$PATH
|
||||
source ~/.cargo/env 2>/dev/null || true
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}
|
||||
|
||||
MODEL_DIR="${1:?Usage: $0 <model-dir> <limit> [gpu_id] [tp]}"
|
||||
LIMIT="${2:-50}"
|
||||
GPU="${3:-0}"
|
||||
TP="${4:-1}"
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
XSERV_CHAT="$SCRIPT_DIR/../target/release/xserv-chat"
|
||||
DATA="$SCRIPT_DIR/bench/data/gsm8k.json"
|
||||
SYSTEM='Solve the problem step by step. Put your final numeric answer inside \\boxed{}.'
|
||||
|
||||
echo "=== GSM8K Eval: model=$MODEL_DIR, limit=$LIMIT, gpu=$GPU, tp=$TP ==="
|
||||
|
||||
TMPDIR=$(mktemp -d)
|
||||
trap "rm -rf $TMPDIR" EXIT
|
||||
|
||||
# Generate problem files
|
||||
python3 -c "
|
||||
import json
|
||||
problems = json.load(open('$DATA'))[:$LIMIT]
|
||||
for i, p in enumerate(problems):
|
||||
with open(f'$TMPDIR/{i:04d}.txt', 'w') as f:
|
||||
f.write(p['problem'].replace(chr(10), ' '))
|
||||
with open(f'$TMPDIR/{i:04d}.gold', 'w') as f:
|
||||
f.write(p['answer'])
|
||||
print(f'{len(problems)} problems prepared')
|
||||
"
|
||||
|
||||
TOTAL=$(ls "$TMPDIR"/*.txt 2>/dev/null | wc -l)
|
||||
CORRECT=0
|
||||
SCORED=0
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
TP_FLAG=""
|
||||
if [ "$TP" -gt 1 ]; then
|
||||
TP_FLAG="--tp $TP"
|
||||
fi
|
||||
|
||||
for f in $(ls "$TMPDIR"/*.txt | sort); do
|
||||
IDX=$(basename "$f" .txt)
|
||||
GOLD=$(cat "$TMPDIR/${IDX}.gold")
|
||||
QUESTION=$(cat "$f")
|
||||
|
||||
# Run single-question xserv-chat
|
||||
RAW_OUT=$(echo "$QUESTION" | timeout 120 "$XSERV_CHAT" "$MODEL_DIR" \
|
||||
--max-tokens 512 --max-seq-len 1024 \
|
||||
--system "$SYSTEM" --no-color $TP_FLAG 2>/dev/null || true)
|
||||
|
||||
# Extract predicted answer
|
||||
PRED=$(echo "$RAW_OUT" | python3 -c "
|
||||
import re, sys
|
||||
text = sys.stdin.read()
|
||||
# Extract everything after 'assistant>'
|
||||
if 'assistant>' in text:
|
||||
text = text.split('assistant>', 1)[1]
|
||||
if 'user>' in text:
|
||||
text = text[:text.rindex('user>')]
|
||||
boxed = re.findall(r'\\\\boxed\s*\{([^{}]*)\}', text)
|
||||
if boxed:
|
||||
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', boxed[-1])
|
||||
if nums:
|
||||
s = nums[-1].replace(',','')
|
||||
f = float(s)
|
||||
print(str(int(f)) if f == int(f) else f'{f:g}')
|
||||
sys.exit(0)
|
||||
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
|
||||
if nums:
|
||||
s = nums[-1].replace(',','')
|
||||
f = float(s)
|
||||
print(str(int(f)) if f == int(f) else f'{f:g}')
|
||||
else:
|
||||
print('NONE')
|
||||
" 2>/dev/null || echo "NONE")
|
||||
|
||||
# Normalize gold
|
||||
GOLD_NORM=$(python3 -c "
|
||||
s='$GOLD'.replace(',','').strip()
|
||||
f=float(s)
|
||||
print(str(int(f)) if f==int(f) else f'{f:g}')
|
||||
" 2>/dev/null || echo "$GOLD")
|
||||
|
||||
SCORED=$((SCORED + 1))
|
||||
if [ "$PRED" = "$GOLD_NORM" ]; then
|
||||
CORRECT=$((CORRECT + 1))
|
||||
echo "[✓] $IDX gold=$GOLD_NORM pred=$PRED"
|
||||
elif [ "$PRED" = "NONE" ]; then
|
||||
echo "[E] $IDX gold=$GOLD_NORM pred=NONE (no output)"
|
||||
else
|
||||
echo "[✗] $IDX gold=$GOLD_NORM pred=$PRED"
|
||||
fi
|
||||
done
|
||||
|
||||
END_TIME=$(date +%s)
|
||||
ELAPSED=$((END_TIME - START_TIME))
|
||||
|
||||
echo "------------------------------------------------------------------------"
|
||||
python3 -c "print(f'Results: $CORRECT/$SCORED correct = {$CORRECT/$SCORED*100:.1f}% accuracy')"
|
||||
echo "Wall time: ${ELAPSED}s (avg $((ELAPSED / TOTAL))s/problem)"
|
||||
echo "=== Done ==="
|
||||
147
tools/quantize_fp8.py
Executable file
147
tools/quantize_fp8.py
Executable file
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Quantize gpt-oss expert weights from BF16 to FP8 E4M3 (W8A16).
|
||||
|
||||
Usage:
|
||||
python quantize_fp8.py <input_model_dir> <output_model_dir>
|
||||
|
||||
Converts expert gate_up_proj and down_proj weights to FP8 E4M3 with
|
||||
per-expert per-matrix FP32 scale factors. All other tensors (attention,
|
||||
router, embeddings, norms, biases) are kept in BF16.
|
||||
|
||||
The output directory contains:
|
||||
- model.safetensors: quantized weights
|
||||
- config.json: copy with "quantization": "fp8_e4m3" added
|
||||
- All other files (tokenizer, etc.) copied as-is
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
FP8_E4M3_MAX = 448.0 # max representable value in FP8 E4M3
|
||||
|
||||
|
||||
def quantize_expert_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize a [num_experts, rows, cols] BF16 tensor to FP8 E4M3.
|
||||
|
||||
Returns (quantized_fp8, scales) where:
|
||||
- quantized_fp8: [num_experts, rows, cols] torch.float8_e4m3fn
|
||||
- scales: [num_experts] torch.float32
|
||||
"""
|
||||
assert tensor.ndim == 3, f"expected 3D, got {tensor.ndim}D"
|
||||
num_experts = tensor.shape[0]
|
||||
|
||||
# Per-expert absmax scale
|
||||
flat = tensor.view(num_experts, -1).float()
|
||||
absmax = flat.abs().amax(dim=1) # [num_experts]
|
||||
scales = absmax / FP8_E4M3_MAX
|
||||
# Avoid division by zero for all-zero experts
|
||||
scales = scales.clamp(min=1e-12)
|
||||
|
||||
# Scale and cast to FP8
|
||||
# Reshape scales for broadcasting: [E, 1, 1]
|
||||
scales_bc = scales.view(num_experts, 1, 1)
|
||||
scaled = tensor.float() / scales_bc
|
||||
quantized = scaled.to(torch.float8_e4m3fn)
|
||||
|
||||
return quantized, scales
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Quantize gpt-oss experts to FP8 E4M3")
|
||||
parser.add_argument("input_dir", type=Path, help="Input model directory (BF16)")
|
||||
parser.add_argument("output_dir", type=Path, help="Output model directory (FP8)")
|
||||
args = parser.parse_args()
|
||||
|
||||
input_dir = args.input_dir
|
||||
output_dir = args.output_dir
|
||||
|
||||
if not input_dir.exists():
|
||||
print(f"Error: input directory {input_dir} does not exist", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load config
|
||||
config_path = input_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
num_layers = config.get("num_hidden_layers", 0)
|
||||
num_experts = config.get("num_local_experts", 0)
|
||||
print(f"Model: {num_layers} layers, {num_experts} experts per layer")
|
||||
|
||||
# Load weights (may be sharded)
|
||||
safetensor_files = sorted(input_dir.glob("*.safetensors"))
|
||||
if not safetensor_files:
|
||||
print("Error: no .safetensors files found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loading {len(safetensor_files)} safetensors file(s)...")
|
||||
all_tensors = {}
|
||||
for sf in safetensor_files:
|
||||
all_tensors.update(load_file(str(sf), device="cpu"))
|
||||
print(f"Loaded {len(all_tensors)} tensors")
|
||||
|
||||
# Quantize expert weights
|
||||
quantized_count = 0
|
||||
output_tensors = {}
|
||||
|
||||
for name, tensor in all_tensors.items():
|
||||
# Check if this is an expert weight to quantize
|
||||
if ".mlp.experts.gate_up_proj" in name and name.endswith("gate_up_proj"):
|
||||
print(f" Quantizing {name} {list(tensor.shape)} ...")
|
||||
q, s = quantize_expert_tensor(tensor)
|
||||
output_tensors[name] = q
|
||||
output_tensors[name + "_scale"] = s
|
||||
quantized_count += 1
|
||||
elif ".mlp.experts.down_proj" in name and name.endswith("down_proj"):
|
||||
print(f" Quantizing {name} {list(tensor.shape)} ...")
|
||||
q, s = quantize_expert_tensor(tensor)
|
||||
output_tensors[name] = q
|
||||
output_tensors[name + "_scale"] = s
|
||||
quantized_count += 1
|
||||
else:
|
||||
output_tensors[name] = tensor
|
||||
|
||||
print(f"\nQuantized {quantized_count} expert weight tensors to FP8 E4M3")
|
||||
|
||||
# Compute size savings
|
||||
input_bytes = sum(t.numel() * t.element_size() for t in all_tensors.values())
|
||||
output_bytes = sum(t.numel() * t.element_size() for t in output_tensors.values())
|
||||
print(f"Size: {input_bytes / 1e9:.2f} GB → {output_bytes / 1e9:.2f} GB "
|
||||
f"({(1 - output_bytes / input_bytes) * 100:.1f}% reduction)")
|
||||
|
||||
# Save quantized model
|
||||
output_safetensors = output_dir / "model.safetensors"
|
||||
print(f"\nSaving to {output_safetensors} ...")
|
||||
save_file(output_tensors, str(output_safetensors))
|
||||
|
||||
# Save modified config
|
||||
config["quantization"] = "fp8_e4m3"
|
||||
output_config = output_dir / "config.json"
|
||||
with open(output_config, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
print(f"Saved config to {output_config}")
|
||||
|
||||
# Copy other files
|
||||
for src_file in input_dir.iterdir():
|
||||
if src_file.suffix == ".safetensors":
|
||||
continue
|
||||
if src_file.name == "config.json":
|
||||
continue
|
||||
dst_file = output_dir / src_file.name
|
||||
if src_file.is_file() and not dst_file.exists():
|
||||
shutil.copy2(src_file, dst_file)
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user