quantization: add FP8 E4M3 W8A16 for gpt-oss MoE expert weights
Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem) with per-expert FP32 scale factors. At inference, a fused CUDA kernel dequantizes to BF16 before the existing cuBLAS batched GEMM. Results on gpt-oss-20b (50-problem GSM8K subset): - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM) - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total) No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%). New files: - tools/quantize_fp8.py: offline BF16→FP8 conversion script - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -30,6 +30,7 @@ fn main() {
|
|||||||
.file("../../csrc/attention/paged_attention.cu")
|
.file("../../csrc/attention/paged_attention.cu")
|
||||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||||
.file("../../csrc/moe/moe_kernels.cu")
|
.file("../../csrc/moe/moe_kernels.cu")
|
||||||
|
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||||
.compile("xserv_kernels");
|
.compile("xserv_kernels");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=../../csrc/");
|
println!("cargo:rerun-if-changed=../../csrc/");
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ pub mod embedding;
|
|||||||
pub mod gemm;
|
pub mod gemm;
|
||||||
pub mod layernorm;
|
pub mod layernorm;
|
||||||
pub mod moe;
|
pub mod moe;
|
||||||
|
pub mod quantization;
|
||||||
pub mod rmsnorm;
|
pub mod rmsnorm;
|
||||||
pub mod rope;
|
pub mod rope;
|
||||||
pub mod softmax;
|
pub mod softmax;
|
||||||
|
|||||||
46
crates/xserv-kernels/src/quantization.rs
Normal file
46
crates/xserv-kernels/src/quantization.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
use std::ffi::c_void;
|
||||||
|
use xserv_tensor::{DType, Tensor};
|
||||||
|
|
||||||
|
unsafe extern "C" {
|
||||||
|
fn launch_dequant_fp8e4m3_to_bf16(
|
||||||
|
src: *const c_void,
|
||||||
|
scales: *const c_void,
|
||||||
|
dst: *mut c_void,
|
||||||
|
num_experts: i32, rows: i32, cols: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dequantize a 3D FP8 E4M3 tensor to BF16 using per-expert FP32 scales.
|
||||||
|
///
|
||||||
|
/// src: [num_experts, rows, cols] FP8E4M3, contiguous, GPU
|
||||||
|
/// scales: [num_experts] F32, contiguous, GPU
|
||||||
|
///
|
||||||
|
/// Returns: [num_experts, rows, cols] BF16
|
||||||
|
pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
||||||
|
assert_eq!(src.ndim(), 3, "dequant_fp8_to_bf16: src must be 3D");
|
||||||
|
assert_eq!(src.dtype(), DType::FP8E4M3);
|
||||||
|
assert!(src.is_contiguous());
|
||||||
|
assert_eq!(scales.ndim(), 1);
|
||||||
|
assert_eq!(scales.dtype(), DType::F32);
|
||||||
|
assert!(scales.is_contiguous());
|
||||||
|
|
||||||
|
let num_experts = src.shape()[0];
|
||||||
|
let rows = src.shape()[1];
|
||||||
|
let cols = src.shape()[2];
|
||||||
|
assert_eq!(scales.shape()[0], num_experts);
|
||||||
|
|
||||||
|
let out = Tensor::empty(&[num_experts, rows, cols], DType::BF16, src.device());
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
launch_dequant_fp8e4m3_to_bf16(
|
||||||
|
src.data_ptr() as *const c_void,
|
||||||
|
scales.data_ptr() as *const c_void,
|
||||||
|
out.data_ptr() as *mut c_void,
|
||||||
|
num_experts as i32, rows as i32, cols as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
@@ -43,10 +43,15 @@ struct GptOssBlock {
|
|||||||
router_wt: Tensor,
|
router_wt: Tensor,
|
||||||
router_bias: Tensor,
|
router_bias: Tensor,
|
||||||
// 3D expert weights for batched GEMM (contiguous on GPU)
|
// 3D expert weights for batched GEMM (contiguous on GPU)
|
||||||
expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter]
|
expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter] BF16
|
||||||
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
|
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
|
||||||
expert_down_wt: Tensor, // [local_experts, inter, hidden]
|
expert_down_wt: Tensor, // [local_experts, inter, hidden] BF16
|
||||||
expert_down_bias: Tensor, // [local_experts, hidden]
|
expert_down_bias: Tensor, // [local_experts, hidden]
|
||||||
|
// FP8 quantized expert weights (Some when running FP8 W8A16)
|
||||||
|
expert_gate_up_fp8: Option<Tensor>, // [local_experts, hidden, 2*inter] FP8E4M3
|
||||||
|
expert_gate_up_scale: Option<Tensor>,// [local_experts] F32
|
||||||
|
expert_down_fp8: Option<Tensor>, // [local_experts, inter, hidden] FP8E4M3
|
||||||
|
expert_down_scale: Option<Tensor>, // [local_experts] F32
|
||||||
local_experts: usize,
|
local_experts: usize,
|
||||||
// Activation params
|
// Activation params
|
||||||
glu_alpha: f32,
|
glu_alpha: f32,
|
||||||
@@ -156,17 +161,49 @@ impl GptOss {
|
|||||||
let down_3d = take(&mut w, &format!("{p}.mlp.experts.down_proj"));
|
let down_3d = take(&mut w, &format!("{p}.mlp.experts.down_proj"));
|
||||||
let down_bias_2d = take(&mut w, &format!("{p}.mlp.experts.down_proj_bias"));
|
let down_bias_2d = take(&mut w, &format!("{p}.mlp.experts.down_proj_bias"));
|
||||||
|
|
||||||
|
// FP8 scale tensors (present only in FP8-quantized models)
|
||||||
|
let gate_up_scale = w.remove(&format!("{p}.mlp.experts.gate_up_proj_scale"));
|
||||||
|
let down_scale = w.remove(&format!("{p}.mlp.experts.down_proj_scale"));
|
||||||
|
|
||||||
let local_experts = num_experts / world;
|
let local_experts = num_experts / world;
|
||||||
let expert_start = rank * local_experts;
|
let expert_start = rank * local_experts;
|
||||||
|
|
||||||
|
let is_fp8 = gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3;
|
||||||
|
|
||||||
let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size
|
let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size
|
||||||
let hidden = gate_up_3d.shape()[1];
|
let hidden = gate_up_3d.shape()[1];
|
||||||
let inter = down_3d.shape()[1]; // intermediate_size
|
let inter = down_3d.shape()[1]; // intermediate_size
|
||||||
|
|
||||||
// Slice the rank's range of experts as contiguous 3D tensors on GPU
|
// Slice the rank's range of experts as contiguous 3D tensors on GPU
|
||||||
let expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
|
let expert_gate_up_wt;
|
||||||
|
let expert_down_wt;
|
||||||
|
let expert_gate_up_fp8;
|
||||||
|
let expert_gate_up_scale_gpu;
|
||||||
|
let expert_down_fp8;
|
||||||
|
let expert_down_scale_gpu;
|
||||||
|
|
||||||
|
if is_fp8 {
|
||||||
|
// FP8 path: load quantized weights and scales
|
||||||
|
expert_gate_up_fp8 = Some(slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev));
|
||||||
|
expert_down_fp8 = Some(slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev));
|
||||||
|
// Scales: [num_experts] F32 → slice to [local_experts]
|
||||||
|
let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale");
|
||||||
|
let d_s = down_scale.expect("FP8 model missing down_proj_scale");
|
||||||
|
expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev));
|
||||||
|
expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev));
|
||||||
|
// Dummy BF16 tensors (never read in FP8 path)
|
||||||
|
expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||||
|
expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev);
|
||||||
|
} else {
|
||||||
|
// BF16 path: existing behavior
|
||||||
|
expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
|
||||||
|
expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
|
||||||
|
expert_gate_up_fp8 = None;
|
||||||
|
expert_gate_up_scale_gpu = None;
|
||||||
|
expert_down_fp8 = None;
|
||||||
|
expert_down_scale_gpu = None;
|
||||||
|
}
|
||||||
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
|
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
|
||||||
let expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
|
|
||||||
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
|
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
|
||||||
|
|
||||||
xserv_cuda::allocator::cached_trim();
|
xserv_cuda::allocator::cached_trim();
|
||||||
@@ -198,6 +235,10 @@ impl GptOss {
|
|||||||
expert_gate_up_bias,
|
expert_gate_up_bias,
|
||||||
expert_down_wt,
|
expert_down_wt,
|
||||||
expert_down_bias,
|
expert_down_bias,
|
||||||
|
expert_gate_up_fp8,
|
||||||
|
expert_gate_up_scale: expert_gate_up_scale_gpu,
|
||||||
|
expert_down_fp8,
|
||||||
|
expert_down_scale: expert_down_scale_gpu,
|
||||||
local_experts,
|
local_experts,
|
||||||
glu_alpha,
|
glu_alpha,
|
||||||
glu_limit,
|
glu_limit,
|
||||||
@@ -208,10 +249,14 @@ impl GptOss {
|
|||||||
let local_num_kv_heads = config.num_kv_heads() / world;
|
let local_num_kv_heads = config.num_kv_heads() / world;
|
||||||
|
|
||||||
let has_norm_bias = norm_bias.is_some();
|
let has_norm_bias = norm_bias.is_some();
|
||||||
|
let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false);
|
||||||
if rank == 0 {
|
if rank == 0 {
|
||||||
if has_norm_bias {
|
if has_norm_bias {
|
||||||
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
|
||||||
}
|
}
|
||||||
|
if is_fp8 {
|
||||||
|
eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A16 mode)");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Warn about unused weights that the model didn't consume
|
// Warn about unused weights that the model didn't consume
|
||||||
@@ -470,7 +515,12 @@ impl GptOss {
|
|||||||
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
||||||
|
|
||||||
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
|
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
|
||||||
let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt);
|
let gate_up_wt = if let Some(ref fp8) = layer.expert_gate_up_fp8 {
|
||||||
|
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_gate_up_scale.as_ref().unwrap())
|
||||||
|
} else {
|
||||||
|
layer.expert_gate_up_wt.clone()
|
||||||
|
};
|
||||||
|
let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &gate_up_wt);
|
||||||
|
|
||||||
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
|
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
|
||||||
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
|
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
|
||||||
@@ -484,7 +534,12 @@ impl GptOss {
|
|||||||
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
|
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
|
||||||
|
|
||||||
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
|
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
|
||||||
let down = xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt);
|
let down_wt = if let Some(ref fp8) = layer.expert_down_fp8 {
|
||||||
|
xserv_kernels::quantization::dequant_fp8_to_bf16(fp8, layer.expert_down_scale.as_ref().unwrap())
|
||||||
|
} else {
|
||||||
|
layer.expert_down_wt.clone()
|
||||||
|
};
|
||||||
|
let down = xserv_kernels::moe::batched_gemm_strided(&activated, &down_wt);
|
||||||
|
|
||||||
// 8. Bias add: down += expert_down_bias (in-place)
|
// 8. Bias add: down += expert_down_bias (in-place)
|
||||||
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
|
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
|
||||||
@@ -581,6 +636,28 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
|
|||||||
Tensor::from_slice(&shard, &[local])
|
Tensor::from_slice(&shard, &[local])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes).
|
||||||
|
fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||||
|
assert_eq!(t.ndim(), 3);
|
||||||
|
let host = t.to_device(Device::Cpu);
|
||||||
|
let elem_size = t.dtype().size_bytes();
|
||||||
|
let raw = host.as_raw_bytes();
|
||||||
|
let stride = rows * cols * elem_size;
|
||||||
|
let offset = start * stride;
|
||||||
|
let slice = &raw[offset..offset + count * stride];
|
||||||
|
Tensor::from_raw_bytes(slice, &[count, rows, cols], t.dtype())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Slice scale tensor [num_experts] F32 → [count] starting at `start`.
|
||||||
|
fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor {
|
||||||
|
assert_eq!(t.ndim(), 1);
|
||||||
|
assert_eq!(t.dtype(), xserv_tensor::DType::F32);
|
||||||
|
let host = t.to_device(Device::Cpu);
|
||||||
|
let data = host.as_slice::<f32>();
|
||||||
|
let slice = data[start..start + count].to_vec();
|
||||||
|
Tensor::from_slice(&slice, &[count])
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
|
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
|
||||||
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
|
||||||
assert_eq!(t.ndim(), 3);
|
assert_eq!(t.ndim(), 3);
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ pub fn load_safetensors(path: &Path, device: Device) -> HashMap<String, Tensor>
|
|||||||
safetensors::Dtype::F32 => DType::F32,
|
safetensors::Dtype::F32 => DType::F32,
|
||||||
safetensors::Dtype::F16 => DType::F16,
|
safetensors::Dtype::F16 => DType::F16,
|
||||||
safetensors::Dtype::BF16 => DType::BF16,
|
safetensors::Dtype::BF16 => DType::BF16,
|
||||||
|
safetensors::Dtype::F8_E4M3 => DType::FP8E4M3,
|
||||||
other => {
|
other => {
|
||||||
eprintln!("skipping tensor {name}: unsupported dtype {other:?}");
|
eprintln!("skipping tensor {name}: unsupported dtype {other:?}");
|
||||||
continue;
|
continue;
|
||||||
@@ -83,5 +84,8 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor {
|
|||||||
};
|
};
|
||||||
Tensor::from_slice(bfs, shape)
|
Tensor::from_slice(bfs, shape)
|
||||||
}
|
}
|
||||||
|
DType::FP8E4M3 => {
|
||||||
|
Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ pub enum DType {
|
|||||||
F32,
|
F32,
|
||||||
F16,
|
F16,
|
||||||
BF16,
|
BF16,
|
||||||
|
FP8E4M3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DType {
|
impl DType {
|
||||||
@@ -13,6 +14,7 @@ impl DType {
|
|||||||
DType::F32 => 4,
|
DType::F32 => 4,
|
||||||
DType::F16 => 2,
|
DType::F16 => 2,
|
||||||
DType::BF16 => 2,
|
DType::BF16 => 2,
|
||||||
|
DType::FP8E4M3 => 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,6 +23,7 @@ impl DType {
|
|||||||
DType::F32 => "f32",
|
DType::F32 => "f32",
|
||||||
DType::F16 => "f16",
|
DType::F16 => "f16",
|
||||||
DType::BF16 => "bf16",
|
DType::BF16 => "bf16",
|
||||||
|
DType::FP8E4M3 => "fp8e4m3",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,25 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a tensor from raw bytes. Used for dtypes without a Rust type
|
||||||
|
/// (e.g. FP8 E4M3) where we store the bit pattern as-is.
|
||||||
|
pub fn from_raw_bytes(data: &[u8], shape: &[usize], dtype: DType) -> Self {
|
||||||
|
let numel: usize = shape.iter().product();
|
||||||
|
assert_eq!(
|
||||||
|
data.len(),
|
||||||
|
numel * dtype.size_bytes(),
|
||||||
|
"raw bytes length {} != expected {} (numel={} * elem_size={})",
|
||||||
|
data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes()
|
||||||
|
);
|
||||||
|
Self {
|
||||||
|
storage: Storage::cpu(data.to_vec()),
|
||||||
|
shape: Dims::from_slice(shape),
|
||||||
|
strides: shape::contiguous_strides(shape),
|
||||||
|
offset: 0,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
pub fn zeros(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||||
let numel = shape::num_elements(shape);
|
let numel = shape::num_elements(shape);
|
||||||
let len_bytes = numel * dtype.size_bytes();
|
let len_bytes = numel * dtype.size_bytes();
|
||||||
@@ -87,6 +106,7 @@ impl Tensor {
|
|||||||
DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape),
|
DType::F32 => Self::from_slice(&vec![1.0f32; numel], shape),
|
||||||
DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape),
|
DType::F16 => Self::from_slice(&vec![half::f16::from_f32(1.0); numel], shape),
|
||||||
DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape),
|
DType::BF16 => Self::from_slice(&vec![half::bf16::from_f32(1.0); numel], shape),
|
||||||
|
DType::FP8E4M3 => panic!("ones() not supported for FP8E4M3"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,6 +285,17 @@ impl Tensor {
|
|||||||
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) }
|
unsafe { std::slice::from_raw_parts(bytes[start..].as_ptr() as *const T, len) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Raw byte access for dtypes without a Rust type (e.g. FP8).
|
||||||
|
pub fn as_raw_bytes(&self) -> &[u8] {
|
||||||
|
assert!(self.is_contiguous(), "as_raw_bytes requires contiguous");
|
||||||
|
assert_eq!(self.device(), Device::Cpu, "as_raw_bytes requires CPU");
|
||||||
|
let bytes = self.storage.as_cpu_bytes();
|
||||||
|
let elem_size = self.dtype.size_bytes();
|
||||||
|
let start = self.offset * elem_size;
|
||||||
|
let len = self.numel() * elem_size;
|
||||||
|
&bytes[start..start + len]
|
||||||
|
}
|
||||||
|
|
||||||
/// Raw pointer to storage start (for GPU kernel launch).
|
/// Raw pointer to storage start (for GPU kernel launch).
|
||||||
pub fn data_ptr(&self) -> *const u8 {
|
pub fn data_ptr(&self) -> *const u8 {
|
||||||
match self.device() {
|
match self.device() {
|
||||||
|
|||||||
51
csrc/quantization/dequant_fp8.cu
Normal file
51
csrc/quantization/dequant_fp8.cu
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
|
// Dequantize FP8 E4M3 → BF16 with per-expert (per-batch-slice) FP32 scale.
|
||||||
|
//
|
||||||
|
// Input: src [num_experts, rows, cols] FP8 E4M3 (1 byte each)
|
||||||
|
// scales [num_experts] FP32
|
||||||
|
// Output: dst [num_experts, rows, cols] BF16
|
||||||
|
//
|
||||||
|
// Each element: dst[e, r, c] = bf16( float(src[e, r, c]) * scales[e] )
|
||||||
|
|
||||||
|
__global__ void dequant_fp8e4m3_to_bf16_kernel(
|
||||||
|
const __nv_fp8_e4m3* __restrict__ src,
|
||||||
|
const float* __restrict__ scales,
|
||||||
|
__nv_bfloat16* __restrict__ dst,
|
||||||
|
int num_experts, int rows, int cols
|
||||||
|
) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int total = num_experts * rows * cols;
|
||||||
|
if (idx >= total) return;
|
||||||
|
|
||||||
|
int expert_stride = rows * cols;
|
||||||
|
int expert = idx / expert_stride;
|
||||||
|
float scale = scales[expert];
|
||||||
|
float val = float(src[idx]) * scale;
|
||||||
|
dst[idx] = __float2bfloat16(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void launch_dequant_fp8e4m3_to_bf16(
|
||||||
|
const void* src,
|
||||||
|
const void* scales,
|
||||||
|
void* dst,
|
||||||
|
int num_experts, int rows, int cols,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
int total = num_experts * rows * cols;
|
||||||
|
int block = 256;
|
||||||
|
int grid = (total + block - 1) / block;
|
||||||
|
dequant_fp8e4m3_to_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_fp8_e4m3*)src,
|
||||||
|
(const float*)scales,
|
||||||
|
(__nv_bfloat16*)dst,
|
||||||
|
num_experts, rows, cols
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
107
tools/eval_gsm8k_batch.sh
Executable file
107
tools/eval_gsm8k_batch.sh
Executable file
@@ -0,0 +1,107 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# GSM8K evaluation via repeated xserv-chat invocations.
|
||||||
|
# Usage: eval_gsm8k_batch.sh <model-dir> <limit> [gpu_id] [tp]
|
||||||
|
set -uo pipefail
|
||||||
|
export PATH=/usr/local/cuda/bin:$PATH
|
||||||
|
source ~/.cargo/env 2>/dev/null || true
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}
|
||||||
|
|
||||||
|
MODEL_DIR="${1:?Usage: $0 <model-dir> <limit> [gpu_id] [tp]}"
|
||||||
|
LIMIT="${2:-50}"
|
||||||
|
GPU="${3:-0}"
|
||||||
|
TP="${4:-1}"
|
||||||
|
export CUDA_VISIBLE_DEVICES=$GPU
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
XSERV_CHAT="$SCRIPT_DIR/../target/release/xserv-chat"
|
||||||
|
DATA="$SCRIPT_DIR/bench/data/gsm8k.json"
|
||||||
|
SYSTEM='Solve the problem step by step. Put your final numeric answer inside \\boxed{}.'
|
||||||
|
|
||||||
|
echo "=== GSM8K Eval: model=$MODEL_DIR, limit=$LIMIT, gpu=$GPU, tp=$TP ==="
|
||||||
|
|
||||||
|
TMPDIR=$(mktemp -d)
|
||||||
|
trap "rm -rf $TMPDIR" EXIT
|
||||||
|
|
||||||
|
# Generate problem files
|
||||||
|
python3 -c "
|
||||||
|
import json
|
||||||
|
problems = json.load(open('$DATA'))[:$LIMIT]
|
||||||
|
for i, p in enumerate(problems):
|
||||||
|
with open(f'$TMPDIR/{i:04d}.txt', 'w') as f:
|
||||||
|
f.write(p['problem'].replace(chr(10), ' '))
|
||||||
|
with open(f'$TMPDIR/{i:04d}.gold', 'w') as f:
|
||||||
|
f.write(p['answer'])
|
||||||
|
print(f'{len(problems)} problems prepared')
|
||||||
|
"
|
||||||
|
|
||||||
|
TOTAL=$(ls "$TMPDIR"/*.txt 2>/dev/null | wc -l)
|
||||||
|
CORRECT=0
|
||||||
|
SCORED=0
|
||||||
|
START_TIME=$(date +%s)
|
||||||
|
|
||||||
|
TP_FLAG=""
|
||||||
|
if [ "$TP" -gt 1 ]; then
|
||||||
|
TP_FLAG="--tp $TP"
|
||||||
|
fi
|
||||||
|
|
||||||
|
for f in $(ls "$TMPDIR"/*.txt | sort); do
|
||||||
|
IDX=$(basename "$f" .txt)
|
||||||
|
GOLD=$(cat "$TMPDIR/${IDX}.gold")
|
||||||
|
QUESTION=$(cat "$f")
|
||||||
|
|
||||||
|
# Run single-question xserv-chat
|
||||||
|
RAW_OUT=$(echo "$QUESTION" | timeout 120 "$XSERV_CHAT" "$MODEL_DIR" \
|
||||||
|
--max-tokens 512 --max-seq-len 1024 \
|
||||||
|
--system "$SYSTEM" --no-color $TP_FLAG 2>/dev/null || true)
|
||||||
|
|
||||||
|
# Extract predicted answer
|
||||||
|
PRED=$(echo "$RAW_OUT" | python3 -c "
|
||||||
|
import re, sys
|
||||||
|
text = sys.stdin.read()
|
||||||
|
# Extract everything after 'assistant>'
|
||||||
|
if 'assistant>' in text:
|
||||||
|
text = text.split('assistant>', 1)[1]
|
||||||
|
if 'user>' in text:
|
||||||
|
text = text[:text.rindex('user>')]
|
||||||
|
boxed = re.findall(r'\\\\boxed\s*\{([^{}]*)\}', text)
|
||||||
|
if boxed:
|
||||||
|
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', boxed[-1])
|
||||||
|
if nums:
|
||||||
|
s = nums[-1].replace(',','')
|
||||||
|
f = float(s)
|
||||||
|
print(str(int(f)) if f == int(f) else f'{f:g}')
|
||||||
|
sys.exit(0)
|
||||||
|
nums = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
|
||||||
|
if nums:
|
||||||
|
s = nums[-1].replace(',','')
|
||||||
|
f = float(s)
|
||||||
|
print(str(int(f)) if f == int(f) else f'{f:g}')
|
||||||
|
else:
|
||||||
|
print('NONE')
|
||||||
|
" 2>/dev/null || echo "NONE")
|
||||||
|
|
||||||
|
# Normalize gold
|
||||||
|
GOLD_NORM=$(python3 -c "
|
||||||
|
s='$GOLD'.replace(',','').strip()
|
||||||
|
f=float(s)
|
||||||
|
print(str(int(f)) if f==int(f) else f'{f:g}')
|
||||||
|
" 2>/dev/null || echo "$GOLD")
|
||||||
|
|
||||||
|
SCORED=$((SCORED + 1))
|
||||||
|
if [ "$PRED" = "$GOLD_NORM" ]; then
|
||||||
|
CORRECT=$((CORRECT + 1))
|
||||||
|
echo "[✓] $IDX gold=$GOLD_NORM pred=$PRED"
|
||||||
|
elif [ "$PRED" = "NONE" ]; then
|
||||||
|
echo "[E] $IDX gold=$GOLD_NORM pred=NONE (no output)"
|
||||||
|
else
|
||||||
|
echo "[✗] $IDX gold=$GOLD_NORM pred=$PRED"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
END_TIME=$(date +%s)
|
||||||
|
ELAPSED=$((END_TIME - START_TIME))
|
||||||
|
|
||||||
|
echo "------------------------------------------------------------------------"
|
||||||
|
python3 -c "print(f'Results: $CORRECT/$SCORED correct = {$CORRECT/$SCORED*100:.1f}% accuracy')"
|
||||||
|
echo "Wall time: ${ELAPSED}s (avg $((ELAPSED / TOTAL))s/problem)"
|
||||||
|
echo "=== Done ==="
|
||||||
147
tools/quantize_fp8.py
Executable file
147
tools/quantize_fp8.py
Executable file
@@ -0,0 +1,147 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Quantize gpt-oss expert weights from BF16 to FP8 E4M3 (W8A16).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python quantize_fp8.py <input_model_dir> <output_model_dir>
|
||||||
|
|
||||||
|
Converts expert gate_up_proj and down_proj weights to FP8 E4M3 with
|
||||||
|
per-expert per-matrix FP32 scale factors. All other tensors (attention,
|
||||||
|
router, embeddings, norms, biases) are kept in BF16.
|
||||||
|
|
||||||
|
The output directory contains:
|
||||||
|
- model.safetensors: quantized weights
|
||||||
|
- config.json: copy with "quantization": "fp8_e4m3" added
|
||||||
|
- All other files (tokenizer, etc.) copied as-is
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
|
||||||
|
FP8_E4M3_MAX = 448.0 # max representable value in FP8 E4M3
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_expert_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Quantize a [num_experts, rows, cols] BF16 tensor to FP8 E4M3.
|
||||||
|
|
||||||
|
Returns (quantized_fp8, scales) where:
|
||||||
|
- quantized_fp8: [num_experts, rows, cols] torch.float8_e4m3fn
|
||||||
|
- scales: [num_experts] torch.float32
|
||||||
|
"""
|
||||||
|
assert tensor.ndim == 3, f"expected 3D, got {tensor.ndim}D"
|
||||||
|
num_experts = tensor.shape[0]
|
||||||
|
|
||||||
|
# Per-expert absmax scale
|
||||||
|
flat = tensor.view(num_experts, -1).float()
|
||||||
|
absmax = flat.abs().amax(dim=1) # [num_experts]
|
||||||
|
scales = absmax / FP8_E4M3_MAX
|
||||||
|
# Avoid division by zero for all-zero experts
|
||||||
|
scales = scales.clamp(min=1e-12)
|
||||||
|
|
||||||
|
# Scale and cast to FP8
|
||||||
|
# Reshape scales for broadcasting: [E, 1, 1]
|
||||||
|
scales_bc = scales.view(num_experts, 1, 1)
|
||||||
|
scaled = tensor.float() / scales_bc
|
||||||
|
quantized = scaled.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
return quantized, scales
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Quantize gpt-oss experts to FP8 E4M3")
|
||||||
|
parser.add_argument("input_dir", type=Path, help="Input model directory (BF16)")
|
||||||
|
parser.add_argument("output_dir", type=Path, help="Output model directory (FP8)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_dir = args.input_dir
|
||||||
|
output_dir = args.output_dir
|
||||||
|
|
||||||
|
if not input_dir.exists():
|
||||||
|
print(f"Error: input directory {input_dir} does not exist", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config_path = input_dir / "config.json"
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
num_layers = config.get("num_hidden_layers", 0)
|
||||||
|
num_experts = config.get("num_local_experts", 0)
|
||||||
|
print(f"Model: {num_layers} layers, {num_experts} experts per layer")
|
||||||
|
|
||||||
|
# Load weights (may be sharded)
|
||||||
|
safetensor_files = sorted(input_dir.glob("*.safetensors"))
|
||||||
|
if not safetensor_files:
|
||||||
|
print("Error: no .safetensors files found", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loading {len(safetensor_files)} safetensors file(s)...")
|
||||||
|
all_tensors = {}
|
||||||
|
for sf in safetensor_files:
|
||||||
|
all_tensors.update(load_file(str(sf), device="cpu"))
|
||||||
|
print(f"Loaded {len(all_tensors)} tensors")
|
||||||
|
|
||||||
|
# Quantize expert weights
|
||||||
|
quantized_count = 0
|
||||||
|
output_tensors = {}
|
||||||
|
|
||||||
|
for name, tensor in all_tensors.items():
|
||||||
|
# Check if this is an expert weight to quantize
|
||||||
|
if ".mlp.experts.gate_up_proj" in name and name.endswith("gate_up_proj"):
|
||||||
|
print(f" Quantizing {name} {list(tensor.shape)} ...")
|
||||||
|
q, s = quantize_expert_tensor(tensor)
|
||||||
|
output_tensors[name] = q
|
||||||
|
output_tensors[name + "_scale"] = s
|
||||||
|
quantized_count += 1
|
||||||
|
elif ".mlp.experts.down_proj" in name and name.endswith("down_proj"):
|
||||||
|
print(f" Quantizing {name} {list(tensor.shape)} ...")
|
||||||
|
q, s = quantize_expert_tensor(tensor)
|
||||||
|
output_tensors[name] = q
|
||||||
|
output_tensors[name + "_scale"] = s
|
||||||
|
quantized_count += 1
|
||||||
|
else:
|
||||||
|
output_tensors[name] = tensor
|
||||||
|
|
||||||
|
print(f"\nQuantized {quantized_count} expert weight tensors to FP8 E4M3")
|
||||||
|
|
||||||
|
# Compute size savings
|
||||||
|
input_bytes = sum(t.numel() * t.element_size() for t in all_tensors.values())
|
||||||
|
output_bytes = sum(t.numel() * t.element_size() for t in output_tensors.values())
|
||||||
|
print(f"Size: {input_bytes / 1e9:.2f} GB → {output_bytes / 1e9:.2f} GB "
|
||||||
|
f"({(1 - output_bytes / input_bytes) * 100:.1f}% reduction)")
|
||||||
|
|
||||||
|
# Save quantized model
|
||||||
|
output_safetensors = output_dir / "model.safetensors"
|
||||||
|
print(f"\nSaving to {output_safetensors} ...")
|
||||||
|
save_file(output_tensors, str(output_safetensors))
|
||||||
|
|
||||||
|
# Save modified config
|
||||||
|
config["quantization"] = "fp8_e4m3"
|
||||||
|
output_config = output_dir / "config.json"
|
||||||
|
with open(output_config, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
print(f"Saved config to {output_config}")
|
||||||
|
|
||||||
|
# Copy other files
|
||||||
|
for src_file in input_dir.iterdir():
|
||||||
|
if src_file.suffix == ".safetensors":
|
||||||
|
continue
|
||||||
|
if src_file.name == "config.json":
|
||||||
|
continue
|
||||||
|
dst_file = output_dir / src_file.name
|
||||||
|
if src_file.is_file() and not dst_file.exists():
|
||||||
|
shutil.copy2(src_file, dst_file)
|
||||||
|
|
||||||
|
print("\nDone!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user