test: AdamW PyTorch parity + checkpoint round-trip + real training
Acceptance tests (GPU-gated not(no_cuda), run on dash5): - adamw_parity_dump.rs + adamw_parity.py: build the tiny model with fixed init, run N AdamW steps on a fixed batch, dump the loss trajectory + final params; the Python side rebuilds the identical model and runs torch.optim.AdamW with matched lr/wd/betas/eps, comparing trajectory + final params within rtol. - checkpoint_roundtrip.rs: train a few steps, save, load into a fresh model with a DIFFERENT init, assert identical logits/loss on a fixed input. - real_training.rs (#[ignore], --release): train on TinyStories for a bounded budget; assert loss drops substantially and print greedy samples. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
180
crates/xtrain-train/tests/adamw_parity.py
Normal file
180
crates/xtrain-train/tests/adamw_parity.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
"""AdamW-vs-PyTorch parity (Phase T6).
|
||||
|
||||
Loads the model dumped by tests/adamw_parity_dump.rs (config, ids, initial
|
||||
params, the loss trajectory, and final params), rebuilds the IDENTICAL tiny
|
||||
transformer in PyTorch from the same initial weights, and runs the SAME number
|
||||
of `torch.optim.AdamW` steps with matched hyperparameters (lr, weight_decay,
|
||||
betas, eps) on the same fixed batch. It then compares:
|
||||
|
||||
* the per-step loss trajectory (Rust AdamW vs torch AdamW), and
|
||||
* the final parameters,
|
||||
|
||||
within a relative tolerance. A correct hand-written AdamW (bias correction +
|
||||
decoupled weight decay) tracks torch's optimizer step-for-step.
|
||||
|
||||
Usage: python3 adamw_parity.py /tmp/xtrain_adamw
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
|
||||
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_adamw"
|
||||
|
||||
|
||||
def read_vec(name):
|
||||
path = os.path.join(DIR, name)
|
||||
shape = None
|
||||
vals = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("# shape"):
|
||||
shape = [int(x) for x in line.split()[2].split(",") if x]
|
||||
elif line:
|
||||
vals.append(float(line))
|
||||
t = torch.tensor(vals, dtype=torch.float64)
|
||||
if shape:
|
||||
t = t.reshape(shape)
|
||||
return t
|
||||
|
||||
|
||||
def read_cfg():
|
||||
cfg = {}
|
||||
with open(os.path.join(DIR, "config.txt")) as f:
|
||||
for line in f:
|
||||
k, v = line.split()
|
||||
cfg[k] = v
|
||||
return cfg
|
||||
|
||||
|
||||
def read_ids(name):
|
||||
with open(os.path.join(DIR, name)) as f:
|
||||
return [int(x) for x in f.read().split()]
|
||||
|
||||
|
||||
cfg = read_cfg()
|
||||
DIM = int(cfg["dim"])
|
||||
NL = int(cfg["n_layers"])
|
||||
NH = int(cfg["n_heads"])
|
||||
HD = int(cfg["head_dim"])
|
||||
EPS = float(cfg["eps"])
|
||||
THETA = float(cfg["rope_theta"])
|
||||
LR = float(cfg["lr"])
|
||||
WD = float(cfg["wd"])
|
||||
N_STEPS = int(cfg["n_steps"])
|
||||
|
||||
ids = read_ids("ids.txt")
|
||||
targets = read_ids("targets.txt")
|
||||
SEQ = len(ids)
|
||||
|
||||
NAMES = ["embed"]
|
||||
for l in range(NL):
|
||||
for p in ["attn_norm", "wq", "wk", "wv", "wo",
|
||||
"ffn_norm", "w_gate", "w_up", "w_down"]:
|
||||
NAMES.append(f"l{l}_{p}")
|
||||
NAMES += ["final_norm", "lm_head"]
|
||||
|
||||
# Load the IDENTICAL initial weights as leaf params (float64 reference).
|
||||
P = {n: read_vec(f"w0_{n}.txt").clone().requires_grad_(True) for n in NAMES}
|
||||
|
||||
|
||||
def rms_norm(x, gamma):
|
||||
ms = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
return x * torch.rsqrt(ms + EPS) * gamma
|
||||
|
||||
|
||||
def rope(x): # x: [seq, nh, hd], position = token index
|
||||
half = HD // 2
|
||||
out = torch.empty_like(x)
|
||||
i = torch.arange(half, dtype=torch.float64)
|
||||
freq = THETA ** (-(2.0 * i) / HD)
|
||||
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1)
|
||||
ang = pos * freq
|
||||
c = torch.cos(ang).reshape(SEQ, 1, half)
|
||||
s = torch.sin(ang).reshape(SEQ, 1, half)
|
||||
x0, x1 = x[..., :half], x[..., half:]
|
||||
out[..., :half] = x0 * c - x1 * s
|
||||
out[..., half:] = x1 * c + x0 * s
|
||||
return out
|
||||
|
||||
|
||||
idx = torch.tensor(ids, dtype=torch.long)
|
||||
tgt = torch.tensor(targets, dtype=torch.long)
|
||||
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
|
||||
|
||||
|
||||
def forward():
|
||||
h = P["embed"][idx]
|
||||
for l in range(NL):
|
||||
x = rms_norm(h, P[f"l{l}_attn_norm"])
|
||||
q = (x @ P[f"l{l}_wq"]).reshape(SEQ, NH, HD)
|
||||
k = (x @ P[f"l{l}_wk"]).reshape(SEQ, NH, HD)
|
||||
v = (x @ P[f"l{l}_wv"]).reshape(SEQ, NH, HD)
|
||||
q = rope(q).transpose(0, 1)
|
||||
k = rope(k).transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
scores = (q @ k.transpose(-1, -2)) * scale + mask
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
out = (probs @ v).transpose(0, 1).reshape(SEQ, DIM)
|
||||
h = h + out @ P[f"l{l}_wo"]
|
||||
x = rms_norm(h, P[f"l{l}_ffn_norm"])
|
||||
act = torch.nn.functional.silu(x @ P[f"l{l}_w_gate"]) * (x @ P[f"l{l}_w_up"])
|
||||
h = h + act @ P[f"l{l}_w_down"]
|
||||
h = rms_norm(h, P["final_norm"])
|
||||
return h @ P["lm_head"]
|
||||
|
||||
|
||||
# Match the Rust optimizer: torch.optim.AdamW with the same lr/wd/betas/eps.
|
||||
opt = torch.optim.AdamW(list(P.values()), lr=LR, betas=(0.9, 0.999),
|
||||
eps=1e-8, weight_decay=WD)
|
||||
|
||||
torch_losses = []
|
||||
for _ in range(N_STEPS):
|
||||
opt.zero_grad()
|
||||
logits = forward()
|
||||
loss = torch.nn.functional.cross_entropy(logits, tgt, reduction="mean")
|
||||
torch_losses.append(loss.item())
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
|
||||
def relerr(a, b):
|
||||
a, b = a.double(), b.double()
|
||||
denom = b.abs().clamp(min=1e-6)
|
||||
return ((a - b).abs() / denom).max().item()
|
||||
|
||||
|
||||
rust_losses = read_vec("losses.txt")
|
||||
print("step rust_loss torch_loss relerr")
|
||||
worst_loss = 0.0
|
||||
for i in range(N_STEPS):
|
||||
rl, tl = rust_losses[i].item(), torch_losses[i]
|
||||
e = abs(rl - tl) / max(abs(tl), 1e-6)
|
||||
worst_loss = max(worst_loss, e)
|
||||
if i < 5 or i == N_STEPS - 1:
|
||||
print(f"{i:4d} {rl:.6e} {tl:.6e} {e:.2e}")
|
||||
print(f"loss trajectory: worst relerr = {worst_loss:.2e}")
|
||||
|
||||
RTOL = 2e-2
|
||||
worst_p, worst_name = 0.0, ""
|
||||
fails = []
|
||||
for n in NAMES:
|
||||
ref = read_vec(f"wN_{n}.txt")
|
||||
e = relerr(P[n].detach(), ref)
|
||||
if e > worst_p:
|
||||
worst_p, worst_name = e, n
|
||||
if e > RTOL:
|
||||
fails.append((n, e))
|
||||
print(f"final params: {len(NAMES)} checked, worst = {worst_name} @ {worst_p:.2e} (rtol={RTOL})")
|
||||
|
||||
if worst_loss > RTOL or fails:
|
||||
print("FAIL:")
|
||||
if worst_loss > RTOL:
|
||||
print(f" loss trajectory relerr {worst_loss:.3e} > {RTOL}")
|
||||
for n, e in fails:
|
||||
print(f" param[{n}]: relerr={e:.3e}")
|
||||
sys.exit(1)
|
||||
print("ADAMW PARITY OK: loss trajectory + final params match torch.optim.AdamW within rtol")
|
||||
165
crates/xtrain-train/tests/adamw_parity_dump.rs
Normal file
165
crates/xtrain-train/tests/adamw_parity_dump.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
// AdamW-vs-PyTorch parity, step 1 of 2: build the tiny transformer with a fixed
|
||||
// deterministic init, then run N steps of the hand-written AdamW on a FIXED
|
||||
// (input, target) batch — recording the loss at each step and the final
|
||||
// parameters. tests/adamw_parity.py rebuilds the identical model + torch.optim
|
||||
// .AdamW with matched hyperparameters and compares the loss trajectory and final
|
||||
// params within rtol. This is the rigorous correctness check for the optimizer.
|
||||
//
|
||||
// Run: XTRAIN_ADAMW_DIR=/tmp/xtrain_adamw cargo test -p xtrain-train \
|
||||
// --test adamw_parity_dump -- --nocapture --ignored
|
||||
// then: python3 crates/xtrain-train/tests/adamw_parity.py /tmp/xtrain_adamw
|
||||
//
|
||||
// Marked #[ignore] (fixture generator) and gated #![cfg(not(no_cuda))].
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
|
||||
use xtrain_optim::AdamW;
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
|
||||
let mut f = fs::File::create(dir.join(name)).unwrap();
|
||||
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
|
||||
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
|
||||
for v in data {
|
||||
writeln!(f, "{v:.8e}").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
const LR: f32 = 0.01;
|
||||
const WD: f32 = 0.1;
|
||||
const N_STEPS: usize = 30;
|
||||
|
||||
#[test]
|
||||
#[ignore = "fixture generator for AdamW PyTorch parity; run with --ignored"]
|
||||
fn dump_adamw_trajectory() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let dir = PathBuf::from(
|
||||
std::env::var("XTRAIN_ADAMW_DIR").unwrap_or_else(|_| "/tmp/xtrain_adamw".to_string()),
|
||||
);
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 12;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
|
||||
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
||||
|
||||
// Same deterministic init the parity dump uses (so the torch side can reuse it).
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
|
||||
// Dump config + ids + initial params (named for adamw_parity.py).
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
|
||||
writeln!(f, "vocab {}", cfg.vocab).unwrap();
|
||||
writeln!(f, "dim {}", cfg.dim).unwrap();
|
||||
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
|
||||
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
|
||||
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
|
||||
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
||||
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
||||
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
|
||||
writeln!(f, "lr {LR:e}").unwrap();
|
||||
writeln!(f, "wd {WD:e}").unwrap();
|
||||
writeln!(f, "n_steps {N_STEPS}").unwrap();
|
||||
let mut g = fs::File::create(dir.join("ids.txt")).unwrap();
|
||||
for v in &ids {
|
||||
writeln!(g, "{v}").unwrap();
|
||||
}
|
||||
let mut g = fs::File::create(dir.join("targets.txt")).unwrap();
|
||||
for v in &targets {
|
||||
writeln!(g, "{v}").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let names = param_names(&cfg);
|
||||
let params = model.params();
|
||||
for (name, p) in names.iter().zip(¶ms) {
|
||||
let shape = p.value().shape().to_vec();
|
||||
write_vec(&dir, &format!("w0_{name}.txt"), ¶m_to_host(p), &shape);
|
||||
}
|
||||
|
||||
// Train N steps of AdamW with a CONSTANT lr (no schedule) on the fixed batch.
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let targets_t = ids_tensor(&targets, device);
|
||||
let mut opt = AdamW::new(LR, WD);
|
||||
let mut losses = Vec::with_capacity(N_STEPS);
|
||||
for _ in 0..N_STEPS {
|
||||
let loss = model.loss(&ids_t, &targets_t);
|
||||
losses.push(param_to_host(&loss)[0]);
|
||||
loss.backward();
|
||||
opt.step(LR, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("losses.txt")).unwrap();
|
||||
for l in &losses {
|
||||
writeln!(f, "{l:.8e}").unwrap();
|
||||
}
|
||||
}
|
||||
for (name, p) in names.iter().zip(¶ms) {
|
||||
let shape = p.value().shape().to_vec();
|
||||
write_vec(&dir, &format!("wN_{name}.txt"), ¶m_to_host(p), &shape);
|
||||
}
|
||||
|
||||
println!(
|
||||
"adamw parity: dumped to {} (loss {:.6e} → {:.6e} over {N_STEPS} steps)",
|
||||
dir.display(),
|
||||
losses.first().unwrap(),
|
||||
losses.last().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
fn param_names(cfg: &Config) -> Vec<String> {
|
||||
let mut names = vec!["embed".to_string()];
|
||||
for l in 0..cfg.n_layers {
|
||||
for p in [
|
||||
"attn_norm",
|
||||
"wq",
|
||||
"wk",
|
||||
"wv",
|
||||
"wo",
|
||||
"ffn_norm",
|
||||
"w_gate",
|
||||
"w_up",
|
||||
"w_down",
|
||||
] {
|
||||
names.push(format!("l{l}_{p}"));
|
||||
}
|
||||
}
|
||||
names.push("final_norm".to_string());
|
||||
names.push("lm_head".to_string());
|
||||
names
|
||||
}
|
||||
113
crates/xtrain-train/tests/checkpoint_roundtrip.rs
Normal file
113
crates/xtrain-train/tests/checkpoint_roundtrip.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
// Checkpoint round-trip acceptance (Phase T6): train a few AdamW steps on a fixed
|
||||
// batch, save the params, build a FRESH model (different init), load the
|
||||
// checkpoint into it, and assert it produces identical logits + loss on a fixed
|
||||
// input. This verifies the on-disk format dumps/reloads `params()` in order with
|
||||
// exact f32 fidelity. Gated #![cfg(not(no_cuda))] (runs on dash5).
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
|
||||
use xtrain_optim::AdamW;
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::checkpoint;
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn make_model(device: Device, vocab: usize, init_seed: u64) -> TinyTransformer {
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = vocab;
|
||||
let mut seed = init_seed;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn checkpoint_roundtrip_identical_logits() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let vocab = 12;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
|
||||
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let targets_t = ids_tensor(&targets, device);
|
||||
|
||||
// --- Train a few steps so the params are non-trivial (not the init). ---
|
||||
let model = make_model(device, vocab, 1);
|
||||
let params = model.params();
|
||||
let mut opt = AdamW::new(0.01, 0.1);
|
||||
for _ in 0..5 {
|
||||
let loss = model.loss(&ids_t, &targets_t);
|
||||
loss.backward();
|
||||
opt.step(0.01, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
}
|
||||
|
||||
let path = std::env::temp_dir().join(format!("xtrain_ckpt_{}.bin", std::process::id()));
|
||||
checkpoint::save(&path, ¶ms).unwrap();
|
||||
|
||||
let ref_logits = param_to_host(&model.forward(&ids_t));
|
||||
let ref_loss = param_to_host(&model.loss(&ids_t, &targets_t))[0];
|
||||
|
||||
// --- Fresh model with a DIFFERENT init; loading must overwrite it exactly. ---
|
||||
let fresh = make_model(device, vocab, 999);
|
||||
let fresh_params = fresh.params();
|
||||
// Sanity: before load, the fresh model disagrees.
|
||||
let pre = param_to_host(&fresh.forward(&ids_t));
|
||||
let pre_diff: f32 = pre
|
||||
.iter()
|
||||
.zip(&ref_logits)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0, f32::max);
|
||||
assert!(
|
||||
pre_diff > 1e-4,
|
||||
"fresh model unexpectedly matched before load"
|
||||
);
|
||||
|
||||
checkpoint::load_into(&path, &fresh_params).unwrap();
|
||||
|
||||
let got_logits = param_to_host(&fresh.forward(&ids_t));
|
||||
let got_loss = param_to_host(&fresh.loss(&ids_t, &targets_t))[0];
|
||||
let _ = std::fs::remove_file(&path);
|
||||
|
||||
// Exact f32 round-trip → bit-for-bit identical forward (same kernels, same
|
||||
// inputs). Allow only float noise from re-running the forward.
|
||||
let max_logit_diff: f32 = got_logits
|
||||
.iter()
|
||||
.zip(&ref_logits)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0, f32::max);
|
||||
println!(
|
||||
"checkpoint round-trip: max logit diff = {max_logit_diff:.3e}, loss {ref_loss:.6} vs {got_loss:.6}"
|
||||
);
|
||||
assert!(
|
||||
max_logit_diff < 1e-5,
|
||||
"logits differ after reload: {max_logit_diff:e}"
|
||||
);
|
||||
assert!(
|
||||
(got_loss - ref_loss).abs() < 1e-5,
|
||||
"loss differs after reload: {ref_loss} vs {got_loss}"
|
||||
);
|
||||
}
|
||||
121
crates/xtrain-train/tests/real_training.rs
Normal file
121
crates/xtrain-train/tests/real_training.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
// Real-training acceptance (Phase T6): train the tiny transformer on the
|
||||
// TinyStories corpus (tokenized with the reused GPT-2 BPE) for a BOUNDED budget
|
||||
// and assert the loss decreases substantially — the end-to-end signal that the
|
||||
// whole stack (data pipeline, AdamW, LR schedule, grad clip) learns. Prints the
|
||||
// loss curve and a couple of greedy samples.
|
||||
//
|
||||
// Needs the corpus + tokenizer present, so it is #[ignore] (run with --ignored)
|
||||
// and gated #![cfg(not(no_cuda))]. Paths are overridable via env vars.
|
||||
//
|
||||
// Run: cargo test -p xtrain-train --release --test real_training \
|
||||
// -- --ignored --nocapture
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
use xtrain_tensor::Device;
|
||||
use xtrain_train::data::Corpus;
|
||||
use xtrain_train::sample::generate;
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
use xtrain_train::{TrainConfig, train};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "real training; needs corpus + tokenizer; run with --ignored --release"]
|
||||
fn trains_on_tinystories() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok_path = PathBuf::from(
|
||||
std::env::var("XTRAIN_TOKENIZER")
|
||||
.unwrap_or_else(|_| "/opt/wjh/models/gpt2/tokenizer.json".into()),
|
||||
);
|
||||
let corpus_path = PathBuf::from(
|
||||
std::env::var("XTRAIN_CORPUS").unwrap_or_else(|_| "data/tinystories-valid-3mb.txt".into()),
|
||||
);
|
||||
|
||||
let corpus = Corpus::load(&tok_path, &corpus_path);
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = corpus.vocab_size;
|
||||
cfg.n_layers = 4;
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
|
||||
let steps = std::env::var("XTRAIN_STEPS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(800usize);
|
||||
let tcfg = TrainConfig {
|
||||
seq_len: 64,
|
||||
batch_size: 8,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: (steps / 20).max(20),
|
||||
total: steps,
|
||||
},
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 50,
|
||||
ckpt_path: None,
|
||||
ckpt_every: 0,
|
||||
seed: 42,
|
||||
};
|
||||
|
||||
let losses = train(&model, device, &corpus, &tcfg);
|
||||
// Average the first/last few steps to smooth per-step noise.
|
||||
let head: f32 =
|
||||
losses[..10.min(losses.len())].iter().sum::<f32>() / 10.0_f32.min(losses.len() as f32);
|
||||
let tail_n = 10.min(losses.len());
|
||||
let tail: f32 = losses[losses.len() - tail_n..].iter().sum::<f32>() / tail_n as f32;
|
||||
println!("loss: start(avg10) {head:.4} → end(avg10) {tail:.4}");
|
||||
|
||||
// A couple of greedy samples (should show English structure, not gibberish).
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
for p in ["Once upon a time", "The little"] {
|
||||
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, 40, 0.0, &mut rng);
|
||||
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
println!("sample [{p}] → {text}");
|
||||
}
|
||||
|
||||
// Bounded run: expect a substantial drop (not full convergence).
|
||||
assert!(
|
||||
tail < head - 0.5,
|
||||
"loss did not decrease substantially: {head:.4} → {tail:.4}"
|
||||
);
|
||||
assert!(tail < 6.5, "final loss implausibly high: {tail:.4}");
|
||||
}
|
||||
Reference in New Issue
Block a user