Store expert gate_up_proj and down_proj weights in FP8 E4M3 (1 byte/elem) with per-expert FP32 scale factors. At inference, a fused CUDA kernel dequantizes to BF16 before the existing cuBLAS batched GEMM. Results on gpt-oss-20b (50-problem GSM8K subset): - FP8 TP=1: 47/50 = 94.0% (single RTX 5090, ~25 GB VRAM) - BF16 TP=2: 47/50 = 94.0% (requires 2× RTX 5090, ~39 GB total) No measurable accuracy degradation. Model size: 41.8 GB → 22.7 GB (−46%). New files: - tools/quantize_fp8.py: offline BF16→FP8 conversion script - csrc/quantization/dequant_fp8.cu: per-expert-scale dequant kernel - crates/xserv-kernels/src/quantization.rs: Rust FFI wrapper - tools/eval_gsm8k_batch.sh: GSM8K accuracy evaluation harness Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
38 lines
1.4 KiB
Rust
38 lines
1.4 KiB
Rust
use std::env;
|
|
|
|
fn main() {
|
|
let cuda_path = env::var("CUDA_HOME")
|
|
.or_else(|_| env::var("CUDA_PATH"))
|
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
|
|
|
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
|
println!("cargo:rustc-link-lib=dylib=cublas");
|
|
|
|
cc::Build::new()
|
|
.cuda(true)
|
|
.cudart("shared")
|
|
.flag("-gencode=arch=compute_120,code=sm_120")
|
|
.include("../../csrc")
|
|
.file("../../csrc/gemm/naive.cu")
|
|
.file("../../csrc/gemm/tiled.cu")
|
|
.file("../../csrc/gemm/gemv.cu")
|
|
.file("../../csrc/normalization/rmsnorm.cu")
|
|
.file("../../csrc/normalization/layernorm.cu")
|
|
.file("../../csrc/activation/activations.cu")
|
|
.file("../../csrc/reduce/softmax.cu")
|
|
.file("../../csrc/reduce/argmax.cu")
|
|
.file("../../csrc/embedding/embedding.cu")
|
|
.file("../../csrc/embedding/rope.cu")
|
|
.file("../../csrc/attention/causal_mask.cu")
|
|
.file("../../csrc/embedding/transpose.cu")
|
|
.file("../../csrc/attention/flash_attention.cu")
|
|
.file("../../csrc/attention/paged_attention.cu")
|
|
.file("../../csrc/attention/reshape_and_cache.cu")
|
|
.file("../../csrc/moe/moe_kernels.cu")
|
|
.file("../../csrc/quantization/dequant_fp8.cu")
|
|
.compile("xserv_kernels");
|
|
|
|
println!("cargo:rerun-if-changed=../../csrc/");
|
|
}
|