train: loop + checkpoint save/load + sampler + train binary

Training loop (train_loop.rs): sample batch_size sequences, forward loss +
backward (tape SUMs grads), clip_grad_norm with ×1/batch averaging, AdamW step
with scheduled lr, zero_grad; logs loss/lr/gnorm/tok-s and checkpoints
periodically; returns the loss trace.

Checkpoint (checkpoint.rs): flat little-endian dump of params() in order
(magic/version/count + per-param ndim/dims/f32 data); load_into validates and
overwrites a matching model's params via set_value (exact f32 round-trip).

Sampler (sample.rs): autoregressive greedy / temperature generation — re-runs
forward on the growing prefix (model is single-sequence, RoPE pos=row).

bin/train.rs: end-to-end entry — load tokenizer+corpus, train a tiny 4-layer
model for a bounded budget, checkpoint, print samples. no_cuda stub keeps it
buildable on a GPU-less host.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:29:58 +08:00
parent 7d84a64f5c
commit 77a82bfeee
6 changed files with 453 additions and 0 deletions

View File

@@ -14,3 +14,7 @@ xtrain-cuda = { path = "../xtrain-cuda" }
# crate inherits xserv's workspace for its own deps (serde/regex) — Cargo reads
# the target package's workspace, not ours.
xserv-tokenizer = { path = "../../../xserv/crates/xserv-tokenizer" }
[[bin]]
name = "train"
path = "src/bin/train.rs"

View File

@@ -0,0 +1,166 @@
//! End-to-end training entry point (Phase T6): load the GPT-2 BPE + TinyStories
//! corpus, train the tiny transformer with hand-written AdamW for a BOUNDED
//! budget, checkpoint it, and print a few generated samples.
//!
//! Run on dash5 (needs a GPU + the corpus + tokenizer.json):
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
//! cargo run -p xtrain-train --release --bin train -- \
//! /opt/wjh/models/gpt2/tokenizer.json \
//! data/tinystories-valid-3mb.txt
//!
//! Optional 3rd/4th args: number of steps, checkpoint path.
// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a
// stub `main` so the crate still builds for `cargo check`.
#[cfg(no_cuda)]
fn main() {
eprintln!("xtrain train: built without CUDA (no_cuda); run on a GPU host (dash5).");
}
#[cfg(not(no_cuda))]
use std::path::{Path, PathBuf};
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer};
#[cfg(not(no_cuda))]
use xtrain_tensor::Device;
#[cfg(not(no_cuda))]
use xtrain_train::data::Corpus;
#[cfg(not(no_cuda))]
use xtrain_train::sample::generate;
#[cfg(not(no_cuda))]
use xtrain_train::schedule::LrSchedule;
#[cfg(not(no_cuda))]
use xtrain_train::{TrainConfig, train};
// Deterministic LCG fill in [-scale, scale) — same init scheme as the T5 tests.
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
let args: Vec<String> = std::env::args().collect();
let tok_path = args
.get(1)
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
let corpus_path = args
.get(2)
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
let steps: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(2000);
let ckpt: PathBuf = args
.get(4)
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
println!(
"loading tokenizer {} + corpus {}",
tok_path.display(),
corpus_path.display()
);
let corpus = Corpus::load(&tok_path, &corpus_path);
println!(
"corpus: {} tokens, vocab {}",
corpus.len(),
corpus.vocab_size
);
// Tiny model sized to the BPE vocab. A real (but small) config: wider than
// the overfit test so it has capacity to learn English structure.
let mut cfg = Config::tiny();
cfg.vocab = corpus.vocab_size;
cfg.n_layers = 4;
println!(
"model: dim {} layers {} heads {} ffn {}{} params",
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.ffn_hidden,
cfg.num_params()
);
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
// RMSNorm gammas → ~1.
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
// Small fan-in-ish scale; keeps early logits tame.
fill(n, seed, 0.04)
}
});
let seq_len = 64;
let tcfg = TrainConfig {
seq_len,
batch_size: 8,
steps,
schedule: LrSchedule {
max_lr: 3e-3,
min_lr: 3e-4,
warmup: (steps / 20).max(20),
total: steps,
},
weight_decay: 0.1,
max_grad_norm: 1.0,
log_every: 50,
ckpt_path: Some(ckpt.clone()),
ckpt_every: 500,
seed: 42,
};
println!(
"training: {} steps, seq {}, batch {}, lr {:.1e}{:.1e}",
tcfg.steps, tcfg.seq_len, tcfg.batch_size, tcfg.schedule.max_lr, tcfg.schedule.min_lr
);
let losses = train(&model, device, &corpus, &tcfg);
let start = losses.first().copied().unwrap_or(0.0);
let end = losses.last().copied().unwrap_or(0.0);
println!("loss: start {start:.4} → end {end:.4}");
sample_some(&model, device, &tok_path);
}
#[cfg(not(no_cuda))]
fn sample_some(model: &TinyTransformer, device: Device, tok_path: &Path) {
use xserv_tokenizer::Tokenizer;
let tok = Tokenizer::from_file(tok_path);
let prompts = ["Once upon a time", "The little", "One day"];
println!("\n--- samples (greedy) ---");
for p in prompts {
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
let mut rng = 7u64;
let out = generate(model, device, &ids, 40, 0.0, &mut rng);
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
println!("[{p}] → {text}");
}
println!("\n--- samples (temperature 0.8) ---");
for p in prompts {
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
let mut rng = 13u64;
let out = generate(model, device, &ids, 40, 0.8, &mut rng);
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
println!("[{p}] → {text}");
}
}

View File

@@ -0,0 +1,90 @@
//! Checkpoint save/load. Dumps the model's `params()` (in their stable order) to
//! a flat binary file and reloads them into a model with matching architecture.
//!
//! Format (little-endian):
//! ```text
//! magic : u32 = 0x58545254 ("XTRT")
//! version : u32 = 1
//! n_params: u32
//! repeat n_params times:
//! ndim : u32
//! dims : [u32; ndim]
//! data : [f32; prod(dims)]
//! ```
//! Architecture/config is NOT stored here — the caller rebuilds the model from
//! the same `Config` and `load_into`s the params (the round-trip and resume both
//! know their config). Gated behind `not(no_cuda)` (it round-trips GPU tensors).
#![cfg(not(no_cuda))]
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{Device, Tensor};
const MAGIC: u32 = 0x5854_5254;
const VERSION: u32 = 1;
/// Write every parameter (value) to `path` in `params()` order.
pub fn save(path: &Path, params: &[Var]) -> std::io::Result<()> {
let mut w = BufWriter::new(File::create(path)?);
w.write_all(&MAGIC.to_le_bytes())?;
w.write_all(&VERSION.to_le_bytes())?;
w.write_all(&(params.len() as u32).to_le_bytes())?;
for p in params {
let v = p.value().to_device(Device::Cpu);
let shape = v.shape();
w.write_all(&(shape.len() as u32).to_le_bytes())?;
for &d in shape {
w.write_all(&(d as u32).to_le_bytes())?;
}
for &x in v.as_slice::<f32>() {
w.write_all(&x.to_le_bytes())?;
}
}
w.flush()
}
/// Read a checkpoint and overwrite each parameter's value in `params` (in order).
/// Shapes must match the saved ones. Tensors are placed on each param's device.
pub fn load_into(path: &Path, params: &[Var]) -> std::io::Result<()> {
let mut r = BufReader::new(File::open(path)?);
assert_eq!(read_u32(&mut r)?, MAGIC, "bad checkpoint magic");
assert_eq!(read_u32(&mut r)?, VERSION, "unsupported checkpoint version");
let n = read_u32(&mut r)? as usize;
assert_eq!(n, params.len(), "checkpoint param count != model");
for p in params {
let ndim = read_u32(&mut r)? as usize;
let mut dims = Vec::with_capacity(ndim);
for _ in 0..ndim {
dims.push(read_u32(&mut r)? as usize);
}
let numel: usize = dims.iter().product();
let mut data = vec![0.0f32; numel];
for slot in data.iter_mut() {
*slot = read_f32(&mut r)?;
}
let device = p.value().device();
assert_eq!(
p.value().shape(),
dims.as_slice(),
"checkpoint shape mismatch"
);
p.set_value(Tensor::from_slice(&data, &dims).to_device(device));
}
Ok(())
}
fn read_u32<R: Read>(r: &mut R) -> std::io::Result<u32> {
let mut b = [0u8; 4];
r.read_exact(&mut b)?;
Ok(u32::from_le_bytes(b))
}
fn read_f32<R: Read>(r: &mut R) -> std::io::Result<f32> {
let mut b = [0u8; 4];
r.read_exact(&mut b)?;
Ok(f32::from_le_bytes(b))
}

View File

@@ -10,3 +10,13 @@
pub mod clip;
pub mod data;
pub mod schedule;
#[cfg(not(no_cuda))]
pub mod checkpoint;
#[cfg(not(no_cuda))]
pub mod sample;
#[cfg(not(no_cuda))]
mod train_loop;
#[cfg(not(no_cuda))]
pub use train_loop::{TrainConfig, train};

View File

@@ -0,0 +1,76 @@
//! Autoregressive text sampling from the trained model. The model is
//! single-sequence with RoPE position = row index, so generation re-runs the
//! forward on the growing prefix each step and reads the last row's logits — the
//! simplest correct approach (no KV cache; that is an inference/perf concern).
//!
//! Greedy when `temperature == 0`, else temperature sampling over the softmax.
#![cfg(not(no_cuda))]
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_tensor::Device;
/// Generate `max_new` tokens continuing `prompt`. `temperature == 0` → greedy
/// argmax; otherwise sample from softmax(logits / temperature). `rng_state` is a
/// reproducible LCG seed (only used when temperature > 0).
pub fn generate(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<i32> {
let vocab = model.config().vocab;
let mut ids: Vec<i32> = prompt.to_vec();
for _ in 0..max_new {
let ids_t = ids_tensor(&ids, device);
let logits = model.forward(&ids_t).value().to_device(Device::Cpu);
let lg = logits.as_slice::<f32>();
// Last row = next-token distribution for the current prefix.
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
let next = if temperature <= 0.0 {
argmax(last)
} else {
sample_temperature(last, temperature, rng_state)
};
ids.push(next as i32);
}
ids
}
fn argmax(row: &[f32]) -> usize {
row.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
}
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
// Softmax with temperature (numerically stable).
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row
.iter()
.map(|&x| ((x - max) / temperature).exp())
.collect();
let sum: f32 = exps.iter().sum();
let r = (next_rand(rng_state) as f32 / u32::MAX as f32) * sum;
let mut acc = 0.0;
for (i, &e) in exps.iter().enumerate() {
acc += e;
if acc >= r {
return i;
}
}
exps.len() - 1
}
fn next_rand(state: &mut u64) -> u32 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 32) as u32
}

View File

@@ -0,0 +1,107 @@
//! The training loop: sample sequences → forward `loss` → backward → grad clip
//! (with batch averaging) → AdamW step → zero grads; with an LR schedule,
//! periodic loss logging, and periodic checkpointing.
//!
//! The T5 model is single-sequence, so a "batch" of `batch_size` sequences is
//! handled by running forward+backward on each and letting the tape SUM their
//! grads (its fan-out rule); the clip pass then multiplies by `1/batch_size` to
//! recover the batch-mean gradient before clipping + the optimizer step.
#![cfg(not(no_cuda))]
use std::path::PathBuf;
use std::time::Instant;
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_optim::AdamW;
use xtrain_tensor::Device;
use crate::checkpoint;
use crate::clip::clip_grad_norm;
use crate::data::Corpus;
use crate::schedule::LrSchedule;
/// Knobs for a training run.
pub struct TrainConfig {
pub seq_len: usize,
pub batch_size: usize,
pub steps: usize,
pub schedule: LrSchedule,
pub weight_decay: f32,
pub max_grad_norm: f32,
pub log_every: usize,
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
pub ckpt_path: Option<PathBuf>,
pub ckpt_every: usize,
/// Seed for reproducible sequence sampling.
pub seed: u64,
}
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
/// loss trace (one mean loss per step, read from the first sequence of the
/// batch — cheap and representative). Logs progress and checkpoints as configured.
pub fn train(
model: &TinyTransformer,
device: Device,
corpus: &Corpus,
cfg: &TrainConfig,
) -> Vec<f32> {
let params = model.params();
let mut opt = AdamW::new(cfg.schedule.max_lr, cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
let inv_batch = 1.0 / cfg.batch_size as f32;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
// Accumulate grads over `batch_size` sequences (tape SUMs them).
let mut step_loss = 0.0f32;
for _ in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
step_loss += read_scalar(&loss);
loss.backward();
tokens_seen += cfg.seq_len as u64;
}
step_loss *= inv_batch;
losses.push(step_loss);
// Average the summed grads (×1/batch) and clip to the global norm.
let gnorm = clip_grad_norm(&params, cfg.max_grad_norm, inv_batch);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
if step % cfg.log_every == 0 || step == cfg.steps - 1 {
let elapsed = start.elapsed().as_secs_f32();
let tps = tokens_seen as f32 / elapsed.max(1e-6);
println!(
"step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
({tps:.0} tok/s)",
cfg.steps
);
}
if let Some(path) = &cfg.ckpt_path {
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
checkpoint::save(path, &params).expect("checkpoint save");
}
}
}
if let Some(path) = &cfg.ckpt_path {
checkpoint::save(path, &params).expect("final checkpoint save");
println!("saved checkpoint → {}", path.display());
}
losses
}
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
}