Files
xserv/crates/xserv-kernels/build.rs
Gahow Wang d33220498a quantization: MXFP4 W4A16 expert weights (memory-optimization foundation)
Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 +
per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads
the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused
dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill
(M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected
by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E].

Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to
FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB
5090 with room for KV cache.

NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less
efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes
decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses
(350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization
foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt
block-scaled MXFP4) or a Marlin-class kernel; see
docs/benchmarks/mxfp4-and-llama-decode.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-12 15:01:42 +08:00

41 lines
1.6 KiB
Rust

use std::env;
fn main() {
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cublas");
println!("cargo:rustc-link-lib=dylib=cublasLt");
cc::Build::new()
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.include("../../csrc")
.file("../../csrc/gemm/naive.cu")
.file("../../csrc/gemm/tiled.cu")
.file("../../csrc/gemm/gemv.cu")
.file("../../csrc/normalization/rmsnorm.cu")
.file("../../csrc/normalization/layernorm.cu")
.file("../../csrc/activation/activations.cu")
.file("../../csrc/reduce/softmax.cu")
.file("../../csrc/reduce/argmax.cu")
.file("../../csrc/embedding/embedding.cu")
.file("../../csrc/embedding/rope.cu")
.file("../../csrc/attention/causal_mask.cu")
.file("../../csrc/embedding/transpose.cu")
.file("../../csrc/attention/flash_attention.cu")
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");
}