diff --git a/crates/xserv-kernels/build.rs b/crates/xserv-kernels/build.rs index 4512550..e94f147 100644 --- a/crates/xserv-kernels/build.rs +++ b/crates/xserv-kernels/build.rs @@ -30,6 +30,7 @@ fn main() { .file("../../csrc/attention/paged_attention.cu") .file("../../csrc/attention/reshape_and_cache.cu") .file("../../csrc/moe/moe_kernels.cu") + .file("../../csrc/quantization/dequant_fp8.cu") .compile("xserv_kernels"); println!("cargo:rerun-if-changed=../../csrc/"); diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index ee19356..adf2ca1 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -6,6 +6,7 @@ pub mod embedding; pub mod gemm; pub mod layernorm; pub mod moe; +pub mod quantization; pub mod rmsnorm; pub mod rope; pub mod softmax; diff --git a/crates/xserv-kernels/src/quantization.rs b/crates/xserv-kernels/src/quantization.rs new file mode 100644 index 0000000..31db4ab --- /dev/null +++ b/crates/xserv-kernels/src/quantization.rs @@ -0,0 +1,46 @@ +use std::ffi::c_void; +use xserv_tensor::{DType, Tensor}; + +unsafe extern "C" { + fn launch_dequant_fp8e4m3_to_bf16( + src: *const c_void, + scales: *const c_void, + dst: *mut c_void, + num_experts: i32, rows: i32, cols: i32, + stream: *mut c_void, + ); +} + +/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales. +/// +/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU +/// scales: [num_experts] F32, contiguous, GPU +/// +/// Returns: [num_experts, rows, cols] BF16 +pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor { + assert_eq!(src.ndim(), 3, "dequant_fp8_to_bf16: src must be 3D"); + assert_eq!(src.dtype(), DType::FP8E4M3); + assert!(src.is_contiguous()); + assert_eq!(scales.ndim(), 1); + assert_eq!(scales.dtype(), DType::F32); + assert!(scales.is_contiguous()); + + let num_experts = src.shape()[0]; + let rows = src.shape()[1]; + let cols = src.shape()[2]; + assert_eq!(scales.shape()[0], num_experts); + + let out = Tensor::empty(&[num_experts, rows, cols], DType::BF16, src.device()); + + unsafe { + launch_dequant_fp8e4m3_to_bf16( + src.data_ptr() as *const c_void, + scales.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + num_experts as i32, rows as i32, cols as i32, + std::ptr::null_mut(), + ); + } + + out +} diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index ac15705..5b69eef 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -43,10 +43,15 @@ struct GptOssBlock { router_wt: Tensor, router_bias: Tensor, // 3D expert weights for batched GEMM (contiguous on GPU) - expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter] + expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter] BF16 expert_gate_up_bias: Tensor, // [local_experts, 2*inter] - expert_down_wt: Tensor, // [local_experts, inter, hidden] + expert_down_wt: Tensor, // [local_experts, inter, hidden] BF16 expert_down_bias: Tensor, // [local_experts, hidden] + // FP8 quantized expert weights (Some when running FP8 W8A16) + expert_gate_up_fp8: Option, // [local_experts, hidden, 2*inter] FP8E4M3 + expert_gate_up_scale: Option,// [local_experts] F32 + expert_down_fp8: Option, // [local_experts, inter, hidden] FP8E4M3 + expert_down_scale: Option, // [local_experts] F32 local_experts: usize, // Activation params glu_alpha: f32, @@ -156,17 +161,49 @@ impl GptOss { let down_3d = take(&mut w, &format!("{p}.mlp.experts.down_proj")); let down_bias_2d = take(&mut w, &format!("{p}.mlp.experts.down_proj_bias")); + // FP8 scale tensors (present only in FP8-quantized models) + let gate_up_scale = w.remove(&format!("{p}.mlp.experts.gate_up_proj_scale")); + let down_scale = w.remove(&format!("{p}.mlp.experts.down_proj_scale")); + let local_experts = num_experts / world; let expert_start = rank * local_experts; + let is_fp8 = gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3; + let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size let hidden = gate_up_3d.shape()[1]; let inter = down_3d.shape()[1]; // intermediate_size // Slice the rank's range of experts as contiguous 3D tensors on GPU - let expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev); + let expert_gate_up_wt; + let expert_down_wt; + let expert_gate_up_fp8; + let expert_gate_up_scale_gpu; + let expert_down_fp8; + let expert_down_scale_gpu; + + if is_fp8 { + // FP8 path: load quantized weights and scales + expert_gate_up_fp8 = Some(slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev)); + expert_down_fp8 = Some(slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev)); + // Scales: [num_experts] F32 → slice to [local_experts] + let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale"); + let d_s = down_scale.expect("FP8 model missing down_proj_scale"); + expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev)); + expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev)); + // Dummy BF16 tensors (never read in FP8 path) + expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); + expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); + } else { + // BF16 path: existing behavior + expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev); + expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev); + expert_gate_up_fp8 = None; + expert_gate_up_scale_gpu = None; + expert_down_fp8 = None; + expert_down_scale_gpu = None; + } let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev); - let expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev); let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev); xserv_cuda::allocator::cached_trim(); @@ -198,6 +235,10 @@ impl GptOss { expert_gate_up_bias, expert_down_wt, expert_down_bias, + expert_gate_up_fp8, + expert_gate_up_scale: expert_gate_up_scale_gpu, + expert_down_fp8, + expert_down_scale: expert_down_scale_gpu, local_experts, glu_alpha, glu_limit, @@ -208,10 +249,14 @@ impl GptOss { let local_num_kv_heads = config.num_kv_heads() / world; let has_norm_bias = norm_bias.is_some(); + let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false); if rank == 0 { if has_norm_bias { eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm"); } + if is_fp8 { + eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A16 mode)"); + } } // Warn about unused weights that the model didn't consume @@ -470,7 +515,12 @@ impl GptOss { let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts); // 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter] - let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt); + let gate_up_wt = if let Some(ref fp8) = layer.expert_gate_up_fp8 { + xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_gate_up_scale.as_ref().unwrap()) + } else { + layer.expert_gate_up_wt.clone() + }; + let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &gate_up_wt); // 5. Bias add: gate_up += expert_gate_up_bias (in-place) xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias); @@ -484,7 +534,12 @@ impl GptOss { let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]); // 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden] - let down = xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt); + let down_wt = if let Some(ref fp8) = layer.expert_down_fp8 { + xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_down_scale.as_ref().unwrap()) + } else { + layer.expert_down_wt.clone() + }; + let down = xserv_kernels::moe::batched_gemm_strided(&activated, &down_wt); // 8. Bias add: down += expert_down_bias (in-place) xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias); @@ -581,6 +636,28 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor { Tensor::from_slice(&shard, &[local]) } +/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes). +fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor { + assert_eq!(t.ndim(), 3); + let host = t.to_device(Device::Cpu); + let elem_size = t.dtype().size_bytes(); + let raw = host.as_raw_bytes(); + let stride = rows * cols * elem_size; + let offset = start * stride; + let slice = &raw[offset..offset + count * stride]; + Tensor::from_raw_bytes(slice, &[count, rows, cols], t.dtype()) +} + +/// Slice scale tensor [num_experts] F32 → [count] starting at `start`. +fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor { + assert_eq!(t.ndim(), 1); + assert_eq!(t.dtype(), xserv_tensor::DType::F32); + let host = t.to_device(Device::Cpu); + let data = host.as_slice::(); + let slice = data[start..start + count].to_vec(); + Tensor::from_slice(&slice, &[count]) +} + /// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor { assert_eq!(t.ndim(), 3); diff --git a/crates/xserv-model/src/loader.rs b/crates/xserv-model/src/loader.rs index 00c6815..77b0096 100644 --- a/crates/xserv-model/src/loader.rs +++ b/crates/xserv-model/src/loader.rs @@ -19,6 +19,7 @@ pub fn load_safetensors(path: &Path, device: Device) -> HashMap safetensors::Dtype::F32 => DType::F32, safetensors::Dtype::F16 => DType::F16, safetensors::Dtype::BF16 => DType::BF16, + safetensors::Dtype::F8_E4M3 => DType::FP8E4M3, other => { eprintln!("skipping tensor {name}: unsupported dtype {other:?}"); continue; @@ -83,5 +84,8 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor { }; Tensor::from_slice(bfs, shape) } + DType::FP8E4M3 => { + Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3) + } } } diff --git a/crates/xserv-tensor/src/dtype.rs b/crates/xserv-tensor/src/dtype.rs index 058f81b..95d90ec 100644 --- a/crates/xserv-tensor/src/dtype.rs +++ b/crates/xserv-tensor/src/dtype.rs @@ -5,6 +5,7 @@ pub enum DType { F32, F16, BF16, + FP8E4M3, } impl DType { @@ -13,6 +14,7 @@ impl DType { DType::F32 => 4, DType::F16 => 2, DType::BF16 => 2, + DType::FP8E4M3 => 1, } } @@ -21,6 +23,7 @@ impl DType { DType::F32 => "f32", DType::F16 => "f16", DType::BF16 => "bf16", + DType::FP8E4M3 => "fp8e4m3", } } } diff --git a/crates/xserv-tensor/src/tensor.rs b/crates/xserv-tensor/src/tensor.rs index 33965f6..8c88cfb 100644 --- a/crates/xserv-tensor/src/tensor.rs +++ b/crates/xserv-tensor/src/tensor.rs @@ -52,6 +52,25 @@ impl Tensor { } } + /// Create a tensor from raw bytes. Used for dtypes without a Rust type + /// (e.g. FP8 E4M3) where we store the bit pattern as-is. + pub fn from_raw_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Self { + let numel: usize = shape.iter().product(); + assert_eq!( + data.len(), + numel * dtype.size_bytes(), + "raw bytes length {} != expected {} (numel={} * elem_size={})", + data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes() + ); + Self { + storage: Storage::cpu(data.to_vec()), + shape: Dims::from_slice(shape), + strides: shape::contiguous_strides(shape), + offset: 0, + dtype, + } + } + pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self { let numel = shape::num_elements(shape); let len_bytes = numel * dtype.size_bytes(); @@ -87,6 +106,7 @@ impl Tensor { DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape), DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape), DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape), + DType::FP8E4M3 => panic!("ones() not supported for FP8E4M3"), } } @@ -265,6 +285,17 @@ impl Tensor { unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) } } + /// Raw byte access for dtypes without a Rust type (e.g. FP8). + pub fn as_raw_bytes(&self) -> &[u8] { + assert!(self.is_contiguous(), "as_raw_bytes requires contiguous"); + assert_eq!(self.device(), Device::Cpu, "as_raw_bytes requires CPU"); + let bytes = self.storage.as_cpu_bytes(); + let elem_size = self.dtype.size_bytes(); + let start = self.offset * elem_size; + let len = self.numel() * elem_size; + &bytes[start..start + len] + } + /// Raw pointer to storage start (for GPU kernel launch). pub fn data_ptr(&self) -> *const u8 { match self.device() { diff --git a/csrc/quantization/dequant_fp8.cu b/csrc/quantization/dequant_fp8.cu new file mode 100644 index 0000000..cd98ca5 --- /dev/null +++ b/csrc/quantization/dequant_fp8.cu @@ -0,0 +1,51 @@ +#include +#include +#include "../common.cuh" + +// Dequantize FP8 E4M3 → BF16 with per-expert (per-batch-slice) FP32 scale. +// +// Input: src [num_experts, rows, cols] FP8 E4M3 (1 byte each) +// scales [num_experts] FP32 +// Output: dst [num_experts, rows, cols] BF16 +// +// Each element: dst[e, r, c] = bf16( float(src[e, r, c]) * scales[e] ) + +__global__ void dequant_fp8e4m3_to_bf16_kernel( + const __nv_fp8_e4m3* __restrict__ src, + const float* __restrict__ scales, + __nv_bfloat16* __restrict__ dst, + int num_experts, int rows, int cols +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_experts * rows * cols; + if (idx >= total) return; + + int expert_stride = rows * cols; + int expert = idx / expert_stride; + float scale = scales[expert]; + float val = float(src[idx]) * scale; + dst[idx] = __float2bfloat16(val); +} + +extern "C" { + +void launch_dequant_fp8e4m3_to_bf16( + const void* src, + const void* scales, + void* dst, + int num_experts, int rows, int cols, + void* stream +) { + int total = num_experts * rows * cols; + int block = 256; + int grid = (total + block - 1) / block; + dequant_fp8e4m3_to_bf16_kernel<<>>( + (const __nv_fp8_e4m3*)src, + (const float*)scales, + (__nv_bfloat16*)dst, + num_experts, rows, cols + ); + CUDA_CHECK_LAST_ERROR(); +} + +} diff --git a/tools/eval_gsm8k_batch.sh b/tools/eval_gsm8k_batch.sh new file mode 100755 index 0000000..6b1315e --- /dev/null +++ b/tools/eval_gsm8k_batch.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# GSM8K evaluation via repeated xserv-chat invocations. +# Usage: eval_gsm8k_batch.sh [gpu_id] [tp] +set -uo pipefail +export PATH=/usr/local/cuda/bin:$PATH +source ~/.cargo/env 2>/dev/null || true +export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-} + +MODEL_DIR="${1:?Usage: $0 [gpu_id] [tp]}" +LIMIT="${2:-50}" +GPU="${3:-0}" +TP="${4:-1}" +export CUDA_VISIBLE_DEVICES=$GPU + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +XSERV_CHAT="$SCRIPT_DIR/../target/release/xserv-chat" +DATA="$SCRIPT_DIR/bench/data/gsm8k.json" +SYSTEM='Solve the problem step by step. Put your final numeric answer inside \\boxed{}.' + +echo "=== GSM8K Eval: model=$MODEL_DIR, limit=$LIMIT, gpu=$GPU, tp=$TP ===" + +TMPDIR=$(mktemp -d) +trap "rm -rf $TMPDIR" EXIT + +# Generate problem files +python3 -c " +import json +problems = json.load(open('$DATA'))[:$LIMIT] +for i, p in enumerate(problems): + with open(f'$TMPDIR/{i:04d}.txt', 'w') as f: + f.write(p['problem'].replace(chr(10), ' ')) + with open(f'$TMPDIR/{i:04d}.gold', 'w') as f: + f.write(p['answer']) +print(f'{len(problems)} problems prepared') +" + +TOTAL=$(ls "$TMPDIR"/*.txt 2>/dev/null | wc -l) +CORRECT=0 +SCORED=0 +START_TIME=$(date +%s) + +TP_FLAG="" +if [ "$TP" -gt 1 ]; then + TP_FLAG="--tp $TP" +fi + +for f in $(ls "$TMPDIR"/*.txt | sort); do + IDX=$(basename "$f" .txt) + GOLD=$(cat "$TMPDIR/${IDX}.gold") + QUESTION=$(cat "$f") + + # Run single-question xserv-chat + RAW_OUT=$(echo "$QUESTION" | timeout 120 "$XSERV_CHAT" "$MODEL_DIR" \ + --max-tokens 512 --max-seq-len 1024 \ + --system "$SYSTEM" --no-color $TP_FLAG 2>/dev/null || true) + + # Extract predicted answer + PRED=$(echo "$RAW_OUT" | python3 -c " +import re, sys +text = sys.stdin.read() +# Extract everything after 'assistant>' +if 'assistant>' in text: + text = text.split('assistant>', 1)[1] + if 'user>' in text: + text = text[:text.rindex('user>')] +boxed = re.findall(r'\\\\boxed\s*\{([^{}]*)\}', text) +if boxed: + nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', boxed[-1]) + if nums: + s = nums[-1].replace(',','') + f = float(s) + print(str(int(f)) if f == int(f) else f'{f:g}') + sys.exit(0) +nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text) +if nums: + s = nums[-1].replace(',','') + f = float(s) + print(str(int(f)) if f == int(f) else f'{f:g}') +else: + print('NONE') +" 2>/dev/null || echo "NONE") + + # Normalize gold + GOLD_NORM=$(python3 -c " +s='$GOLD'.replace(',','').strip() +f=float(s) +print(str(int(f)) if f==int(f) else f'{f:g}') +" 2>/dev/null || echo "$GOLD") + + SCORED=$((SCORED + 1)) + if [ "$PRED" = "$GOLD_NORM" ]; then + CORRECT=$((CORRECT + 1)) + echo "[✓] $IDX gold=$GOLD_NORM pred=$PRED" + elif [ "$PRED" = "NONE" ]; then + echo "[E] $IDX gold=$GOLD_NORM pred=NONE (no output)" + else + echo "[✗] $IDX gold=$GOLD_NORM pred=$PRED" + fi +done + +END_TIME=$(date +%s) +ELAPSED=$((END_TIME - START_TIME)) + +echo "------------------------------------------------------------------------" +python3 -c "print(f'Results: $CORRECT/$SCORED correct = {$CORRECT/$SCORED*100:.1f}% accuracy')" +echo "Wall time: ${ELAPSED}s (avg $((ELAPSED / TOTAL))s/problem)" +echo "=== Done ===" diff --git a/tools/quantize_fp8.py b/tools/quantize_fp8.py new file mode 100755 index 0000000..a8faa90 --- /dev/null +++ b/tools/quantize_fp8.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +"""Quantize gpt-oss expert weights from BF16 to FP8 E4M3 (W8A16). + +Usage: + python quantize_fp8.py + +Converts expert gate_up_proj and down_proj weights to FP8 E4M3 with +per-expert per-matrix FP32 scale factors. All other tensors (attention, +router, embeddings, norms, biases) are kept in BF16. + +The output directory contains: + - model.safetensors: quantized weights + - config.json: copy with "quantization": "fp8_e4m3" added + - All other files (tokenizer, etc.) copied as-is +""" + +import argparse +import json +import shutil +import sys +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file + + +FP8_E4M3_MAX = 448.0 # max representable value in FP8 E4M3 + + +def quantize_expert_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a [num_experts, rows, cols] BF16 tensor to FP8 E4M3. + + Returns (quantized_fp8, scales) where: + - quantized_fp8: [num_experts, rows, cols] torch.float8_e4m3fn + - scales: [num_experts] torch.float32 + """ + assert tensor.ndim == 3, f"expected 3D, got {tensor.ndim}D" + num_experts = tensor.shape[0] + + # Per-expert absmax scale + flat = tensor.view(num_experts, -1).float() + absmax = flat.abs().amax(dim=1) # [num_experts] + scales = absmax / FP8_E4M3_MAX + # Avoid division by zero for all-zero experts + scales = scales.clamp(min=1e-12) + + # Scale and cast to FP8 + # Reshape scales for broadcasting: [E, 1, 1] + scales_bc = scales.view(num_experts, 1, 1) + scaled = tensor.float() / scales_bc + quantized = scaled.to(torch.float8_e4m3fn) + + return quantized, scales + + +def main(): + parser = argparse.ArgumentParser(description="Quantize gpt-oss experts to FP8 E4M3") + parser.add_argument("input_dir", type=Path, help="Input model directory (BF16)") + parser.add_argument("output_dir", type=Path, help="Output model directory (FP8)") + args = parser.parse_args() + + input_dir = args.input_dir + output_dir = args.output_dir + + if not input_dir.exists(): + print(f"Error: input directory {input_dir} does not exist", file=sys.stderr) + sys.exit(1) + + output_dir.mkdir(parents=True, exist_ok=True) + + # Load config + config_path = input_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + num_layers = config.get("num_hidden_layers", 0) + num_experts = config.get("num_local_experts", 0) + print(f"Model: {num_layers} layers, {num_experts} experts per layer") + + # Load weights (may be sharded) + safetensor_files = sorted(input_dir.glob("*.safetensors")) + if not safetensor_files: + print("Error: no .safetensors files found", file=sys.stderr) + sys.exit(1) + + print(f"Loading {len(safetensor_files)} safetensors file(s)...") + all_tensors = {} + for sf in safetensor_files: + all_tensors.update(load_file(str(sf), device="cpu")) + print(f"Loaded {len(all_tensors)} tensors") + + # Quantize expert weights + quantized_count = 0 + output_tensors = {} + + for name, tensor in all_tensors.items(): + # Check if this is an expert weight to quantize + if ".mlp.experts.gate_up_proj" in name and name.endswith("gate_up_proj"): + print(f" Quantizing {name} {list(tensor.shape)} ...") + q, s = quantize_expert_tensor(tensor) + output_tensors[name] = q + output_tensors[name + "_scale"] = s + quantized_count += 1 + elif ".mlp.experts.down_proj" in name and name.endswith("down_proj"): + print(f" Quantizing {name} {list(tensor.shape)} ...") + q, s = quantize_expert_tensor(tensor) + output_tensors[name] = q + output_tensors[name + "_scale"] = s + quantized_count += 1 + else: + output_tensors[name] = tensor + + print(f"\nQuantized {quantized_count} expert weight tensors to FP8 E4M3") + + # Compute size savings + input_bytes = sum(t.numel() * t.element_size() for t in all_tensors.values()) + output_bytes = sum(t.numel() * t.element_size() for t in output_tensors.values()) + print(f"Size: {input_bytes / 1e9:.2f} GB → {output_bytes / 1e9:.2f} GB " + f"({(1 - output_bytes / input_bytes) * 100:.1f}% reduction)") + + # Save quantized model + output_safetensors = output_dir / "model.safetensors" + print(f"\nSaving to {output_safetensors} ...") + save_file(output_tensors, str(output_safetensors)) + + # Save modified config + config["quantization"] = "fp8_e4m3" + output_config = output_dir / "config.json" + with open(output_config, "w") as f: + json.dump(config, f, indent=2) + print(f"Saved config to {output_config}") + + # Copy other files + for src_file in input_dir.iterdir(): + if src_file.suffix == ".safetensors": + continue + if src_file.name == "config.json": + continue + dst_file = output_dir / src_file.name + if src_file.is_file() and not dst_file.exists(): + shutil.copy2(src_file, dst_file) + + print("\nDone!") + + +if __name__ == "__main__": + main()