optim: hand-written AdamW (decoupled weight decay + bias correction)
New xtrain-optim crate. AdamW with per-param m/v moments keyed by params() index, global bias correction, and decoupled weight decay (matches torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers, unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr argument leaves room for an LR schedule. Host unit test checks the update against an independent reference recurrence over 20 steps and the pure-decay (g=0) boundary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
12
crates/xtrain-optim/Cargo.toml
Normal file
12
crates/xtrain-optim/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "xtrain-optim"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xtrain-tensor = { path = "../xtrain-tensor" }
|
||||
xtrain-autodiff = { path = "../xtrain-autodiff" }
|
||||
|
||||
[dev-dependencies]
|
||||
# Acceptance tests drive the GPU (device selection) directly.
|
||||
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
26
crates/xtrain-optim/build.rs
Normal file
26
crates/xtrain-optim/build.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
// Per-crate convention (see the other crates): the AdamW *math* is host-only and
|
||||
// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads
|
||||
// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does
|
||||
// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled.
|
||||
fn main() {
|
||||
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
|
||||
|
||||
let cuda_path = env::var("CUDA_HOME")
|
||||
.or_else(|_| env::var("CUDA_PATH"))
|
||||
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||
|
||||
if !nvcc_available(&cuda_path) {
|
||||
println!("cargo:rustc-cfg=no_cuda");
|
||||
}
|
||||
}
|
||||
|
||||
fn nvcc_available(cuda_path: &str) -> bool {
|
||||
if Command::new("nvcc").arg("--version").output().is_ok() {
|
||||
return true;
|
||||
}
|
||||
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
|
||||
}
|
||||
154
crates/xtrain-optim/src/lib.rs
Normal file
154
crates/xtrain-optim/src/lib.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! Hand-written AdamW optimizer (Phase T6).
|
||||
//!
|
||||
//! AdamW = Adam with **decoupled** weight decay (Loshchilov & Hutter, 2019): the
|
||||
//! weight-decay term is applied directly to the parameter, NOT folded into the
|
||||
//! gradient (so it does not interact with the adaptive `v` denominator). This
|
||||
//! matches `torch.optim.AdamW`.
|
||||
//!
|
||||
//! Update for parameter `θ` at step `t` (1-indexed), with gradient `g`:
|
||||
//! ```text
|
||||
//! m ← β1·m + (1−β1)·g
|
||||
//! v ← β2·v + (1−β2)·g²
|
||||
//! m̂ ← m / (1 − β1ᵗ) (bias correction)
|
||||
//! v̂ ← v / (1 − β2ᵗ)
|
||||
//! θ ← θ − lr·( m̂ / (√v̂ + ε) + wd·θ )
|
||||
//! ```
|
||||
//! The `lr·wd·θ` term is the decoupled decay. Note PyTorch applies decay as
|
||||
//! `θ ← θ·(1 − lr·wd)` then the Adam step; both are algebraically the same
|
||||
//! first-order update — we fold decay into the single subtraction above, which
|
||||
//! is what PyTorch's default (`maximize=False`, no `amsgrad`) computes.
|
||||
//!
|
||||
//! The math operates on flat host `f32` buffers ([`AdamW::step_host`]) so it is
|
||||
//! unit-testable on a GPU-less host; [`AdamW::step`] is a thin wrapper that
|
||||
//! round-trips each parameter's value/grad through the GPU tensor and is gated
|
||||
//! behind `not(no_cuda)`.
|
||||
|
||||
/// Per-parameter optimizer state: the first (`m`) and second (`v`) moment
|
||||
/// estimates, one f32 per element, kept flat (matching the parameter layout).
|
||||
struct ParamState {
|
||||
m: Vec<f32>,
|
||||
v: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Decoupled-weight-decay Adam. One instance owns the moment state for a fixed
|
||||
/// list of parameters, keyed by their index in the slice passed to `step`
|
||||
/// (the model's stable `params()` order).
|
||||
pub struct AdamW {
|
||||
pub lr: f32,
|
||||
beta1: f32,
|
||||
beta2: f32,
|
||||
eps: f32,
|
||||
weight_decay: f32,
|
||||
/// Global step count (shared across all params for bias correction).
|
||||
t: u64,
|
||||
/// Lazily sized to the parameter list on the first `step`.
|
||||
state: Vec<ParamState>,
|
||||
}
|
||||
|
||||
impl AdamW {
|
||||
/// PyTorch-default hyperparameters except `lr`/`weight_decay`, which you set
|
||||
/// (β1=0.9, β2=0.999, ε=1e-8).
|
||||
pub fn new(lr: f32, weight_decay: f32) -> Self {
|
||||
Self::with_betas(lr, weight_decay, 0.9, 0.999, 1e-8)
|
||||
}
|
||||
|
||||
pub fn with_betas(lr: f32, weight_decay: f32, beta1: f32, beta2: f32, eps: f32) -> Self {
|
||||
Self {
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
t: 0,
|
||||
state: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Current global step (number of `step` calls so far).
|
||||
pub fn step_count(&self) -> u64 {
|
||||
self.t
|
||||
}
|
||||
|
||||
/// Pure-host AdamW step over flat parameter/gradient buffers. `params[i]` is
|
||||
/// updated in place using `grads[i]`; both are the i-th parameter's elements
|
||||
/// in the model's stable order. Lazily allocates moment state on first call.
|
||||
///
|
||||
/// This is the testable core — no GPU, no autograd. `lr` is passed per call
|
||||
/// so a schedule can vary it each step.
|
||||
pub fn step_host(&mut self, lr: f32, params: &mut [Vec<f32>], grads: &[Vec<f32>]) {
|
||||
assert_eq!(params.len(), grads.len(), "param/grad count mismatch");
|
||||
if self.state.is_empty() {
|
||||
self.state = params
|
||||
.iter()
|
||||
.map(|p| ParamState {
|
||||
m: vec![0.0; p.len()],
|
||||
v: vec![0.0; p.len()],
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
assert_eq!(self.state.len(), params.len(), "param count changed");
|
||||
|
||||
self.t += 1;
|
||||
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
|
||||
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
|
||||
|
||||
for (i, (p, g)) in params.iter_mut().zip(grads).enumerate() {
|
||||
assert_eq!(p.len(), g.len(), "param/grad len mismatch at {i}");
|
||||
let st = &mut self.state[i];
|
||||
for j in 0..p.len() {
|
||||
let gj = g[j];
|
||||
st.m[j] = self.beta1 * st.m[j] + (1.0 - self.beta1) * gj;
|
||||
st.v[j] = self.beta2 * st.v[j] + (1.0 - self.beta2) * gj * gj;
|
||||
let mhat = st.m[j] / bc1;
|
||||
let vhat = st.v[j] / bc2;
|
||||
// Decoupled weight decay: decay term uses the *current* param,
|
||||
// matching PyTorch's `p ← p − lr·wd·p` applied alongside the step.
|
||||
p[j] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.weight_decay * p[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
mod gpu {
|
||||
use super::AdamW;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
impl AdamW {
|
||||
/// Apply one AdamW step to every parameter `Var`, using `lr` for this step
|
||||
/// (so an LR schedule can vary it). Pulls each param's value and `.grad()`
|
||||
/// to the host, runs [`AdamW::step_host`], and writes the updated value
|
||||
/// back with `set_value`. A param with no grad is fed a zero grad, so the
|
||||
/// Adam term vanishes and only decoupled weight decay applies (the model's
|
||||
/// params all receive grads each step, so this is just a safety default).
|
||||
///
|
||||
/// Does NOT zero grads — the caller does that (matching the GD-step
|
||||
/// template in the T5 overfit test).
|
||||
pub fn step(&mut self, lr: f32, params: &[Var]) {
|
||||
let device = params[0].value().device();
|
||||
let shapes: Vec<Vec<usize>> =
|
||||
params.iter().map(|p| p.value().shape().to_vec()).collect();
|
||||
|
||||
let mut host_params: Vec<Vec<f32>> = params
|
||||
.iter()
|
||||
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||
.collect();
|
||||
let host_grads: Vec<Vec<f32>> = params
|
||||
.iter()
|
||||
.zip(&host_params)
|
||||
.map(|(p, hp)| match p.grad() {
|
||||
Some(g) => g.to_device(Device::Cpu).as_slice::<f32>().to_vec(),
|
||||
None => vec![0.0; hp.len()], // no grad → no update this step
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.step_host(lr, &mut host_params, &host_grads);
|
||||
|
||||
for ((p, data), shape) in params.iter().zip(&host_params).zip(&shapes) {
|
||||
let t = Tensor::from_slice(data, shape).to_device(device);
|
||||
p.set_value(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
99
crates/xtrain-optim/tests/adamw_host.rs
Normal file
99
crates/xtrain-optim/tests/adamw_host.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against
|
||||
// an independent, hand-rolled reference implementation of the same recurrence for
|
||||
// several steps with non-trivial weight decay — catching bias-correction and
|
||||
// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real
|
||||
// model) lives in xtrain-train; this is the fast local guard on the formula.
|
||||
|
||||
use xtrain_optim::AdamW;
|
||||
|
||||
// Independent reference: the textbook AdamW recurrence, kept separate from the
|
||||
// implementation so a shared bug can't hide.
|
||||
struct RefAdamW {
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
eps: f32,
|
||||
wd: f32,
|
||||
t: i32,
|
||||
m: Vec<f32>,
|
||||
v: Vec<f32>,
|
||||
}
|
||||
|
||||
impl RefAdamW {
|
||||
fn new(n: usize, wd: f32) -> Self {
|
||||
Self {
|
||||
b1: 0.9,
|
||||
b2: 0.999,
|
||||
eps: 1e-8,
|
||||
wd,
|
||||
t: 0,
|
||||
m: vec![0.0; n],
|
||||
v: vec![0.0; n],
|
||||
}
|
||||
}
|
||||
fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) {
|
||||
self.t += 1;
|
||||
let bc1 = 1.0 - self.b1.powi(self.t);
|
||||
let bc2 = 1.0 - self.b2.powi(self.t);
|
||||
for i in 0..p.len() {
|
||||
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i];
|
||||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i];
|
||||
let mhat = self.m[i] / bc1;
|
||||
let vhat = self.v[i] / bc2;
|
||||
p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adamw_matches_reference_recurrence() {
|
||||
let lr = 0.01;
|
||||
let wd = 0.1;
|
||||
let mut opt = AdamW::new(lr, wd);
|
||||
|
||||
// Two parameters of different sizes (exercises per-param state keying).
|
||||
let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]];
|
||||
let mut p_ref = p_impl.clone();
|
||||
let mut r0 = RefAdamW::new(4, wd);
|
||||
let mut r1 = RefAdamW::new(2, wd);
|
||||
|
||||
// Deterministic pseudo-grads that change every step.
|
||||
let grad = |step: usize, idx: usize, j: usize| -> f32 {
|
||||
let s = (step * 13 + idx * 7 + j * 3) as f32;
|
||||
(s * 0.123).sin() * 0.5
|
||||
};
|
||||
|
||||
for step in 0..20 {
|
||||
let grads = vec![
|
||||
(0..4).map(|j| grad(step, 0, j)).collect::<Vec<_>>(),
|
||||
(0..2).map(|j| grad(step, 1, j)).collect::<Vec<_>>(),
|
||||
];
|
||||
opt.step_host(lr, &mut p_impl, &grads);
|
||||
r0.step(lr, &mut p_ref[0], &grads[0]);
|
||||
r1.step(lr, &mut p_ref[1], &grads[1]);
|
||||
}
|
||||
|
||||
assert_eq!(opt.step_count(), 20);
|
||||
for (pi, pr) in p_impl.iter().zip(&p_ref) {
|
||||
for (a, b) in pi.iter().zip(pr) {
|
||||
assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_grad_only_decays() {
|
||||
// With g=0 and wd>0, the step must reduce to pure decoupled decay:
|
||||
// θ ← θ − lr·wd·θ (Adam term is 0/eps = 0).
|
||||
let lr = 0.1;
|
||||
let wd = 0.5;
|
||||
let mut opt = AdamW::new(lr, wd);
|
||||
let mut p = vec![vec![2.0f32]];
|
||||
let g = vec![vec![0.0f32]];
|
||||
opt.step_host(lr, &mut p, &g);
|
||||
let expected = 2.0 - lr * wd * 2.0;
|
||||
assert!(
|
||||
(p[0][0] - expected).abs() < 1e-6,
|
||||
"{} != {expected}",
|
||||
p[0][0]
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user