optim: hand-written AdamW (decoupled weight decay + bias correction)

New xtrain-optim crate. AdamW with per-param m/v moments keyed by params()
index, global bias correction, and decoupled weight decay (matches
torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers,
unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips
each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr
argument leaves room for an LR schedule.

Host unit test checks the update against an independent reference recurrence
over 20 steps and the pure-decay (g=0) boundary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:28:23 +08:00
parent 8565565647
commit f22429f5b8
6 changed files with 301 additions and 0 deletions

9
Cargo.lock generated
View File

@@ -112,6 +112,15 @@ dependencies = [
"xtrain-tensor", "xtrain-tensor",
] ]
[[package]]
name = "xtrain-optim"
version = "0.1.0"
dependencies = [
"xtrain-autodiff",
"xtrain-cuda",
"xtrain-tensor",
]
[[package]] [[package]]
name = "xtrain-tensor" name = "xtrain-tensor"
version = "0.1.0" version = "0.1.0"

View File

@@ -5,6 +5,7 @@ members = [
"crates/xtrain-tensor", "crates/xtrain-tensor",
"crates/xtrain-autodiff", "crates/xtrain-autodiff",
"crates/xtrain-model", "crates/xtrain-model",
"crates/xtrain-optim",
] ]
[workspace.package] [workspace.package]

View File

@@ -0,0 +1,12 @@
[package]
name = "xtrain-optim"
version.workspace = true
edition.workspace = true
[dependencies]
xtrain-tensor = { path = "../xtrain-tensor" }
xtrain-autodiff = { path = "../xtrain-autodiff" }
[dev-dependencies]
# Acceptance tests drive the GPU (device selection) directly.
xtrain-cuda = { path = "../xtrain-cuda" }

View File

@@ -0,0 +1,26 @@
use std::env;
use std::path::Path;
use std::process::Command;
// Per-crate convention (see the other crates): the AdamW *math* is host-only and
// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads
// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does
// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled.
fn main() {
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
let cuda_path = env::var("CUDA_HOME")
.or_else(|_| env::var("CUDA_PATH"))
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
if !nvcc_available(&cuda_path) {
println!("cargo:rustc-cfg=no_cuda");
}
}
fn nvcc_available(cuda_path: &str) -> bool {
if Command::new("nvcc").arg("--version").output().is_ok() {
return true;
}
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
}

View File

@@ -0,0 +1,154 @@
//! Hand-written AdamW optimizer (Phase T6).
//!
//! AdamW = Adam with **decoupled** weight decay (Loshchilov & Hutter, 2019): the
//! weight-decay term is applied directly to the parameter, NOT folded into the
//! gradient (so it does not interact with the adaptive `v` denominator). This
//! matches `torch.optim.AdamW`.
//!
//! Update for parameter `θ` at step `t` (1-indexed), with gradient `g`:
//! ```text
//! m ← β1·m + (1β1)·g
//! v ← β2·v + (1β2)·g²
//! m̂ ← m / (1 β1ᵗ) (bias correction)
//! v̂ ← v / (1 β2ᵗ)
//! θ ← θ lr·( m̂ / (√v̂ + ε) + wd·θ )
//! ```
//! The `lr·wd·θ` term is the decoupled decay. Note PyTorch applies decay as
//! `θ ← θ·(1 lr·wd)` then the Adam step; both are algebraically the same
//! first-order update — we fold decay into the single subtraction above, which
//! is what PyTorch's default (`maximize=False`, no `amsgrad`) computes.
//!
//! The math operates on flat host `f32` buffers ([`AdamW::step_host`]) so it is
//! unit-testable on a GPU-less host; [`AdamW::step`] is a thin wrapper that
//! round-trips each parameter's value/grad through the GPU tensor and is gated
//! behind `not(no_cuda)`.
/// Per-parameter optimizer state: the first (`m`) and second (`v`) moment
/// estimates, one f32 per element, kept flat (matching the parameter layout).
struct ParamState {
m: Vec<f32>,
v: Vec<f32>,
}
/// Decoupled-weight-decay Adam. One instance owns the moment state for a fixed
/// list of parameters, keyed by their index in the slice passed to `step`
/// (the model's stable `params()` order).
pub struct AdamW {
pub lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
/// Global step count (shared across all params for bias correction).
t: u64,
/// Lazily sized to the parameter list on the first `step`.
state: Vec<ParamState>,
}
impl AdamW {
/// PyTorch-default hyperparameters except `lr`/`weight_decay`, which you set
/// (β1=0.9, β2=0.999, ε=1e-8).
pub fn new(lr: f32, weight_decay: f32) -> Self {
Self::with_betas(lr, weight_decay, 0.9, 0.999, 1e-8)
}
pub fn with_betas(lr: f32, weight_decay: f32, beta1: f32, beta2: f32, eps: f32) -> Self {
Self {
lr,
beta1,
beta2,
eps,
weight_decay,
t: 0,
state: Vec::new(),
}
}
/// Current global step (number of `step` calls so far).
pub fn step_count(&self) -> u64 {
self.t
}
/// Pure-host AdamW step over flat parameter/gradient buffers. `params[i]` is
/// updated in place using `grads[i]`; both are the i-th parameter's elements
/// in the model's stable order. Lazily allocates moment state on first call.
///
/// This is the testable core — no GPU, no autograd. `lr` is passed per call
/// so a schedule can vary it each step.
pub fn step_host(&mut self, lr: f32, params: &mut [Vec<f32>], grads: &[Vec<f32>]) {
assert_eq!(params.len(), grads.len(), "param/grad count mismatch");
if self.state.is_empty() {
self.state = params
.iter()
.map(|p| ParamState {
m: vec![0.0; p.len()],
v: vec![0.0; p.len()],
})
.collect();
}
assert_eq!(self.state.len(), params.len(), "param count changed");
self.t += 1;
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
for (i, (p, g)) in params.iter_mut().zip(grads).enumerate() {
assert_eq!(p.len(), g.len(), "param/grad len mismatch at {i}");
let st = &mut self.state[i];
for j in 0..p.len() {
let gj = g[j];
st.m[j] = self.beta1 * st.m[j] + (1.0 - self.beta1) * gj;
st.v[j] = self.beta2 * st.v[j] + (1.0 - self.beta2) * gj * gj;
let mhat = st.m[j] / bc1;
let vhat = st.v[j] / bc2;
// Decoupled weight decay: decay term uses the *current* param,
// matching PyTorch's `p ← p lr·wd·p` applied alongside the step.
p[j] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.weight_decay * p[j]);
}
}
}
}
#[cfg(not(no_cuda))]
mod gpu {
use super::AdamW;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{Device, Tensor};
impl AdamW {
/// Apply one AdamW step to every parameter `Var`, using `lr` for this step
/// (so an LR schedule can vary it). Pulls each param's value and `.grad()`
/// to the host, runs [`AdamW::step_host`], and writes the updated value
/// back with `set_value`. A param with no grad is fed a zero grad, so the
/// Adam term vanishes and only decoupled weight decay applies (the model's
/// params all receive grads each step, so this is just a safety default).
///
/// Does NOT zero grads — the caller does that (matching the GD-step
/// template in the T5 overfit test).
pub fn step(&mut self, lr: f32, params: &[Var]) {
let device = params[0].value().device();
let shapes: Vec<Vec<usize>> =
params.iter().map(|p| p.value().shape().to_vec()).collect();
let mut host_params: Vec<Vec<f32>> = params
.iter()
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
.collect();
let host_grads: Vec<Vec<f32>> = params
.iter()
.zip(&host_params)
.map(|(p, hp)| match p.grad() {
Some(g) => g.to_device(Device::Cpu).as_slice::<f32>().to_vec(),
None => vec![0.0; hp.len()], // no grad → no update this step
})
.collect();
self.step_host(lr, &mut host_params, &host_grads);
for ((p, data), shape) in params.iter().zip(&host_params).zip(&shapes) {
let t = Tensor::from_slice(data, shape).to_device(device);
p.set_value(t);
}
}
}
}

View File

@@ -0,0 +1,99 @@
// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against
// an independent, hand-rolled reference implementation of the same recurrence for
// several steps with non-trivial weight decay — catching bias-correction and
// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real
// model) lives in xtrain-train; this is the fast local guard on the formula.
use xtrain_optim::AdamW;
// Independent reference: the textbook AdamW recurrence, kept separate from the
// implementation so a shared bug can't hide.
struct RefAdamW {
b1: f32,
b2: f32,
eps: f32,
wd: f32,
t: i32,
m: Vec<f32>,
v: Vec<f32>,
}
impl RefAdamW {
fn new(n: usize, wd: f32) -> Self {
Self {
b1: 0.9,
b2: 0.999,
eps: 1e-8,
wd,
t: 0,
m: vec![0.0; n],
v: vec![0.0; n],
}
}
fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) {
self.t += 1;
let bc1 = 1.0 - self.b1.powi(self.t);
let bc2 = 1.0 - self.b2.powi(self.t);
for i in 0..p.len() {
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i];
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i];
let mhat = self.m[i] / bc1;
let vhat = self.v[i] / bc2;
p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]);
}
}
}
#[test]
fn adamw_matches_reference_recurrence() {
let lr = 0.01;
let wd = 0.1;
let mut opt = AdamW::new(lr, wd);
// Two parameters of different sizes (exercises per-param state keying).
let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]];
let mut p_ref = p_impl.clone();
let mut r0 = RefAdamW::new(4, wd);
let mut r1 = RefAdamW::new(2, wd);
// Deterministic pseudo-grads that change every step.
let grad = |step: usize, idx: usize, j: usize| -> f32 {
let s = (step * 13 + idx * 7 + j * 3) as f32;
(s * 0.123).sin() * 0.5
};
for step in 0..20 {
let grads = vec![
(0..4).map(|j| grad(step, 0, j)).collect::<Vec<_>>(),
(0..2).map(|j| grad(step, 1, j)).collect::<Vec<_>>(),
];
opt.step_host(lr, &mut p_impl, &grads);
r0.step(lr, &mut p_ref[0], &grads[0]);
r1.step(lr, &mut p_ref[1], &grads[1]);
}
assert_eq!(opt.step_count(), 20);
for (pi, pr) in p_impl.iter().zip(&p_ref) {
for (a, b) in pi.iter().zip(pr) {
assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}");
}
}
}
#[test]
fn zero_grad_only_decays() {
// With g=0 and wd>0, the step must reduce to pure decoupled decay:
// θ ← θ lr·wd·θ (Adam term is 0/eps = 0).
let lr = 0.1;
let wd = 0.5;
let mut opt = AdamW::new(lr, wd);
let mut p = vec![vec![2.0f32]];
let g = vec![vec![0.0f32]];
opt.step_host(lr, &mut p, &g);
let expected = 2.0 - lr * wd * 2.0;
assert!(
(p[0][0] - expected).abs() < 1e-6,
"{} != {expected}",
p[0][0]
);
}