optim: hand-written AdamW (decoupled weight decay + bias correction)
New xtrain-optim crate. AdamW with per-param m/v moments keyed by params() index, global bias correction, and decoupled weight decay (matches torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers, unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr argument leaves room for an LR schedule. Host unit test checks the update against an independent reference recurrence over 20 steps and the pure-decay (g=0) boundary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
9
Cargo.lock
generated
9
Cargo.lock
generated
@@ -112,6 +112,15 @@ dependencies = [
|
|||||||
"xtrain-tensor",
|
"xtrain-tensor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "xtrain-optim"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"xtrain-autodiff",
|
||||||
|
"xtrain-cuda",
|
||||||
|
"xtrain-tensor",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xtrain-tensor"
|
name = "xtrain-tensor"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ members = [
|
|||||||
"crates/xtrain-tensor",
|
"crates/xtrain-tensor",
|
||||||
"crates/xtrain-autodiff",
|
"crates/xtrain-autodiff",
|
||||||
"crates/xtrain-model",
|
"crates/xtrain-model",
|
||||||
|
"crates/xtrain-optim",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
|
|||||||
12
crates/xtrain-optim/Cargo.toml
Normal file
12
crates/xtrain-optim/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "xtrain-optim"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
xtrain-tensor = { path = "../xtrain-tensor" }
|
||||||
|
xtrain-autodiff = { path = "../xtrain-autodiff" }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
# Acceptance tests drive the GPU (device selection) directly.
|
||||||
|
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||||
26
crates/xtrain-optim/build.rs
Normal file
26
crates/xtrain-optim/build.rs
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
use std::env;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
// Per-crate convention (see the other crates): the AdamW *math* is host-only and
|
||||||
|
// always compiles, but `AdamW::step(&[Var])` round-trips parameter values/grads
|
||||||
|
// through GPU tensors, so that call site is gated behind `not(no_cuda)`. cfg does
|
||||||
|
// not propagate across crates, so this crate re-detects nvcc. No CUDA is compiled.
|
||||||
|
fn main() {
|
||||||
|
println!("cargo:rustc-check-cfg=cfg(no_cuda)");
|
||||||
|
|
||||||
|
let cuda_path = env::var("CUDA_HOME")
|
||||||
|
.or_else(|_| env::var("CUDA_PATH"))
|
||||||
|
.unwrap_or_else(|_| "/usr/local/cuda".to_string());
|
||||||
|
|
||||||
|
if !nvcc_available(&cuda_path) {
|
||||||
|
println!("cargo:rustc-cfg=no_cuda");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn nvcc_available(cuda_path: &str) -> bool {
|
||||||
|
if Command::new("nvcc").arg("--version").output().is_ok() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
Path::new(&format!("{cuda_path}/bin/nvcc")).exists()
|
||||||
|
}
|
||||||
154
crates/xtrain-optim/src/lib.rs
Normal file
154
crates/xtrain-optim/src/lib.rs
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
//! Hand-written AdamW optimizer (Phase T6).
|
||||||
|
//!
|
||||||
|
//! AdamW = Adam with **decoupled** weight decay (Loshchilov & Hutter, 2019): the
|
||||||
|
//! weight-decay term is applied directly to the parameter, NOT folded into the
|
||||||
|
//! gradient (so it does not interact with the adaptive `v` denominator). This
|
||||||
|
//! matches `torch.optim.AdamW`.
|
||||||
|
//!
|
||||||
|
//! Update for parameter `θ` at step `t` (1-indexed), with gradient `g`:
|
||||||
|
//! ```text
|
||||||
|
//! m ← β1·m + (1−β1)·g
|
||||||
|
//! v ← β2·v + (1−β2)·g²
|
||||||
|
//! m̂ ← m / (1 − β1ᵗ) (bias correction)
|
||||||
|
//! v̂ ← v / (1 − β2ᵗ)
|
||||||
|
//! θ ← θ − lr·( m̂ / (√v̂ + ε) + wd·θ )
|
||||||
|
//! ```
|
||||||
|
//! The `lr·wd·θ` term is the decoupled decay. Note PyTorch applies decay as
|
||||||
|
//! `θ ← θ·(1 − lr·wd)` then the Adam step; both are algebraically the same
|
||||||
|
//! first-order update — we fold decay into the single subtraction above, which
|
||||||
|
//! is what PyTorch's default (`maximize=False`, no `amsgrad`) computes.
|
||||||
|
//!
|
||||||
|
//! The math operates on flat host `f32` buffers ([`AdamW::step_host`]) so it is
|
||||||
|
//! unit-testable on a GPU-less host; [`AdamW::step`] is a thin wrapper that
|
||||||
|
//! round-trips each parameter's value/grad through the GPU tensor and is gated
|
||||||
|
//! behind `not(no_cuda)`.
|
||||||
|
|
||||||
|
/// Per-parameter optimizer state: the first (`m`) and second (`v`) moment
|
||||||
|
/// estimates, one f32 per element, kept flat (matching the parameter layout).
|
||||||
|
struct ParamState {
|
||||||
|
m: Vec<f32>,
|
||||||
|
v: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decoupled-weight-decay Adam. One instance owns the moment state for a fixed
|
||||||
|
/// list of parameters, keyed by their index in the slice passed to `step`
|
||||||
|
/// (the model's stable `params()` order).
|
||||||
|
pub struct AdamW {
|
||||||
|
pub lr: f32,
|
||||||
|
beta1: f32,
|
||||||
|
beta2: f32,
|
||||||
|
eps: f32,
|
||||||
|
weight_decay: f32,
|
||||||
|
/// Global step count (shared across all params for bias correction).
|
||||||
|
t: u64,
|
||||||
|
/// Lazily sized to the parameter list on the first `step`.
|
||||||
|
state: Vec<ParamState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdamW {
|
||||||
|
/// PyTorch-default hyperparameters except `lr`/`weight_decay`, which you set
|
||||||
|
/// (β1=0.9, β2=0.999, ε=1e-8).
|
||||||
|
pub fn new(lr: f32, weight_decay: f32) -> Self {
|
||||||
|
Self::with_betas(lr, weight_decay, 0.9, 0.999, 1e-8)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_betas(lr: f32, weight_decay: f32, beta1: f32, beta2: f32, eps: f32) -> Self {
|
||||||
|
Self {
|
||||||
|
lr,
|
||||||
|
beta1,
|
||||||
|
beta2,
|
||||||
|
eps,
|
||||||
|
weight_decay,
|
||||||
|
t: 0,
|
||||||
|
state: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Current global step (number of `step` calls so far).
|
||||||
|
pub fn step_count(&self) -> u64 {
|
||||||
|
self.t
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pure-host AdamW step over flat parameter/gradient buffers. `params[i]` is
|
||||||
|
/// updated in place using `grads[i]`; both are the i-th parameter's elements
|
||||||
|
/// in the model's stable order. Lazily allocates moment state on first call.
|
||||||
|
///
|
||||||
|
/// This is the testable core — no GPU, no autograd. `lr` is passed per call
|
||||||
|
/// so a schedule can vary it each step.
|
||||||
|
pub fn step_host(&mut self, lr: f32, params: &mut [Vec<f32>], grads: &[Vec<f32>]) {
|
||||||
|
assert_eq!(params.len(), grads.len(), "param/grad count mismatch");
|
||||||
|
if self.state.is_empty() {
|
||||||
|
self.state = params
|
||||||
|
.iter()
|
||||||
|
.map(|p| ParamState {
|
||||||
|
m: vec![0.0; p.len()],
|
||||||
|
v: vec![0.0; p.len()],
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
assert_eq!(self.state.len(), params.len(), "param count changed");
|
||||||
|
|
||||||
|
self.t += 1;
|
||||||
|
let bc1 = 1.0 - self.beta1.powi(self.t as i32);
|
||||||
|
let bc2 = 1.0 - self.beta2.powi(self.t as i32);
|
||||||
|
|
||||||
|
for (i, (p, g)) in params.iter_mut().zip(grads).enumerate() {
|
||||||
|
assert_eq!(p.len(), g.len(), "param/grad len mismatch at {i}");
|
||||||
|
let st = &mut self.state[i];
|
||||||
|
for j in 0..p.len() {
|
||||||
|
let gj = g[j];
|
||||||
|
st.m[j] = self.beta1 * st.m[j] + (1.0 - self.beta1) * gj;
|
||||||
|
st.v[j] = self.beta2 * st.v[j] + (1.0 - self.beta2) * gj * gj;
|
||||||
|
let mhat = st.m[j] / bc1;
|
||||||
|
let vhat = st.v[j] / bc2;
|
||||||
|
// Decoupled weight decay: decay term uses the *current* param,
|
||||||
|
// matching PyTorch's `p ← p − lr·wd·p` applied alongside the step.
|
||||||
|
p[j] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.weight_decay * p[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
mod gpu {
|
||||||
|
use super::AdamW;
|
||||||
|
use xtrain_autodiff::tape::Var;
|
||||||
|
use xtrain_tensor::{Device, Tensor};
|
||||||
|
|
||||||
|
impl AdamW {
|
||||||
|
/// Apply one AdamW step to every parameter `Var`, using `lr` for this step
|
||||||
|
/// (so an LR schedule can vary it). Pulls each param's value and `.grad()`
|
||||||
|
/// to the host, runs [`AdamW::step_host`], and writes the updated value
|
||||||
|
/// back with `set_value`. A param with no grad is fed a zero grad, so the
|
||||||
|
/// Adam term vanishes and only decoupled weight decay applies (the model's
|
||||||
|
/// params all receive grads each step, so this is just a safety default).
|
||||||
|
///
|
||||||
|
/// Does NOT zero grads — the caller does that (matching the GD-step
|
||||||
|
/// template in the T5 overfit test).
|
||||||
|
pub fn step(&mut self, lr: f32, params: &[Var]) {
|
||||||
|
let device = params[0].value().device();
|
||||||
|
let shapes: Vec<Vec<usize>> =
|
||||||
|
params.iter().map(|p| p.value().shape().to_vec()).collect();
|
||||||
|
|
||||||
|
let mut host_params: Vec<Vec<f32>> = params
|
||||||
|
.iter()
|
||||||
|
.map(|p| p.value().to_device(Device::Cpu).as_slice::<f32>().to_vec())
|
||||||
|
.collect();
|
||||||
|
let host_grads: Vec<Vec<f32>> = params
|
||||||
|
.iter()
|
||||||
|
.zip(&host_params)
|
||||||
|
.map(|(p, hp)| match p.grad() {
|
||||||
|
Some(g) => g.to_device(Device::Cpu).as_slice::<f32>().to_vec(),
|
||||||
|
None => vec![0.0; hp.len()], // no grad → no update this step
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
self.step_host(lr, &mut host_params, &host_grads);
|
||||||
|
|
||||||
|
for ((p, data), shape) in params.iter().zip(&host_params).zip(&shapes) {
|
||||||
|
let t = Tensor::from_slice(data, shape).to_device(device);
|
||||||
|
p.set_value(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
99
crates/xtrain-optim/tests/adamw_host.rs
Normal file
99
crates/xtrain-optim/tests/adamw_host.rs
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against
|
||||||
|
// an independent, hand-rolled reference implementation of the same recurrence for
|
||||||
|
// several steps with non-trivial weight decay — catching bias-correction and
|
||||||
|
// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real
|
||||||
|
// model) lives in xtrain-train; this is the fast local guard on the formula.
|
||||||
|
|
||||||
|
use xtrain_optim::AdamW;
|
||||||
|
|
||||||
|
// Independent reference: the textbook AdamW recurrence, kept separate from the
|
||||||
|
// implementation so a shared bug can't hide.
|
||||||
|
struct RefAdamW {
|
||||||
|
b1: f32,
|
||||||
|
b2: f32,
|
||||||
|
eps: f32,
|
||||||
|
wd: f32,
|
||||||
|
t: i32,
|
||||||
|
m: Vec<f32>,
|
||||||
|
v: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RefAdamW {
|
||||||
|
fn new(n: usize, wd: f32) -> Self {
|
||||||
|
Self {
|
||||||
|
b1: 0.9,
|
||||||
|
b2: 0.999,
|
||||||
|
eps: 1e-8,
|
||||||
|
wd,
|
||||||
|
t: 0,
|
||||||
|
m: vec![0.0; n],
|
||||||
|
v: vec![0.0; n],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) {
|
||||||
|
self.t += 1;
|
||||||
|
let bc1 = 1.0 - self.b1.powi(self.t);
|
||||||
|
let bc2 = 1.0 - self.b2.powi(self.t);
|
||||||
|
for i in 0..p.len() {
|
||||||
|
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i];
|
||||||
|
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i];
|
||||||
|
let mhat = self.m[i] / bc1;
|
||||||
|
let vhat = self.v[i] / bc2;
|
||||||
|
p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn adamw_matches_reference_recurrence() {
|
||||||
|
let lr = 0.01;
|
||||||
|
let wd = 0.1;
|
||||||
|
let mut opt = AdamW::new(lr, wd);
|
||||||
|
|
||||||
|
// Two parameters of different sizes (exercises per-param state keying).
|
||||||
|
let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]];
|
||||||
|
let mut p_ref = p_impl.clone();
|
||||||
|
let mut r0 = RefAdamW::new(4, wd);
|
||||||
|
let mut r1 = RefAdamW::new(2, wd);
|
||||||
|
|
||||||
|
// Deterministic pseudo-grads that change every step.
|
||||||
|
let grad = |step: usize, idx: usize, j: usize| -> f32 {
|
||||||
|
let s = (step * 13 + idx * 7 + j * 3) as f32;
|
||||||
|
(s * 0.123).sin() * 0.5
|
||||||
|
};
|
||||||
|
|
||||||
|
for step in 0..20 {
|
||||||
|
let grads = vec![
|
||||||
|
(0..4).map(|j| grad(step, 0, j)).collect::<Vec<_>>(),
|
||||||
|
(0..2).map(|j| grad(step, 1, j)).collect::<Vec<_>>(),
|
||||||
|
];
|
||||||
|
opt.step_host(lr, &mut p_impl, &grads);
|
||||||
|
r0.step(lr, &mut p_ref[0], &grads[0]);
|
||||||
|
r1.step(lr, &mut p_ref[1], &grads[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(opt.step_count(), 20);
|
||||||
|
for (pi, pr) in p_impl.iter().zip(&p_ref) {
|
||||||
|
for (a, b) in pi.iter().zip(pr) {
|
||||||
|
assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_grad_only_decays() {
|
||||||
|
// With g=0 and wd>0, the step must reduce to pure decoupled decay:
|
||||||
|
// θ ← θ − lr·wd·θ (Adam term is 0/eps = 0).
|
||||||
|
let lr = 0.1;
|
||||||
|
let wd = 0.5;
|
||||||
|
let mut opt = AdamW::new(lr, wd);
|
||||||
|
let mut p = vec![vec![2.0f32]];
|
||||||
|
let g = vec![vec![0.0f32]];
|
||||||
|
opt.step_host(lr, &mut p, &g);
|
||||||
|
let expected = 2.0 - lr * wd * 2.0;
|
||||||
|
assert!(
|
||||||
|
(p[0][0] - expected).abs() < 1e-6,
|
||||||
|
"{} != {expected}",
|
||||||
|
p[0][0]
|
||||||
|
);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user