New xtrain-optim crate. AdamW with per-param m/v moments keyed by params() index, global bias correction, and decoupled weight decay (matches torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers, unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr argument leaves room for an LR schedule. Host unit test checks the update against an independent reference recurrence over 20 steps and the pure-decay (g=0) boundary. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
100 lines
3.0 KiB
Rust
100 lines
3.0 KiB
Rust
// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against
|
||
// an independent, hand-rolled reference implementation of the same recurrence for
|
||
// several steps with non-trivial weight decay — catching bias-correction and
|
||
// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real
|
||
// model) lives in xtrain-train; this is the fast local guard on the formula.
|
||
|
||
use xtrain_optim::AdamW;
|
||
|
||
// Independent reference: the textbook AdamW recurrence, kept separate from the
|
||
// implementation so a shared bug can't hide.
|
||
struct RefAdamW {
|
||
b1: f32,
|
||
b2: f32,
|
||
eps: f32,
|
||
wd: f32,
|
||
t: i32,
|
||
m: Vec<f32>,
|
||
v: Vec<f32>,
|
||
}
|
||
|
||
impl RefAdamW {
|
||
fn new(n: usize, wd: f32) -> Self {
|
||
Self {
|
||
b1: 0.9,
|
||
b2: 0.999,
|
||
eps: 1e-8,
|
||
wd,
|
||
t: 0,
|
||
m: vec![0.0; n],
|
||
v: vec![0.0; n],
|
||
}
|
||
}
|
||
fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) {
|
||
self.t += 1;
|
||
let bc1 = 1.0 - self.b1.powi(self.t);
|
||
let bc2 = 1.0 - self.b2.powi(self.t);
|
||
for i in 0..p.len() {
|
||
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i];
|
||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i];
|
||
let mhat = self.m[i] / bc1;
|
||
let vhat = self.v[i] / bc2;
|
||
p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]);
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn adamw_matches_reference_recurrence() {
|
||
let lr = 0.01;
|
||
let wd = 0.1;
|
||
let mut opt = AdamW::new(lr, wd);
|
||
|
||
// Two parameters of different sizes (exercises per-param state keying).
|
||
let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]];
|
||
let mut p_ref = p_impl.clone();
|
||
let mut r0 = RefAdamW::new(4, wd);
|
||
let mut r1 = RefAdamW::new(2, wd);
|
||
|
||
// Deterministic pseudo-grads that change every step.
|
||
let grad = |step: usize, idx: usize, j: usize| -> f32 {
|
||
let s = (step * 13 + idx * 7 + j * 3) as f32;
|
||
(s * 0.123).sin() * 0.5
|
||
};
|
||
|
||
for step in 0..20 {
|
||
let grads = vec![
|
||
(0..4).map(|j| grad(step, 0, j)).collect::<Vec<_>>(),
|
||
(0..2).map(|j| grad(step, 1, j)).collect::<Vec<_>>(),
|
||
];
|
||
opt.step_host(lr, &mut p_impl, &grads);
|
||
r0.step(lr, &mut p_ref[0], &grads[0]);
|
||
r1.step(lr, &mut p_ref[1], &grads[1]);
|
||
}
|
||
|
||
assert_eq!(opt.step_count(), 20);
|
||
for (pi, pr) in p_impl.iter().zip(&p_ref) {
|
||
for (a, b) in pi.iter().zip(pr) {
|
||
assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}");
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn zero_grad_only_decays() {
|
||
// With g=0 and wd>0, the step must reduce to pure decoupled decay:
|
||
// θ ← θ − lr·wd·θ (Adam term is 0/eps = 0).
|
||
let lr = 0.1;
|
||
let wd = 0.5;
|
||
let mut opt = AdamW::new(lr, wd);
|
||
let mut p = vec![vec![2.0f32]];
|
||
let g = vec![vec![0.0f32]];
|
||
opt.step_host(lr, &mut p, &g);
|
||
let expected = 2.0 - lr * wd * 2.0;
|
||
assert!(
|
||
(p[0][0] - expected).abs() < 1e-6,
|
||
"{} != {expected}",
|
||
p[0][0]
|
||
);
|
||
}
|