Files
xserv/crates/xserv-kernels/build.rs
Gahow Wang 9f1fbbb98b quantization: add FP8 E4M3 W8A16 for gpt-oss MoE expert weights
Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem)
with per-expert FP32 scale factors. At inference, a fused CUDA kernel
dequantizes to BF16 before the existing cuBLAS batched GEMM.

Results on gpt-oss-20b (50-problem GSM8K subset):
  - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM)
  - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total)

No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%).

New files:
  - tools/quantize_fp8.py: offline BF16→FP8 conversion script
  - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel
  - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper
  - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-06-07 19:33:07 +08:00

38 lines
1.4 KiB
Rust

use std::env;
fn main() {
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cublas");
cc::Build::new()
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.include("../../csrc")
.file("../../csrc/gemm/naive.cu")
.file("../../csrc/gemm/tiled.cu")
.file("../../csrc/gemm/gemv.cu")
.file("../../csrc/normalization/rmsnorm.cu")
.file("../../csrc/normalization/layernorm.cu")
.file("../../csrc/activation/activations.cu")
.file("../../csrc/reduce/softmax.cu")
.file("../../csrc/reduce/argmax.cu")
.file("../../csrc/embedding/embedding.cu")
.file("../../csrc/embedding/rope.cu")
.file("../../csrc/attention/causal_mask.cu")
.file("../../csrc/embedding/transpose.cu")
.file("../../csrc/attention/flash_attention.cu")
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");
}