diff --git a/Cargo.lock b/Cargo.lock index 48db26d..f2366f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,15 @@ dependencies = [ "xtrain-tensor", ] +[[package]] +name = "xtrain-optim" +version = "0.1.0" +dependencies = [ + "xtrain-autodiff", + "xtrain-cuda", + "xtrain-tensor", +] + [[package]] name = "xtrain-tensor" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 456b6d7..1f91829 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "crates/xtrain-tensor", "crates/xtrain-autodiff", "crates/xtrain-model", + "crates/xtrain-optim", ] [workspace.package] diff --git a/crates/xtrain-optim/Cargo.toml b/crates/xtrain-optim/Cargo.toml new file mode 100644 index 0000000..d2aff25 --- /dev/null +++ b/crates/xtrain-optim/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "xtrain-optim" +version.workspace = true +edition.workspace = true + +[dependencies] +xtrain-tensor = { path = "../xtrain-tensor" } +xtrain-autodiff = { path = "../xtrain-autodiff" } + +[dev-dependencies] +# Acceptance tests drive the GPU (device selection) directly. +xtrain-cuda = { path = "../xtrain-cuda" } diff --git a/crates/xtrain-optim/build.rs b/crates/xtrain-optim/build.rs new file mode 100644 index 0000000..43ac264 --- /dev/null +++ b/crates/xtrain-optim/build.rs @@ -0,0 +1,26 @@ +use std::env; +use std::path::Path; +use std::process::Command; + +// Per-crate convention (see the other crates): the AdamW *math* is host-only and +// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads +// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does +// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled. +fn main() { + println!("cargo:rustc-check-cfg=cfg(no_cuda)"); + + let cuda_path = env::var("CUDA_HOME") + .or_else(|_| env::var("CUDA_PATH")) + .unwrap_or_else(|_| "/usr/local/cuda".to_string()); + + if !nvcc_available(&cuda_path) { + println!("cargo:rustc-cfg=no_cuda"); + } +} + +fn nvcc_available(cuda_path: &str) -> bool { + if Command::new("nvcc").arg("--version").output().is_ok() { + return true; + } + Path::new(&format!("{cuda_path}/bin/nvcc")).exists() +} diff --git a/crates/xtrain-optim/src/lib.rs b/crates/xtrain-optim/src/lib.rs new file mode 100644 index 0000000..edd37e7 --- /dev/null +++ b/crates/xtrain-optim/src/lib.rs @@ -0,0 +1,154 @@ +//! Hand-written AdamW optimizer (Phase T6). +//! +//! AdamW = Adam with **decoupled** weight decay (Loshchilov & Hutter, 2019): the +//! weight-decay term is applied directly to the parameter, NOT folded into the +//! gradient (so it does not interact with the adaptive `v` denominator). This +//! matches `torch.optim.AdamW`. +//! +//! Update for parameter `θ` at step `t` (1-indexed), with gradient `g`: +//! ```text +//! m ← β1·m + (1−β1)·g +//! v ← β2·v + (1−β2)·g² +//! m̂ ← m / (1 − β1ᵗ) (bias correction) +//! v̂ ← v / (1 − β2ᵗ) +//! θ ← θ − lr·( m̂ / (√v̂ + ε) + wd·θ ) +//! ``` +//! The `lr·wd·θ` term is the decoupled decay. Note PyTorch applies decay as +//! `θ ← θ·(1 − lr·wd)` then the Adam step; both are algebraically the same +//! first-order update — we fold decay into the single subtraction above, which +//! is what PyTorch's default (`maximize=False`, no `amsgrad`) computes. +//! +//! The math operates on flat host `f32` buffers ([`AdamW::step_host`]) so it is +//! unit-testable on a GPU-less host; [`AdamW::step`] is a thin wrapper that +//! round-trips each parameter's value/grad through the GPU tensor and is gated +//! behind `not(no_cuda)`. + +/// Per-parameter optimizer state: the first (`m`) and second (`v`) moment +/// estimates, one f32 per element, kept flat (matching the parameter layout). +struct ParamState { + m: Vec, + v: Vec, +} + +/// Decoupled-weight-decay Adam. One instance owns the moment state for a fixed +/// list of parameters, keyed by their index in the slice passed to `step` +/// (the model's stable `params()` order). +pub struct AdamW { + pub lr: f32, + beta1: f32, + beta2: f32, + eps: f32, + weight_decay: f32, + /// Global step count (shared across all params for bias correction). + t: u64, + /// Lazily sized to the parameter list on the first `step`. + state: Vec, +} + +impl AdamW { + /// PyTorch-default hyperparameters except `lr`/`weight_decay`, which you set + /// (β1=0.9, β2=0.999, ε=1e-8). + pub fn new(lr: f32, weight_decay: f32) -> Self { + Self::with_betas(lr, weight_decay, 0.9, 0.999, 1e-8) + } + + pub fn with_betas(lr: f32, weight_decay: f32, beta1: f32, beta2: f32, eps: f32) -> Self { + Self { + lr, + beta1, + beta2, + eps, + weight_decay, + t: 0, + state: Vec::new(), + } + } + + /// Current global step (number of `step` calls so far). + pub fn step_count(&self) -> u64 { + self.t + } + + /// Pure-host AdamW step over flat parameter/gradient buffers. `params[i]` is + /// updated in place using `grads[i]`; both are the i-th parameter's elements + /// in the model's stable order. Lazily allocates moment state on first call. + /// + /// This is the testable core — no GPU, no autograd. `lr` is passed per call + /// so a schedule can vary it each step. + pub fn step_host(&mut self, lr: f32, params: &mut [Vec], grads: &[Vec]) { + assert_eq!(params.len(), grads.len(), "param/grad count mismatch"); + if self.state.is_empty() { + self.state = params + .iter() + .map(|p| ParamState { + m: vec![0.0; p.len()], + v: vec![0.0; p.len()], + }) + .collect(); + } + assert_eq!(self.state.len(), params.len(), "param count changed"); + + self.t += 1; + let bc1 = 1.0 - self.beta1.powi(self.t as i32); + let bc2 = 1.0 - self.beta2.powi(self.t as i32); + + for (i, (p, g)) in params.iter_mut().zip(grads).enumerate() { + assert_eq!(p.len(), g.len(), "param/grad len mismatch at {i}"); + let st = &mut self.state[i]; + for j in 0..p.len() { + let gj = g[j]; + st.m[j] = self.beta1 * st.m[j] + (1.0 - self.beta1) * gj; + st.v[j] = self.beta2 * st.v[j] + (1.0 - self.beta2) * gj * gj; + let mhat = st.m[j] / bc1; + let vhat = st.v[j] / bc2; + // Decoupled weight decay: decay term uses the *current* param, + // matching PyTorch's `p ← p − lr·wd·p` applied alongside the step. + p[j] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.weight_decay * p[j]); + } + } + } +} + +#[cfg(not(no_cuda))] +mod gpu { + use super::AdamW; + use xtrain_autodiff::tape::Var; + use xtrain_tensor::{Device, Tensor}; + + impl AdamW { + /// Apply one AdamW step to every parameter `Var`, using `lr` for this step + /// (so an LR schedule can vary it). Pulls each param's value and `.grad()` + /// to the host, runs [`AdamW::step_host`], and writes the updated value + /// back with `set_value`. A param with no grad is fed a zero grad, so the + /// Adam term vanishes and only decoupled weight decay applies (the model's + /// params all receive grads each step, so this is just a safety default). + /// + /// Does NOT zero grads — the caller does that (matching the GD-step + /// template in the T5 overfit test). + pub fn step(&mut self, lr: f32, params: &[Var]) { + let device = params[0].value().device(); + let shapes: Vec> = + params.iter().map(|p| p.value().shape().to_vec()).collect(); + + let mut host_params: Vec> = params + .iter() + .map(|p| p.value().to_device(Device::Cpu).as_slice::().to_vec()) + .collect(); + let host_grads: Vec> = params + .iter() + .zip(&host_params) + .map(|(p, hp)| match p.grad() { + Some(g) => g.to_device(Device::Cpu).as_slice::().to_vec(), + None => vec![0.0; hp.len()], // no grad → no update this step + }) + .collect(); + + self.step_host(lr, &mut host_params, &host_grads); + + for ((p, data), shape) in params.iter().zip(&host_params).zip(&shapes) { + let t = Tensor::from_slice(data, shape).to_device(device); + p.set_value(t); + } + } + } +} diff --git a/crates/xtrain-optim/tests/adamw_host.rs b/crates/xtrain-optim/tests/adamw_host.rs new file mode 100644 index 0000000..15e8d33 --- /dev/null +++ b/crates/xtrain-optim/tests/adamw_host.rs @@ -0,0 +1,99 @@ +// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against +// an independent, hand-rolled reference implementation of the same recurrence for +// several steps with non-trivial weight decay — catching bias-correction and +// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real +// model) lives in xtrain-train; this is the fast local guard on the formula. + +use xtrain_optim::AdamW; + +// Independent reference: the textbook AdamW recurrence, kept separate from the +// implementation so a shared bug can't hide. +struct RefAdamW { + b1: f32, + b2: f32, + eps: f32, + wd: f32, + t: i32, + m: Vec, + v: Vec, +} + +impl RefAdamW { + fn new(n: usize, wd: f32) -> Self { + Self { + b1: 0.9, + b2: 0.999, + eps: 1e-8, + wd, + t: 0, + m: vec![0.0; n], + v: vec![0.0; n], + } + } + fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) { + self.t += 1; + let bc1 = 1.0 - self.b1.powi(self.t); + let bc2 = 1.0 - self.b2.powi(self.t); + for i in 0..p.len() { + self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i]; + self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i]; + let mhat = self.m[i] / bc1; + let vhat = self.v[i] / bc2; + p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]); + } + } +} + +#[test] +fn adamw_matches_reference_recurrence() { + let lr = 0.01; + let wd = 0.1; + let mut opt = AdamW::new(lr, wd); + + // Two parameters of different sizes (exercises per-param state keying). + let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]]; + let mut p_ref = p_impl.clone(); + let mut r0 = RefAdamW::new(4, wd); + let mut r1 = RefAdamW::new(2, wd); + + // Deterministic pseudo-grads that change every step. + let grad = |step: usize, idx: usize, j: usize| -> f32 { + let s = (step * 13 + idx * 7 + j * 3) as f32; + (s * 0.123).sin() * 0.5 + }; + + for step in 0..20 { + let grads = vec![ + (0..4).map(|j| grad(step, 0, j)).collect::>(), + (0..2).map(|j| grad(step, 1, j)).collect::>(), + ]; + opt.step_host(lr, &mut p_impl, &grads); + r0.step(lr, &mut p_ref[0], &grads[0]); + r1.step(lr, &mut p_ref[1], &grads[1]); + } + + assert_eq!(opt.step_count(), 20); + for (pi, pr) in p_impl.iter().zip(&p_ref) { + for (a, b) in pi.iter().zip(pr) { + assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}"); + } + } +} + +#[test] +fn zero_grad_only_decays() { + // With g=0 and wd>0, the step must reduce to pure decoupled decay: + // θ ← θ − lr·wd·θ (Adam term is 0/eps = 0). + let lr = 0.1; + let wd = 0.5; + let mut opt = AdamW::new(lr, wd); + let mut p = vec![vec![2.0f32]]; + let g = vec![vec![0.0f32]]; + opt.step_host(lr, &mut p, &g); + let expected = 2.0 - lr * wd * 2.0; + assert!( + (p[0][0] - expected).abs() < 1e-6, + "{} != {expected}", + p[0][0] + ); +}