New xtrain-optim crate. AdamW with per-param m/v moments keyed by params() index, global bias correction, and decoupled weight decay (matches torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers, unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr argument leaves room for an LR schedule. Host unit test checks the update against an independent reference recurrence over 20 steps and the pure-decay (g=0) boundary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
27 lines
891 B
Rust
27 lines
891 B
Rust
use std::env;
|
|
use std::path::Path;
|
|
use std::process::Command;
|
|
|
|
// Per-crate convention (see the other crates): the AdamW *math* is host-only and
|
|
// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads
|
|
// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does
|
|
// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled.
|
|
fn main() {
|
|
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
|
|
|
|
let cuda_path = env::var("CUDA_HOME")
|
|
.or_else(|_| env::var("CUDA_PATH"))
|
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
|
|
|
if !nvcc_available(&cuda_path) {
|
|
println!("cargo:rustc-cfg=no_cuda");
|
|
}
|
|
}
|
|
|
|
fn nvcc_available(cuda_path: &str) -> bool {
|
|
if Command::new("nvcc").arg("--version").output().is_ok() {
|
|
return true;
|
|
}
|
|
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
|
|
}
|