- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each kv head sums its group of query-head grads; no atomics) + Tensor/ops node. - Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim; attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical. - --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export writes real num_key_value_heads (xserv repeat_kv grouping aligned). - Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape); parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
52 lines
1.9 KiB
Rust
52 lines
1.9 KiB
Rust
use std::env;
|
|
use std::path::Path;
|
|
use std::process::Command;
|
|
|
|
fn main() {
|
|
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
|
|
println!("cargo:rerun-if-changed=../../csrc/");
|
|
|
|
let cuda_path = env::var("CUDA_HOME")
|
|
.or_else(|_| env::var("CUDA_PATH"))
|
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
|
|
|
// Locally there is no nvcc / GPU. Detect that and skip the CUDA build so
|
|
// `cargo check`/`cargo build` of host-side Rust still works. The `no_cuda`
|
|
// cfg makes the FFI bindings + smoke test compile (but not run) without nvcc.
|
|
if !nvcc_available(&cuda_path) {
|
|
println!("cargo:warning=nvcc not found — skipping CUDA compilation (host-only build).");
|
|
println!("cargo:rustc-cfg=no_cuda");
|
|
return;
|
|
}
|
|
|
|
println!("cargo:rustc-link-search=native={cuda_path}/lib64");
|
|
println!("cargo:rustc-link-lib=dylib=cudart");
|
|
println!("cargo:rustc-link-lib=dylib=cuda");
|
|
// cuBLAS is used only as a correctness reference for the hand-written GEMM.
|
|
println!("cargo:rustc-link-lib=dylib=cublas");
|
|
|
|
cc::Build::new()
|
|
.cuda(true)
|
|
.cudart("shared")
|
|
.flag("-gencode=arch=compute_120,code=sm_120")
|
|
.file("../../csrc/test/vecadd.cu")
|
|
.file("../../csrc/ops/elementwise.cu")
|
|
.file("../../csrc/ops/gemm.cu")
|
|
.file("../../csrc/ops/nn.cu")
|
|
.file("../../csrc/ops/model.cu")
|
|
.file("../../csrc/ops/optim.cu")
|
|
.file("../../csrc/ops/attention.cu")
|
|
.file("../../csrc/ops/flash_attention.cu")
|
|
.file("../../csrc/ops/repeat_kv.cu")
|
|
.file("../../csrc/ops/cast.cu")
|
|
.file("../../csrc/ops/dropout.cu")
|
|
.compile("xtrain_cuda_kernels");
|
|
}
|
|
|
|
fn nvcc_available(cuda_path: &str) -> bool {
|
|
if Command::new("nvcc").arg("--version").output().is_ok() {
|
|
return true;
|
|
}
|
|
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
|
|
}
|