Files
xtrain/crates/xtrain-optim/tests/adamw_host.rs
Gahow Wang f22429f5b8 optim: hand-written AdamW (decoupled weight decay + bias correction)
New xtrain-optim crate. AdamW with per-param m/v moments keyed by params()
index, global bias correction, and decoupled weight decay (matches
torch.optim.AdamW). Split into a pure-host step_host (flat f32 buffers,
unit-testable on a GPU-less host) and a step(&[Var]) wrapper that round-trips
each param value/grad through the GPU tensor (gated not(no_cuda)). Per-step lr
argument leaves room for an LR schedule.

Host unit test checks the update against an independent reference recurrence
over 20 steps and the pure-decay (g=0) boundary.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:28:23 +08:00

100 lines
3.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Host-only unit test for the AdamW *math* (no GPU). Verifies the update against
// an independent, hand-rolled reference implementation of the same recurrence for
// several steps with non-trivial weight decay — catching bias-correction and
// decoupled-decay mistakes. The rigorous vs-PyTorch parity (end-to-end on a real
// model) lives in xtrain-train; this is the fast local guard on the formula.
use xtrain_optim::AdamW;
// Independent reference: the textbook AdamW recurrence, kept separate from the
// implementation so a shared bug can't hide.
struct RefAdamW {
b1: f32,
b2: f32,
eps: f32,
wd: f32,
t: i32,
m: Vec<f32>,
v: Vec<f32>,
}
impl RefAdamW {
fn new(n: usize, wd: f32) -> Self {
Self {
b1: 0.9,
b2: 0.999,
eps: 1e-8,
wd,
t: 0,
m: vec![0.0; n],
v: vec![0.0; n],
}
}
fn step(&mut self, lr: f32, p: &mut [f32], g: &[f32]) {
self.t += 1;
let bc1 = 1.0 - self.b1.powi(self.t);
let bc2 = 1.0 - self.b2.powi(self.t);
for i in 0..p.len() {
self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * g[i];
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * g[i] * g[i];
let mhat = self.m[i] / bc1;
let vhat = self.v[i] / bc2;
p[i] -= lr * (mhat / (vhat.sqrt() + self.eps) + self.wd * p[i]);
}
}
}
#[test]
fn adamw_matches_reference_recurrence() {
let lr = 0.01;
let wd = 0.1;
let mut opt = AdamW::new(lr, wd);
// Two parameters of different sizes (exercises per-param state keying).
let mut p_impl = vec![vec![0.5f32, -1.0, 2.0, 0.0], vec![1.5f32, -0.25]];
let mut p_ref = p_impl.clone();
let mut r0 = RefAdamW::new(4, wd);
let mut r1 = RefAdamW::new(2, wd);
// Deterministic pseudo-grads that change every step.
let grad = |step: usize, idx: usize, j: usize| -> f32 {
let s = (step * 13 + idx * 7 + j * 3) as f32;
(s * 0.123).sin() * 0.5
};
for step in 0..20 {
let grads = vec![
(0..4).map(|j| grad(step, 0, j)).collect::<Vec<_>>(),
(0..2).map(|j| grad(step, 1, j)).collect::<Vec<_>>(),
];
opt.step_host(lr, &mut p_impl, &grads);
r0.step(lr, &mut p_ref[0], &grads[0]);
r1.step(lr, &mut p_ref[1], &grads[1]);
}
assert_eq!(opt.step_count(), 20);
for (pi, pr) in p_impl.iter().zip(&p_ref) {
for (a, b) in pi.iter().zip(pr) {
assert!((a - b).abs() < 1e-6, "impl {a} != ref {b}");
}
}
}
#[test]
fn zero_grad_only_decays() {
// With g=0 and wd>0, the step must reduce to pure decoupled decay:
// θ ← θ lr·wd·θ (Adam term is 0/eps = 0).
let lr = 0.1;
let wd = 0.5;
let mut opt = AdamW::new(lr, wd);
let mut p = vec![vec![2.0f32]];
let g = vec![vec![0.0f32]];
opt.step_host(lr, &mut p, &g);
let expected = 2.0 - lr * wd * 2.0;
assert!(
(p[0][0] - expected).abs() < 1e-6,
"{} != {expected}",
p[0][0]
);
}