Files
xtrain/crates/xtrain-cuda/build.rs
Gahow Wang 830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00

52 lines
1.9 KiB
Rust

use std::env;
use std::path::Path;
use std::process::Command;
fn main() {
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
println!("cargo:rerun-if-changed=../../csrc/");
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
// Locally there is no nvcc / GPU. Detect that and skip the CUDA build so
// `cargo check`/`cargo build` of host-side Rust still works. The `no_cuda`
// cfg makes the FFI bindings + smoke test compile (but not run) without nvcc.
if !nvcc_available(&cuda_path) {
println!("cargo:warning=nvcc not found — skipping CUDA compilation (host-only build).");
println!("cargo:rustc-cfg=no_cuda");
return;
}
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
// cuBLAS is used only as a correctness reference for the hand-written GEMM.
println!("cargo:rustc-link-lib=dylib=cublas");
cc::Build::new()
.cuda(true)
.cudart("shared")
.flag("-gencode=arch=compute_120,code=sm_120")
.file("../../csrc/test/vecadd.cu")
.file("../../csrc/ops/elementwise.cu")
.file("../../csrc/ops/gemm.cu")
.file("../../csrc/ops/nn.cu")
.file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/flash_attention.cu")
.file("../../csrc/ops/repeat_kv.cu")
.file("../../csrc/ops/cast.cu")
.file("../../csrc/ops/dropout.cu")
.compile("xtrain_cuda_kernels");
}
fn nvcc_available(cuda_path: &str) -> bool {
if Command::new("nvcc").arg("--version").output().is_ok() {
return true;
}
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
}