test: AdamW PyTorch parity + checkpoint round-trip + real training

Acceptance tests (GPU-gated not(no_cuda), run on dash5):
- adamw_parity_dump.rs + adamw_parity.py: build the tiny model with fixed init,
  run N AdamW steps on a fixed batch, dump the loss trajectory + final params;
  the Python side rebuilds the identical model and runs torch.optim.AdamW with
  matched lr/wd/betas/eps, comparing trajectory + final params within rtol.
- checkpoint_roundtrip.rs: train a few steps, save, load into a fresh model with
  a DIFFERENT init, assert identical logits/loss on a fixed input.
- real_training.rs (#[ignore], --release): train on TinyStories for a bounded
  budget; assert loss drops substantially and print greedy samples.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:30:06 +08:00
parent 77a82bfeee
commit 22b7434b23
4 changed files with 579 additions and 0 deletions

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python3
"""AdamW-vs-PyTorch parity (Phase T6).
Loads the model dumped by tests/adamw_parity_dump.rs (config, ids, initial
params, the loss trajectory, and final params), rebuilds the IDENTICAL tiny
transformer in PyTorch from the same initial weights, and runs the SAME number
of `torch.optim.AdamW` steps with matched hyperparameters (lr, weight_decay,
betas, eps) on the same fixed batch. It then compares:
* the per-step loss trajectory (Rust AdamW vs torch AdamW), and
* the final parameters,
within a relative tolerance. A correct hand-written AdamW (bias correction +
decoupled weight decay) tracks torch's optimizer step-for-step.
Usage: python3 adamw_parity.py /tmp/xtrain_adamw
"""
import sys
import os
import math
import torch
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_adamw"
def read_vec(name):
path = os.path.join(DIR, name)
shape = None
vals = []
with open(path) as f:
for line in f:
line = line.strip()
if line.startswith("# shape"):
shape = [int(x) for x in line.split()[2].split(",") if x]
elif line:
vals.append(float(line))
t = torch.tensor(vals, dtype=torch.float64)
if shape:
t = t.reshape(shape)
return t
def read_cfg():
cfg = {}
with open(os.path.join(DIR, "config.txt")) as f:
for line in f:
k, v = line.split()
cfg[k] = v
return cfg
def read_ids(name):
with open(os.path.join(DIR, name)) as f:
return [int(x) for x in f.read().split()]
cfg = read_cfg()
DIM = int(cfg["dim"])
NL = int(cfg["n_layers"])
NH = int(cfg["n_heads"])
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
LR = float(cfg["lr"])
WD = float(cfg["wd"])
N_STEPS = int(cfg["n_steps"])
ids = read_ids("ids.txt")
targets = read_ids("targets.txt")
SEQ = len(ids)
NAMES = ["embed"]
for l in range(NL):
for p in ["attn_norm", "wq", "wk", "wv", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]:
NAMES.append(f"l{l}_{p}")
NAMES += ["final_norm", "lm_head"]
# Load the IDENTICAL initial weights as leaf params (float64 reference).
P = {n: read_vec(f"w0_{n}.txt").clone().requires_grad_(True) for n in NAMES}
def rms_norm(x, gamma):
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + EPS) * gamma
def rope(x): # x: [seq, nh, hd], position = token index
half = HD // 2
out = torch.empty_like(x)
i = torch.arange(half, dtype=torch.float64)
freq = THETA ** (-(2.0 * i) / HD)
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1)
ang = pos * freq
c = torch.cos(ang).reshape(SEQ, 1, half)
s = torch.sin(ang).reshape(SEQ, 1, half)
x0, x1 = x[..., :half], x[..., half:]
out[..., :half] = x0 * c - x1 * s
out[..., half:] = x1 * c + x0 * s
return out
idx = torch.tensor(ids, dtype=torch.long)
tgt = torch.tensor(targets, dtype=torch.long)
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
def forward():
h = P["embed"][idx]
for l in range(NL):
x = rms_norm(h, P[f"l{l}_attn_norm"])
q = (x @ P[f"l{l}_wq"]).reshape(SEQ, NH, HD)
k = (x @ P[f"l{l}_wk"]).reshape(SEQ, NH, HD)
v = (x @ P[f"l{l}_wv"]).reshape(SEQ, NH, HD)
q = rope(q).transpose(0, 1)
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask
probs = torch.softmax(scores, dim=-1)
out = (probs @ v).transpose(0, 1).reshape(SEQ, DIM)
h = h + out @ P[f"l{l}_wo"]
x = rms_norm(h, P[f"l{l}_ffn_norm"])
act = torch.nn.functional.silu(x @ P[f"l{l}_w_gate"]) * (x @ P[f"l{l}_w_up"])
h = h + act @ P[f"l{l}_w_down"]
h = rms_norm(h, P["final_norm"])
return h @ P["lm_head"]
# Match the Rust optimizer: torch.optim.AdamW with the same lr/wd/betas/eps.
opt = torch.optim.AdamW(list(P.values()), lr=LR, betas=(0.9, 0.999),
eps=1e-8, weight_decay=WD)
torch_losses = []
for _ in range(N_STEPS):
opt.zero_grad()
logits = forward()
loss = torch.nn.functional.cross_entropy(logits, tgt, reduction="mean")
torch_losses.append(loss.item())
loss.backward()
opt.step()
def relerr(a, b):
a, b = a.double(), b.double()
denom = b.abs().clamp(min=1e-6)
return ((a - b).abs() / denom).max().item()
rust_losses = read_vec("losses.txt")
print("step rust_loss torch_loss relerr")
worst_loss = 0.0
for i in range(N_STEPS):
rl, tl = rust_losses[i].item(), torch_losses[i]
e = abs(rl - tl) / max(abs(tl), 1e-6)
worst_loss = max(worst_loss, e)
if i < 5 or i == N_STEPS - 1:
print(f"{i:4d} {rl:.6e} {tl:.6e} {e:.2e}")
print(f"loss trajectory: worst relerr = {worst_loss:.2e}")
RTOL = 2e-2
worst_p, worst_name = 0.0, ""
fails = []
for n in NAMES:
ref = read_vec(f"wN_{n}.txt")
e = relerr(P[n].detach(), ref)
if e > worst_p:
worst_p, worst_name = e, n
if e > RTOL:
fails.append((n, e))
print(f"final params: {len(NAMES)} checked, worst = {worst_name} @ {worst_p:.2e} (rtol={RTOL})")
if worst_loss > RTOL or fails:
print("FAIL:")
if worst_loss > RTOL:
print(f" loss trajectory relerr {worst_loss:.3e} > {RTOL}")
for n, e in fails:
print(f" param[{n}]: relerr={e:.3e}")
sys.exit(1)
print("ADAMW PARITY OK: loss trajectory + final params match torch.optim.AdamW within rtol")

View File

@@ -0,0 +1,165 @@
// AdamW-vs-PyTorch parity, step 1 of 2: build the tiny transformer with a fixed
// deterministic init, then run N steps of the hand-written AdamW on a FIXED
// (input, target) batch — recording the loss at each step and the final
// parameters. tests/adamw_parity.py rebuilds the identical model + torch.optim
// .AdamW with matched hyperparameters and compares the loss trajectory and final
// params within rtol. This is the rigorous correctness check for the optimizer.
//
// Run: XTRAIN_ADAMW_DIR=/tmp/xtrain_adamw cargo test -p xtrain-train \
// --test adamw_parity_dump -- --nocapture --ignored
// then: python3 crates/xtrain-train/tests/adamw_parity.py /tmp/xtrain_adamw
//
// Marked #[ignore] (fixture generator) and gated #![cfg(not(no_cuda))].
#![cfg(not(no_cuda))]
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
use xtrain_optim::AdamW;
use xtrain_tensor::Device;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
let mut f = fs::File::create(dir.join(name)).unwrap();
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
for v in data {
writeln!(f, "{v:.8e}").unwrap();
}
}
const LR: f32 = 0.01;
const WD: f32 = 0.1;
const N_STEPS: usize = 30;
#[test]
#[ignore = "fixture generator for AdamW PyTorch parity; run with --ignored"]
fn dump_adamw_trajectory() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let dir = PathBuf::from(
std::env::var("XTRAIN_ADAMW_DIR").unwrap_or_else(|_| "/tmp/xtrain_adamw".to_string()),
);
fs::create_dir_all(&dir).unwrap();
let mut cfg = Config::tiny();
cfg.vocab = 12;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
// Same deterministic init the parity dump uses (so the torch side can reuse it).
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
// Dump config + ids + initial params (named for adamw_parity.py).
{
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
writeln!(f, "vocab {}", cfg.vocab).unwrap();
writeln!(f, "dim {}", cfg.dim).unwrap();
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
writeln!(f, "eps {:e}", cfg.eps).unwrap();
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
writeln!(f, "lr {LR:e}").unwrap();
writeln!(f, "wd {WD:e}").unwrap();
writeln!(f, "n_steps {N_STEPS}").unwrap();
let mut g = fs::File::create(dir.join("ids.txt")).unwrap();
for v in &ids {
writeln!(g, "{v}").unwrap();
}
let mut g = fs::File::create(dir.join("targets.txt")).unwrap();
for v in &targets {
writeln!(g, "{v}").unwrap();
}
}
let names = param_names(&cfg);
let params = model.params();
for (name, p) in names.iter().zip(&params) {
let shape = p.value().shape().to_vec();
write_vec(&dir, &format!("w0_{name}.txt"), &param_to_host(p), &shape);
}
// Train N steps of AdamW with a CONSTANT lr (no schedule) on the fixed batch.
let ids_t = ids_tensor(&ids, device);
let targets_t = ids_tensor(&targets, device);
let mut opt = AdamW::new(LR, WD);
let mut losses = Vec::with_capacity(N_STEPS);
for _ in 0..N_STEPS {
let loss = model.loss(&ids_t, &targets_t);
losses.push(param_to_host(&loss)[0]);
loss.backward();
opt.step(LR, &params);
for p in &params {
p.zero_grad();
}
}
{
let mut f = fs::File::create(dir.join("losses.txt")).unwrap();
for l in &losses {
writeln!(f, "{l:.8e}").unwrap();
}
}
for (name, p) in names.iter().zip(&params) {
let shape = p.value().shape().to_vec();
write_vec(&dir, &format!("wN_{name}.txt"), &param_to_host(p), &shape);
}
println!(
"adamw parity: dumped to {} (loss {:.6e}{:.6e} over {N_STEPS} steps)",
dir.display(),
losses.first().unwrap(),
losses.last().unwrap()
);
}
fn param_names(cfg: &Config) -> Vec<String> {
let mut names = vec!["embed".to_string()];
for l in 0..cfg.n_layers {
for p in [
"attn_norm",
"wq",
"wk",
"wv",
"wo",
"ffn_norm",
"w_gate",
"w_up",
"w_down",
] {
names.push(format!("l{l}_{p}"));
}
}
names.push("final_norm".to_string());
names.push("lm_head".to_string());
names
}

View File

@@ -0,0 +1,113 @@
// Checkpoint round-trip acceptance (Phase T6): train a few AdamW steps on a fixed
// batch, save the params, build a FRESH model (different init), load the
// checkpoint into it, and assert it produces identical logits + loss on a fixed
// input. This verifies the on-disk format dumps/reloads `params()` in order with
// exact f32 fidelity. Gated #![cfg(not(no_cuda))] (runs on dash5).
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
use xtrain_optim::AdamW;
use xtrain_tensor::Device;
use xtrain_train::checkpoint;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn make_model(device: Device, vocab: usize, init_seed: u64) -> TinyTransformer {
let mut cfg = Config::tiny();
cfg.vocab = vocab;
let mut seed = init_seed;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
#[test]
fn checkpoint_roundtrip_identical_logits() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let vocab = 12;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
let ids_t = ids_tensor(&ids, device);
let targets_t = ids_tensor(&targets, device);
// --- Train a few steps so the params are non-trivial (not the init). ---
let model = make_model(device, vocab, 1);
let params = model.params();
let mut opt = AdamW::new(0.01, 0.1);
for _ in 0..5 {
let loss = model.loss(&ids_t, &targets_t);
loss.backward();
opt.step(0.01, &params);
for p in &params {
p.zero_grad();
}
}
let path = std::env::temp_dir().join(format!("xtrain_ckpt_{}.bin", std::process::id()));
checkpoint::save(&path, &params).unwrap();
let ref_logits = param_to_host(&model.forward(&ids_t));
let ref_loss = param_to_host(&model.loss(&ids_t, &targets_t))[0];
// --- Fresh model with a DIFFERENT init; loading must overwrite it exactly. ---
let fresh = make_model(device, vocab, 999);
let fresh_params = fresh.params();
// Sanity: before load, the fresh model disagrees.
let pre = param_to_host(&fresh.forward(&ids_t));
let pre_diff: f32 = pre
.iter()
.zip(&ref_logits)
.map(|(a, b)| (a - b).abs())
.fold(0.0, f32::max);
assert!(
pre_diff > 1e-4,
"fresh model unexpectedly matched before load"
);
checkpoint::load_into(&path, &fresh_params).unwrap();
let got_logits = param_to_host(&fresh.forward(&ids_t));
let got_loss = param_to_host(&fresh.loss(&ids_t, &targets_t))[0];
let _ = std::fs::remove_file(&path);
// Exact f32 round-trip → bit-for-bit identical forward (same kernels, same
// inputs). Allow only float noise from re-running the forward.
let max_logit_diff: f32 = got_logits
.iter()
.zip(&ref_logits)
.map(|(a, b)| (a - b).abs())
.fold(0.0, f32::max);
println!(
"checkpoint round-trip: max logit diff = {max_logit_diff:.3e}, loss {ref_loss:.6} vs {got_loss:.6}"
);
assert!(
max_logit_diff < 1e-5,
"logits differ after reload: {max_logit_diff:e}"
);
assert!(
(got_loss - ref_loss).abs() < 1e-5,
"loss differs after reload: {ref_loss} vs {got_loss}"
);
}

View File

@@ -0,0 +1,121 @@
// Real-training acceptance (Phase T6): train the tiny transformer on the
// TinyStories corpus (tokenized with the reused GPT-2 BPE) for a BOUNDED budget
// and assert the loss decreases substantially — the end-to-end signal that the
// whole stack (data pipeline, AdamW, LR schedule, grad clip) learns. Prints the
// loss curve and a couple of greedy samples.
//
// Needs the corpus + tokenizer present, so it is #[ignore] (run with --ignored)
// and gated #![cfg(not(no_cuda))]. Paths are overridable via env vars.
//
// Run: cargo test -p xtrain-train --release --test real_training \
// -- --ignored --nocapture
#![cfg(not(no_cuda))]
use std::path::PathBuf;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer};
use xtrain_tensor::Device;
use xtrain_train::data::Corpus;
use xtrain_train::sample::generate;
use xtrain_train::schedule::LrSchedule;
use xtrain_train::{TrainConfig, train};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[test]
#[ignore = "real training; needs corpus + tokenizer; run with --ignored --release"]
fn trains_on_tinystories() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok_path = PathBuf::from(
std::env::var("XTRAIN_TOKENIZER")
.unwrap_or_else(|_| "/opt/wjh/models/gpt2/tokenizer.json".into()),
);
let corpus_path = PathBuf::from(
std::env::var("XTRAIN_CORPUS").unwrap_or_else(|_| "data/tinystories-valid-3mb.txt".into()),
);
let corpus = Corpus::load(&tok_path, &corpus_path);
println!(
"corpus: {} tokens, vocab {}",
corpus.len(),
corpus.vocab_size
);
let mut cfg = Config::tiny();
cfg.vocab = corpus.vocab_size;
cfg.n_layers = 4;
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
});
let steps = std::env::var("XTRAIN_STEPS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(800usize);
let tcfg = TrainConfig {
seq_len: 64,
batch_size: 8,
steps,
schedule: LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
warmup: (steps / 20).max(20),
total: steps,
},
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 50,
ckpt_path: None,
ckpt_every: 0,
seed: 42,
};
let losses = train(&model, device, &corpus, &tcfg);
// Average the first/last few steps to smooth per-step noise.
let head: f32 =
losses[..10.min(losses.len())].iter().sum::<f32>() / 10.0_f32.min(losses.len() as f32);
let tail_n = 10.min(losses.len());
let tail: f32 = losses[losses.len() - tail_n..].iter().sum::<f32>() / tail_n as f32;
println!("loss: start(avg10) {head:.4} → end(avg10) {tail:.4}");
// A couple of greedy samples (should show English structure, not gibberish).
use xserv_tokenizer::Tokenizer;
let tok = Tokenizer::from_file(&tok_path);
for p in ["Once upon a time", "The little"] {
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
let mut rng = 7u64;
let out = generate(&model, device, &ids, 40, 0.0, &mut rng);
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
println!("sample [{p}] → {text}");
}
// Bounded run: expect a substantial drop (not full convergence).
assert!(
tail < head - 0.5,
"loss did not decrease substantially: {head:.4} → {tail:.4}"
);
assert!(tail < 6.5, "final loss implausibly high: {tail:.4}");
}