diff --git a/crates/xtrain-train/Cargo.toml b/crates/xtrain-train/Cargo.toml index fc1d1f4..07bcd42 100644 --- a/crates/xtrain-train/Cargo.toml +++ b/crates/xtrain-train/Cargo.toml @@ -14,3 +14,7 @@ xtrain-cuda = { path = "../xtrain-cuda" } # crate inherits xserv's workspace for its own deps (serde/regex) — Cargo reads # the target package's workspace, not ours. xserv-tokenizer = { path = "../../../xserv/crates/xserv-tokenizer" } + +[[bin]] +name = "train" +path = "src/bin/train.rs" diff --git a/crates/xtrain-train/src/bin/train.rs b/crates/xtrain-train/src/bin/train.rs new file mode 100644 index 0000000..f7656db --- /dev/null +++ b/crates/xtrain-train/src/bin/train.rs @@ -0,0 +1,166 @@ +//! End-to-end training entry point (Phase T6): load the GPT-2 BPE + TinyStories +//! corpus, train the tiny transformer with hand-written AdamW for a BOUNDED +//! budget, checkpoint it, and print a few generated samples. +//! +//! Run on dash5 (needs a GPU + the corpus + tokenizer.json): +//! export PATH=/usr/local/cuda/bin:/opt/wjh/.cargo/bin:$PATH +//! cargo run -p xtrain-train --release --bin train -- \ +//! /opt/wjh/models/gpt2/tokenizer.json \ +//! data/tinystories-valid-3mb.txt +//! +//! Optional 3rd/4th args: number of steps, checkpoint path. + +// On a GPU-less host (no_cuda) the whole training body is unavailable; keep a +// stub `main` so the crate still builds for `cargo check`. +#[cfg(no_cuda)] +fn main() { + eprintln!("xtrain train: built without CUDA (no_cuda); run on a GPU host (dash5)."); +} + +#[cfg(not(no_cuda))] +use std::path::{Path, PathBuf}; + +#[cfg(not(no_cuda))] +use xtrain_cuda::device; +#[cfg(not(no_cuda))] +use xtrain_model::{Config, TinyTransformer}; +#[cfg(not(no_cuda))] +use xtrain_tensor::Device; +#[cfg(not(no_cuda))] +use xtrain_train::data::Corpus; +#[cfg(not(no_cuda))] +use xtrain_train::sample::generate; +#[cfg(not(no_cuda))] +use xtrain_train::schedule::LrSchedule; +#[cfg(not(no_cuda))] +use xtrain_train::{TrainConfig, train}; + +// Deterministic LCG fill in [-scale, scale) — same init scheme as the T5 tests. +#[cfg(not(no_cuda))] +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +#[cfg(not(no_cuda))] +fn main() { + let args: Vec = std::env::args().collect(); + let tok_path = args + .get(1) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json")); + let corpus_path = args + .get(2) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("data/tinystories-valid-3mb.txt")); + let steps: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(2000); + let ckpt: PathBuf = args + .get(4) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("/tmp/xtrain_tinystories.ckpt")); + + assert!(device::device_count().unwrap() > 0, "no CUDA device"); + device::set_device(0).unwrap(); + let device = Device::Cuda(0); + + println!( + "loading tokenizer {} + corpus {}", + tok_path.display(), + corpus_path.display() + ); + let corpus = Corpus::load(&tok_path, &corpus_path); + println!( + "corpus: {} tokens, vocab {}", + corpus.len(), + corpus.vocab_size + ); + + // Tiny model sized to the BPE vocab. A real (but small) config: wider than + // the overfit test so it has capacity to learn English structure. + let mut cfg = Config::tiny(); + cfg.vocab = corpus.vocab_size; + cfg.n_layers = 4; + println!( + "model: dim {} layers {} heads {} ffn {} → {} params", + cfg.dim, + cfg.n_layers, + cfg.n_heads, + cfg.ffn_hidden, + cfg.num_params() + ); + + let mut seed = 1u64; + let model = TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + // RMSNorm gammas → ~1. + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + // Small fan-in-ish scale; keeps early logits tame. + fill(n, seed, 0.04) + } + }); + + let seq_len = 64; + let tcfg = TrainConfig { + seq_len, + batch_size: 8, + steps, + schedule: LrSchedule { + max_lr: 3e-3, + min_lr: 3e-4, + warmup: (steps / 20).max(20), + total: steps, + }, + weight_decay: 0.1, + max_grad_norm: 1.0, + log_every: 50, + ckpt_path: Some(ckpt.clone()), + ckpt_every: 500, + seed: 42, + }; + + println!( + "training: {} steps, seq {}, batch {}, lr {:.1e}→{:.1e}", + tcfg.steps, tcfg.seq_len, tcfg.batch_size, tcfg.schedule.max_lr, tcfg.schedule.min_lr + ); + let losses = train(&model, device, &corpus, &tcfg); + let start = losses.first().copied().unwrap_or(0.0); + let end = losses.last().copied().unwrap_or(0.0); + println!("loss: start {start:.4} → end {end:.4}"); + + sample_some(&model, device, &tok_path); +} + +#[cfg(not(no_cuda))] +fn sample_some(model: &TinyTransformer, device: Device, tok_path: &Path) { + use xserv_tokenizer::Tokenizer; + let tok = Tokenizer::from_file(tok_path); + let prompts = ["Once upon a time", "The little", "One day"]; + println!("\n--- samples (greedy) ---"); + for p in prompts { + let ids: Vec = tok.encode(p).into_iter().map(|t| t as i32).collect(); + let mut rng = 7u64; + let out = generate(model, device, &ids, 40, 0.0, &mut rng); + let text = tok.decode(&out.iter().map(|&t| t as u32).collect::>()); + println!("[{p}] → {text}"); + } + println!("\n--- samples (temperature 0.8) ---"); + for p in prompts { + let ids: Vec = tok.encode(p).into_iter().map(|t| t as i32).collect(); + let mut rng = 13u64; + let out = generate(model, device, &ids, 40, 0.8, &mut rng); + let text = tok.decode(&out.iter().map(|&t| t as u32).collect::>()); + println!("[{p}] → {text}"); + } +} diff --git a/crates/xtrain-train/src/checkpoint.rs b/crates/xtrain-train/src/checkpoint.rs new file mode 100644 index 0000000..4078a7e --- /dev/null +++ b/crates/xtrain-train/src/checkpoint.rs @@ -0,0 +1,90 @@ +//! Checkpoint save/load. Dumps the model's `params()` (in their stable order) to +//! a flat binary file and reloads them into a model with matching architecture. +//! +//! Format (little-endian): +//! ```text +//! magic : u32 = 0x58545254 ("XTRT") +//! version : u32 = 1 +//! n_params: u32 +//! repeat n_params times: +//! ndim : u32 +//! dims : [u32; ndim] +//! data : [f32; prod(dims)] +//! ``` +//! Architecture/config is NOT stored here — the caller rebuilds the model from +//! the same `Config` and `load_into`s the params (the round-trip and resume both +//! know their config). Gated behind `not(no_cuda)` (it round-trips GPU tensors). + +#![cfg(not(no_cuda))] + +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::Path; +use xtrain_autodiff::tape::Var; +use xtrain_tensor::{Device, Tensor}; + +const MAGIC: u32 = 0x5854_5254; +const VERSION: u32 = 1; + +/// Write every parameter (value) to `path` in `params()` order. +pub fn save(path: &Path, params: &[Var]) -> std::io::Result<()> { + let mut w = BufWriter::new(File::create(path)?); + w.write_all(&MAGIC.to_le_bytes())?; + w.write_all(&VERSION.to_le_bytes())?; + w.write_all(&(params.len() as u32).to_le_bytes())?; + for p in params { + let v = p.value().to_device(Device::Cpu); + let shape = v.shape(); + w.write_all(&(shape.len() as u32).to_le_bytes())?; + for &d in shape { + w.write_all(&(d as u32).to_le_bytes())?; + } + for &x in v.as_slice::() { + w.write_all(&x.to_le_bytes())?; + } + } + w.flush() +} + +/// Read a checkpoint and overwrite each parameter's value in `params` (in order). +/// Shapes must match the saved ones. Tensors are placed on each param's device. +pub fn load_into(path: &Path, params: &[Var]) -> std::io::Result<()> { + let mut r = BufReader::new(File::open(path)?); + assert_eq!(read_u32(&mut r)?, MAGIC, "bad checkpoint magic"); + assert_eq!(read_u32(&mut r)?, VERSION, "unsupported checkpoint version"); + let n = read_u32(&mut r)? as usize; + assert_eq!(n, params.len(), "checkpoint param count != model"); + + for p in params { + let ndim = read_u32(&mut r)? as usize; + let mut dims = Vec::with_capacity(ndim); + for _ in 0..ndim { + dims.push(read_u32(&mut r)? as usize); + } + let numel: usize = dims.iter().product(); + let mut data = vec![0.0f32; numel]; + for slot in data.iter_mut() { + *slot = read_f32(&mut r)?; + } + let device = p.value().device(); + assert_eq!( + p.value().shape(), + dims.as_slice(), + "checkpoint shape mismatch" + ); + p.set_value(Tensor::from_slice(&data, &dims).to_device(device)); + } + Ok(()) +} + +fn read_u32(r: &mut R) -> std::io::Result { + let mut b = [0u8; 4]; + r.read_exact(&mut b)?; + Ok(u32::from_le_bytes(b)) +} + +fn read_f32(r: &mut R) -> std::io::Result { + let mut b = [0u8; 4]; + r.read_exact(&mut b)?; + Ok(f32::from_le_bytes(b)) +} diff --git a/crates/xtrain-train/src/lib.rs b/crates/xtrain-train/src/lib.rs index 96032e7..31f3bcf 100644 --- a/crates/xtrain-train/src/lib.rs +++ b/crates/xtrain-train/src/lib.rs @@ -10,3 +10,13 @@ pub mod clip; pub mod data; pub mod schedule; + +#[cfg(not(no_cuda))] +pub mod checkpoint; +#[cfg(not(no_cuda))] +pub mod sample; +#[cfg(not(no_cuda))] +mod train_loop; + +#[cfg(not(no_cuda))] +pub use train_loop::{TrainConfig, train}; diff --git a/crates/xtrain-train/src/sample.rs b/crates/xtrain-train/src/sample.rs new file mode 100644 index 0000000..2f8e26f --- /dev/null +++ b/crates/xtrain-train/src/sample.rs @@ -0,0 +1,76 @@ +//! Autoregressive text sampling from the trained model. The model is +//! single-sequence with RoPE position = row index, so generation re-runs the +//! forward on the growing prefix each step and reads the last row's logits — the +//! simplest correct approach (no KV cache; that is an inference/perf concern). +//! +//! Greedy when `temperature == 0`, else temperature sampling over the softmax. + +#![cfg(not(no_cuda))] + +use xtrain_model::{TinyTransformer, ids_tensor}; +use xtrain_tensor::Device; + +/// Generate `max_new` tokens continuing `prompt`. `temperature == 0` → greedy +/// argmax; otherwise sample from softmax(logits / temperature). `rng_state` is a +/// reproducible LCG seed (only used when temperature > 0). +pub fn generate( + model: &TinyTransformer, + device: Device, + prompt: &[i32], + max_new: usize, + temperature: f32, + rng_state: &mut u64, +) -> Vec { + let vocab = model.config().vocab; + let mut ids: Vec = prompt.to_vec(); + + for _ in 0..max_new { + let ids_t = ids_tensor(&ids, device); + let logits = model.forward(&ids_t).value().to_device(Device::Cpu); + let lg = logits.as_slice::(); + // Last row = next-token distribution for the current prefix. + let last = &lg[(ids.len() - 1) * vocab..ids.len() * vocab]; + + let next = if temperature <= 0.0 { + argmax(last) + } else { + sample_temperature(last, temperature, rng_state) + }; + ids.push(next as i32); + } + ids +} + +fn argmax(row: &[f32]) -> usize { + row.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .unwrap() + .0 +} + +fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize { + // Softmax with temperature (numerically stable). + let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + let exps: Vec = row + .iter() + .map(|&x| ((x - max) / temperature).exp()) + .collect(); + let sum: f32 = exps.iter().sum(); + let r = (next_rand(rng_state) as f32 / u32::MAX as f32) * sum; + let mut acc = 0.0; + for (i, &e) in exps.iter().enumerate() { + acc += e; + if acc >= r { + return i; + } + } + exps.len() - 1 +} + +fn next_rand(state: &mut u64) -> u32 { + *state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (*state >> 32) as u32 +} diff --git a/crates/xtrain-train/src/train_loop.rs b/crates/xtrain-train/src/train_loop.rs new file mode 100644 index 0000000..29f024d --- /dev/null +++ b/crates/xtrain-train/src/train_loop.rs @@ -0,0 +1,107 @@ +//! The training loop: sample sequences → forward `loss` → backward → grad clip +//! (with batch averaging) → AdamW step → zero grads; with an LR schedule, +//! periodic loss logging, and periodic checkpointing. +//! +//! The T5 model is single-sequence, so a "batch" of `batch_size` sequences is +//! handled by running forward+backward on each and letting the tape SUM their +//! grads (its fan-out rule); the clip pass then multiplies by `1/batch_size` to +//! recover the batch-mean gradient before clipping + the optimizer step. + +#![cfg(not(no_cuda))] + +use std::path::PathBuf; +use std::time::Instant; + +use xtrain_model::{TinyTransformer, ids_tensor}; +use xtrain_optim::AdamW; +use xtrain_tensor::Device; + +use crate::checkpoint; +use crate::clip::clip_grad_norm; +use crate::data::Corpus; +use crate::schedule::LrSchedule; + +/// Knobs for a training run. +pub struct TrainConfig { + pub seq_len: usize, + pub batch_size: usize, + pub steps: usize, + pub schedule: LrSchedule, + pub weight_decay: f32, + pub max_grad_norm: f32, + pub log_every: usize, + /// Optional checkpoint path written every `ckpt_every` steps (and at the end). + pub ckpt_path: Option, + pub ckpt_every: usize, + /// Seed for reproducible sequence sampling. + pub seed: u64, +} + +/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step +/// loss trace (one mean loss per step, read from the first sequence of the +/// batch — cheap and representative). Logs progress and checkpoints as configured. +pub fn train( + model: &TinyTransformer, + device: Device, + corpus: &Corpus, + cfg: &TrainConfig, +) -> Vec { + let params = model.params(); + let mut opt = AdamW::new(cfg.schedule.max_lr, cfg.weight_decay); + let mut rng = cfg.seed; + let mut losses = Vec::with_capacity(cfg.steps); + let inv_batch = 1.0 / cfg.batch_size as f32; + let start = Instant::now(); + let mut tokens_seen: u64 = 0; + + for step in 0..cfg.steps { + let lr = cfg.schedule.lr(step); + + // Accumulate grads over `batch_size` sequences (tape SUMs them). + let mut step_loss = 0.0f32; + for _ in 0..cfg.batch_size { + let (input, target) = corpus.sample(cfg.seq_len, &mut rng); + let ids = ids_tensor(&input, device); + let targets = ids_tensor(&target, device); + let loss = model.loss(&ids, &targets); + step_loss += read_scalar(&loss); + loss.backward(); + tokens_seen += cfg.seq_len as u64; + } + step_loss *= inv_batch; + losses.push(step_loss); + + // Average the summed grads (×1/batch) and clip to the global norm. + let gnorm = clip_grad_norm(¶ms, cfg.max_grad_norm, inv_batch); + opt.step(lr, ¶ms); + for p in ¶ms { + p.zero_grad(); + } + + if step % cfg.log_every == 0 || step == cfg.steps - 1 { + let elapsed = start.elapsed().as_secs_f32(); + let tps = tokens_seen as f32 / elapsed.max(1e-6); + println!( + "step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \ + ({tps:.0} tok/s)", + cfg.steps + ); + } + + if let Some(path) = &cfg.ckpt_path { + if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 { + checkpoint::save(path, ¶ms).expect("checkpoint save"); + } + } + } + + if let Some(path) = &cfg.ckpt_path { + checkpoint::save(path, ¶ms).expect("final checkpoint save"); + println!("saved checkpoint → {}", path.display()); + } + losses +} + +fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 { + v.value().to_device(Device::Cpu).as_slice::()[0] +}