Files
xtrain/crates/xtrain-optim/build.rs
Gahow Wang f22429f5b8 optim: hand-written AdamW (decoupled weight decay + bias correction)
New xtrain-optim crate. AdamW with per-param m/v moments keyed by params()
index, global bias correction, and decoupled weight decay (matches
torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers,
unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips
each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr
argument leaves room for an LR schedule.

Host unit test checks the update against an independent reference recurrence
over 20 steps and the pure-decay (g=0) boundary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:28:23 +08:00

27 lines
891 B
Rust

use std::env;
use std::path::Path;
use std::process::Command;
// Per-crate convention (see the other crates): the AdamW *math* is host-only and
// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads
// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does
// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled.
fn main() {
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
if !nvcc_available(&cuda_path) {
println!("cargo:rustc-cfg=no_cuda");
}
}
fn nvcc_available(cuda_path: &str) -> bool {
if Command::new("nvcc").arg("--version").output().is_ok() {
return true;
}
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
}