train: loop + checkpoint save/load + sampler + train binary
Training loop (train_loop.rs): sample batch_size sequences, forward loss + backward (tape SUMs grads), clip_grad_norm with ×1/batch averaging, AdamW step with scheduled lr, zero_grad; logs loss/lr/gnorm/tok-s and checkpoints periodically; returns the loss trace. Checkpoint (checkpoint.rs): flat little-endian dump of params() in order (magic/version/count + per-param ndim/dims/f32 data); load_into validates and overwrites a matching model's params via set_value (exact f32 round-trip). Sampler (sample.rs): autoregressive greedy / temperature generation — re-runs forward on the growing prefix (model is single-sequence, RoPE pos=row). bin/train.rs: end-to-end entry — load tokenizer+corpus, train a tiny 4-layer model for a bounded budget, checkpoint, print samples. no_cuda stub keeps it buildable on a GPU-less host. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -14,3 +14,7 @@ xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
# crate inherits xserv's workspace for its own deps (serde/regex) — Cargo reads
|
||||
# the target package's workspace, not ours.
|
||||
xserv-tokenizer = { path = "../../../xserv/crates/xserv-tokenizer" }
|
||||
|
||||
[[bin]]
|
||||
name = "train"
|
||||
path = "src/bin/train.rs"
|
||||
|
||||
166
crates/xtrain-train/src/bin/train.rs
Normal file
166
crates/xtrain-train/src/bin/train.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
//! End-to-end training entry point (Phase T6): load the GPT-2 BPE + TinyStories
|
||||
//! corpus, train the tiny transformer with hand-written AdamW for a BOUNDED
|
||||
//! budget, checkpoint it, and print a few generated samples.
|
||||
//!
|
||||
//! Run on dash5 (needs a GPU + the corpus + tokenizer.json):
|
||||
//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH
|
||||
//! cargo run -p xtrain-train --release --bin train -- \
|
||||
//! /opt/wjh/models/gpt2/tokenizer.json \
|
||||
//! data/tinystories-valid-3mb.txt
|
||||
//!
|
||||
//! Optional 3rd/4th args: number of steps, checkpoint path.
|
||||
|
||||
// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a
|
||||
// stub `main` so the crate still builds for `cargo check`.
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("xtrain train: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::data::Corpus;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::sample::generate;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::schedule::LrSchedule;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::{TrainConfig, train};
|
||||
|
||||
// Deterministic LCG fill in [-scale, scale) — same init scheme as the T5 tests.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let tok_path = args
|
||||
.get(1)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
let corpus_path = args
|
||||
.get(2)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt"));
|
||||
let steps: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(2000);
|
||||
let ckpt: PathBuf = args
|
||||
.get(4)
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt"));
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
println!(
|
||||
"loading tokenizer {} + corpus {}",
|
||||
tok_path.display(),
|
||||
corpus_path.display()
|
||||
);
|
||||
let corpus = Corpus::load(&tok_path, &corpus_path);
|
||||
println!(
|
||||
"corpus: {} tokens, vocab {}",
|
||||
corpus.len(),
|
||||
corpus.vocab_size
|
||||
);
|
||||
|
||||
// Tiny model sized to the BPE vocab. A real (but small) config: wider than
|
||||
// the overfit test so it has capacity to learn English structure.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = corpus.vocab_size;
|
||||
cfg.n_layers = 4;
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} ffn {} → {} params",
|
||||
cfg.dim,
|
||||
cfg.n_layers,
|
||||
cfg.n_heads,
|
||||
cfg.ffn_hidden,
|
||||
cfg.num_params()
|
||||
);
|
||||
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
// RMSNorm gammas → ~1.
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
// Small fan-in-ish scale; keeps early logits tame.
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
|
||||
let seq_len = 64;
|
||||
let tcfg = TrainConfig {
|
||||
seq_len,
|
||||
batch_size: 8,
|
||||
steps,
|
||||
schedule: LrSchedule {
|
||||
max_lr: 3e-3,
|
||||
min_lr: 3e-4,
|
||||
warmup: (steps / 20).max(20),
|
||||
total: steps,
|
||||
},
|
||||
weight_decay: 0.1,
|
||||
max_grad_norm: 1.0,
|
||||
log_every: 50,
|
||||
ckpt_path: Some(ckpt.clone()),
|
||||
ckpt_every: 500,
|
||||
seed: 42,
|
||||
};
|
||||
|
||||
println!(
|
||||
"training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}",
|
||||
tcfg.steps, tcfg.seq_len, tcfg.batch_size, tcfg.schedule.max_lr, tcfg.schedule.min_lr
|
||||
);
|
||||
let losses = train(&model, device, &corpus, &tcfg);
|
||||
let start = losses.first().copied().unwrap_or(0.0);
|
||||
let end = losses.last().copied().unwrap_or(0.0);
|
||||
println!("loss: start {start:.4} → end {end:.4}");
|
||||
|
||||
sample_some(&model, device, &tok_path);
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn sample_some(model: &TinyTransformer, device: Device, tok_path: &Path) {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
let tok = Tokenizer::from_file(tok_path);
|
||||
let prompts = ["Once upon a time", "The little", "One day"];
|
||||
println!("\n--- samples (greedy) ---");
|
||||
for p in prompts {
|
||||
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(model, device, &ids, 40, 0.0, &mut rng);
|
||||
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
println!("[{p}] → {text}");
|
||||
}
|
||||
println!("\n--- samples (temperature 0.8) ---");
|
||||
for p in prompts {
|
||||
let ids: Vec<i32> = tok.encode(p).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 13u64;
|
||||
let out = generate(model, device, &ids, 40, 0.8, &mut rng);
|
||||
let text = tok.decode(&out.iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
println!("[{p}] → {text}");
|
||||
}
|
||||
}
|
||||
90
crates/xtrain-train/src/checkpoint.rs
Normal file
90
crates/xtrain-train/src/checkpoint.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
//! Checkpoint save/load. Dumps the model's `params()` (in their stable order) to
|
||||
//! a flat binary file and reloads them into a model with matching architecture.
|
||||
//!
|
||||
//! Format (little-endian):
|
||||
//! ```text
|
||||
//! magic : u32 = 0x58545254 ("XTRT")
|
||||
//! version : u32 = 1
|
||||
//! n_params: u32
|
||||
//! repeat n_params times:
|
||||
//! ndim : u32
|
||||
//! dims : [u32; ndim]
|
||||
//! data : [f32; prod(dims)]
|
||||
//! ```
|
||||
//! Architecture/config is NOT stored here — the caller rebuilds the model from
|
||||
//! the same `Config` and `load_into`s the params (the round-trip and resume both
|
||||
//! know their config). Gated behind `not(no_cuda)` (it round-trips GPU tensors).
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter, Read, Write};
|
||||
use std::path::Path;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
const MAGIC: u32 = 0x5854_5254;
|
||||
const VERSION: u32 = 1;
|
||||
|
||||
/// Write every parameter (value) to `path` in `params()` order.
|
||||
pub fn save(path: &Path, params: &[Var]) -> std::io::Result<()> {
|
||||
let mut w = BufWriter::new(File::create(path)?);
|
||||
w.write_all(&MAGIC.to_le_bytes())?;
|
||||
w.write_all(&VERSION.to_le_bytes())?;
|
||||
w.write_all(&(params.len() as u32).to_le_bytes())?;
|
||||
for p in params {
|
||||
let v = p.value().to_device(Device::Cpu);
|
||||
let shape = v.shape();
|
||||
w.write_all(&(shape.len() as u32).to_le_bytes())?;
|
||||
for &d in shape {
|
||||
w.write_all(&(d as u32).to_le_bytes())?;
|
||||
}
|
||||
for &x in v.as_slice::<f32>() {
|
||||
w.write_all(&x.to_le_bytes())?;
|
||||
}
|
||||
}
|
||||
w.flush()
|
||||
}
|
||||
|
||||
/// Read a checkpoint and overwrite each parameter's value in `params` (in order).
|
||||
/// Shapes must match the saved ones. Tensors are placed on each param's device.
|
||||
pub fn load_into(path: &Path, params: &[Var]) -> std::io::Result<()> {
|
||||
let mut r = BufReader::new(File::open(path)?);
|
||||
assert_eq!(read_u32(&mut r)?, MAGIC, "bad checkpoint magic");
|
||||
assert_eq!(read_u32(&mut r)?, VERSION, "unsupported checkpoint version");
|
||||
let n = read_u32(&mut r)? as usize;
|
||||
assert_eq!(n, params.len(), "checkpoint param count != model");
|
||||
|
||||
for p in params {
|
||||
let ndim = read_u32(&mut r)? as usize;
|
||||
let mut dims = Vec::with_capacity(ndim);
|
||||
for _ in 0..ndim {
|
||||
dims.push(read_u32(&mut r)? as usize);
|
||||
}
|
||||
let numel: usize = dims.iter().product();
|
||||
let mut data = vec![0.0f32; numel];
|
||||
for slot in data.iter_mut() {
|
||||
*slot = read_f32(&mut r)?;
|
||||
}
|
||||
let device = p.value().device();
|
||||
assert_eq!(
|
||||
p.value().shape(),
|
||||
dims.as_slice(),
|
||||
"checkpoint shape mismatch"
|
||||
);
|
||||
p.set_value(Tensor::from_slice(&data, &dims).to_device(device));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_u32<R: Read>(r: &mut R) -> std::io::Result<u32> {
|
||||
let mut b = [0u8; 4];
|
||||
r.read_exact(&mut b)?;
|
||||
Ok(u32::from_le_bytes(b))
|
||||
}
|
||||
|
||||
fn read_f32<R: Read>(r: &mut R) -> std::io::Result<f32> {
|
||||
let mut b = [0u8; 4];
|
||||
r.read_exact(&mut b)?;
|
||||
Ok(f32::from_le_bytes(b))
|
||||
}
|
||||
@@ -10,3 +10,13 @@
|
||||
pub mod clip;
|
||||
pub mod data;
|
||||
pub mod schedule;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod sample;
|
||||
#[cfg(not(no_cuda))]
|
||||
mod train_loop;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use train_loop::{TrainConfig, train};
|
||||
|
||||
76
crates/xtrain-train/src/sample.rs
Normal file
76
crates/xtrain-train/src/sample.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
//! Autoregressive text sampling from the trained model. The model is
|
||||
//! single-sequence with RoPE position = row index, so generation re-runs the
|
||||
//! forward on the growing prefix each step and reads the last row's logits — the
|
||||
//! simplest correct approach (no KV cache; that is an inference/perf concern).
|
||||
//!
|
||||
//! Greedy when `temperature == 0`, else temperature sampling over the softmax.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_model::{TinyTransformer, ids_tensor};
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
/// Generate `max_new` tokens continuing `prompt`. `temperature == 0` → greedy
|
||||
/// argmax; otherwise sample from softmax(logits / temperature). `rng_state` is a
|
||||
/// reproducible LCG seed (only used when temperature > 0).
|
||||
pub fn generate(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
prompt: &[i32],
|
||||
max_new: usize,
|
||||
temperature: f32,
|
||||
rng_state: &mut u64,
|
||||
) -> Vec<i32> {
|
||||
let vocab = model.config().vocab;
|
||||
let mut ids: Vec<i32> = prompt.to_vec();
|
||||
|
||||
for _ in 0..max_new {
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let logits = model.forward(&ids_t).value().to_device(Device::Cpu);
|
||||
let lg = logits.as_slice::<f32>();
|
||||
// Last row = next-token distribution for the current prefix.
|
||||
let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab];
|
||||
|
||||
let next = if temperature <= 0.0 {
|
||||
argmax(last)
|
||||
} else {
|
||||
sample_temperature(last, temperature, rng_state)
|
||||
};
|
||||
ids.push(next as i32);
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
fn argmax(row: &[f32]) -> usize {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
|
||||
.unwrap()
|
||||
.0
|
||||
}
|
||||
|
||||
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
|
||||
// Softmax with temperature (numerically stable).
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = row
|
||||
.iter()
|
||||
.map(|&x| ((x - max) / temperature).exp())
|
||||
.collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
let r = (next_rand(rng_state) as f32 / u32::MAX as f32) * sum;
|
||||
let mut acc = 0.0;
|
||||
for (i, &e) in exps.iter().enumerate() {
|
||||
acc += e;
|
||||
if acc >= r {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
exps.len() - 1
|
||||
}
|
||||
|
||||
fn next_rand(state: &mut u64) -> u32 {
|
||||
*state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(*state >> 32) as u32
|
||||
}
|
||||
107
crates/xtrain-train/src/train_loop.rs
Normal file
107
crates/xtrain-train/src/train_loop.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! The training loop: sample sequences → forward `loss` → backward → grad clip
|
||||
//! (with batch averaging) → AdamW step → zero grads; with an LR schedule,
|
||||
//! periodic loss logging, and periodic checkpointing.
|
||||
//!
|
||||
//! The T5 model is single-sequence, so a "batch" of `batch_size` sequences is
|
||||
//! handled by running forward+backward on each and letting the tape SUM their
|
||||
//! grads (its fan-out rule); the clip pass then multiplies by `1/batch_size` to
|
||||
//! recover the batch-mean gradient before clipping + the optimizer step.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
|
||||
use xtrain_model::{TinyTransformer, ids_tensor};
|
||||
use xtrain_optim::AdamW;
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
use crate::checkpoint;
|
||||
use crate::clip::clip_grad_norm;
|
||||
use crate::data::Corpus;
|
||||
use crate::schedule::LrSchedule;
|
||||
|
||||
/// Knobs for a training run.
|
||||
pub struct TrainConfig {
|
||||
pub seq_len: usize,
|
||||
pub batch_size: usize,
|
||||
pub steps: usize,
|
||||
pub schedule: LrSchedule,
|
||||
pub weight_decay: f32,
|
||||
pub max_grad_norm: f32,
|
||||
pub log_every: usize,
|
||||
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
|
||||
pub ckpt_path: Option<PathBuf>,
|
||||
pub ckpt_every: usize,
|
||||
/// Seed for reproducible sequence sampling.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
|
||||
/// loss trace (one mean loss per step, read from the first sequence of the
|
||||
/// batch — cheap and representative). Logs progress and checkpoints as configured.
|
||||
pub fn train(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
corpus: &Corpus,
|
||||
cfg: &TrainConfig,
|
||||
) -> Vec<f32> {
|
||||
let params = model.params();
|
||||
let mut opt = AdamW::new(cfg.schedule.max_lr, cfg.weight_decay);
|
||||
let mut rng = cfg.seed;
|
||||
let mut losses = Vec::with_capacity(cfg.steps);
|
||||
let inv_batch = 1.0 / cfg.batch_size as f32;
|
||||
let start = Instant::now();
|
||||
let mut tokens_seen: u64 = 0;
|
||||
|
||||
for step in 0..cfg.steps {
|
||||
let lr = cfg.schedule.lr(step);
|
||||
|
||||
// Accumulate grads over `batch_size` sequences (tape SUMs them).
|
||||
let mut step_loss = 0.0f32;
|
||||
for _ in 0..cfg.batch_size {
|
||||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||||
let ids = ids_tensor(&input, device);
|
||||
let targets = ids_tensor(&target, device);
|
||||
let loss = model.loss(&ids, &targets);
|
||||
step_loss += read_scalar(&loss);
|
||||
loss.backward();
|
||||
tokens_seen += cfg.seq_len as u64;
|
||||
}
|
||||
step_loss *= inv_batch;
|
||||
losses.push(step_loss);
|
||||
|
||||
// Average the summed grads (×1/batch) and clip to the global norm.
|
||||
let gnorm = clip_grad_norm(¶ms, cfg.max_grad_norm, inv_batch);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
|
||||
if step % cfg.log_every == 0 || step == cfg.steps - 1 {
|
||||
let elapsed = start.elapsed().as_secs_f32();
|
||||
let tps = tokens_seen as f32 / elapsed.max(1e-6);
|
||||
println!(
|
||||
"step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
|
||||
({tps:.0} tok/s)",
|
||||
cfg.steps
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
||||
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = &cfg.ckpt_path {
|
||||
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
||||
println!("saved checkpoint → {}", path.display());
|
||||
}
|
||||
losses
|
||||
}
|
||||
|
||||
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
||||
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
|
||||
}
|
||||
Reference in New Issue
Block a user