Weight-only 4-bit for the gpt-oss MoE experts: weights stored MXFP4 (E2M1 + per-32-element UE8M0 block scale, tools/quantize_mxfp4.py), a fused kernel reads the 4-bit weights and dequantizes on-chip to BF16. Decode (M=1) uses a fused dequant-GEMV (batched_gemv_mxfp4) with shared-memory activation tiling; prefill (M>1) dequantizes to BF16 then reuses the BF16 batched GEMM. MXFP4 is detected by the scale tensor's rank (3-D [E,N,K/32]) vs FP8's 1-D [E]. Verified on dash5 (gpt-oss-20b, TP=2, 5090): byte-identical greedy tokens to FP8/BF16, smallest footprint (13 GB vs 22 GB FP8, 39 GB BF16) — fits one 32 GB 5090 with room for KV cache. NOT a decode speedup: the hand-written W4A16 GEMV (no tensor cores) is less efficient than cuBLASLt's FP8 tensor-core GEMM, so even at half the weight bytes decode is 17.0 ms vs FP8 13.5 ms (faster than BF16 18.8 ms); prefill regresses (350 vs 134 ms, dequant fallback). Committed as a correct memory-optimization foundation. Beating FP8 on speed needs FP4 tensor cores (W4A4, cuBLASLt block-scaled MXFP4) or a Marlin-class kernel; see docs/benchmarks/mxfp4-and-llama-decode.md. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
41 lines
1.6 KiB
Rust
41 lines
1.6 KiB
Rust
use std::env;
|
|
|
|
fn main() {
|
|
let cuda_path = env::var("CUDA_HOME")
|
|
.or_else(|_| env::var("CUDA_PATH"))
|
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
|
|
|
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
|
println!("cargo:rustc-link-lib=dylib=cublas");
|
|
println!("cargo:rustc-link-lib=dylib=cublasLt");
|
|
|
|
cc::Build::new()
|
|
.cuda(true)
|
|
.cudart("shared")
|
|
.flag("-gencode=arch=compute_120,code=sm_120")
|
|
.include("../../csrc")
|
|
.file("../../csrc/gemm/naive.cu")
|
|
.file("../../csrc/gemm/tiled.cu")
|
|
.file("../../csrc/gemm/gemv.cu")
|
|
.file("../../csrc/normalization/rmsnorm.cu")
|
|
.file("../../csrc/normalization/layernorm.cu")
|
|
.file("../../csrc/activation/activations.cu")
|
|
.file("../../csrc/reduce/softmax.cu")
|
|
.file("../../csrc/reduce/argmax.cu")
|
|
.file("../../csrc/embedding/embedding.cu")
|
|
.file("../../csrc/embedding/rope.cu")
|
|
.file("../../csrc/attention/causal_mask.cu")
|
|
.file("../../csrc/embedding/transpose.cu")
|
|
.file("../../csrc/attention/flash_attention.cu")
|
|
.file("../../csrc/attention/paged_attention.cu")
|
|
.file("../../csrc/attention/reshape_and_cache.cu")
|
|
.file("../../csrc/moe/moe_kernels.cu")
|
|
.file("../../csrc/quantization/dequant_fp8.cu")
|
|
.file("../../csrc/quantization/quantize_fp8.cu")
|
|
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
|
.compile("xserv_kernels");
|
|
|
|
println!("cargo:rerun-if-changed=../../csrc/");
|
|
}
|