Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch (train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped once per training forward. forward_batched derives a per-layer block_seed (pure fn of step_seed×layer) and block_forward derives two per-site seeds, inserting ops::dropout at the attn and ffn sub-block outputs (before each residual). The seed is a pure function of (step_seed, layer, site) so the checkpoint (T13) recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout node → graph bit-identical to pre-T18. train_loop: model.train() per step (restored after eval flips to eval); eval_loss runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in eval (default), so exported weights are dropout-free (xserv closed loop unaffected). Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads); eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout grads match non-recompute (fp32 + bf16). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
205 lines
7.6 KiB
Rust
205 lines
7.6 KiB
Rust
//! The training loop: sample a batch of sequences → ONE batched forward `loss` →
|
||
//! backward → grad clip → AdamW step → zero grads; with an LR schedule, periodic
|
||
//! loss logging, and periodic checkpointing.
|
||
//!
|
||
//! Since T10 the model is batched (`loss_batched`): `batch_size` sequences are
|
||
//! flattened to `[batch*seq]` and run as a SINGLE forward/backward, so the linear
|
||
//! projections become big `[batch*seq, dim]` GEMMs that fill the GPU. The
|
||
//! cross-entropy mean is over all `batch*seq` rows — already the batch-mean loss,
|
||
//! so backward yields the batch-mean gradient directly (clip pre-scale = 1.0; no
|
||
//! more "loop B times + SUM + ×1/batch" hack).
|
||
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use std::path::PathBuf;
|
||
use std::time::Instant;
|
||
|
||
use xtrain_model::{TinyTransformer, batched_ids_tensor, ids_tensor};
|
||
use xtrain_optim::GpuAdamW;
|
||
use xtrain_tensor::Device;
|
||
|
||
use crate::checkpoint;
|
||
use crate::clip::clip_grad_norm_gpu;
|
||
use crate::data::Corpus;
|
||
use crate::schedule::LrSchedule;
|
||
|
||
/// Knobs for a training run.
|
||
pub struct TrainConfig {
|
||
pub seq_len: usize,
|
||
pub batch_size: usize,
|
||
pub steps: usize,
|
||
pub schedule: LrSchedule,
|
||
pub weight_decay: f32,
|
||
pub max_grad_norm: f32,
|
||
pub log_every: usize,
|
||
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
|
||
/// When `eval_every > 0`, the checkpoint instead tracks the BEST val loss.
|
||
pub ckpt_path: Option<PathBuf>,
|
||
pub ckpt_every: usize,
|
||
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Each eval
|
||
/// averages cross-entropy over `eval_batches` fixed windows of the val corpus.
|
||
pub eval_every: usize,
|
||
pub eval_batches: usize,
|
||
/// Seed for reproducible sequence sampling.
|
||
pub seed: u64,
|
||
}
|
||
|
||
/// Outcome of a run: per-step train losses and (step, val_loss) eval points.
|
||
pub struct TrainResult {
|
||
pub train_losses: Vec<f32>,
|
||
pub evals: Vec<(usize, f32)>,
|
||
pub best_val: Option<f32>,
|
||
}
|
||
|
||
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
|
||
/// train-loss trace plus any (step, val_loss) eval points. Logs progress, and —
|
||
/// when `valid` is given and `cfg.eval_every > 0` — evaluates held-out val loss
|
||
/// periodically and checkpoints the BEST val model (else checkpoints on a fixed
|
||
/// cadence, as in T6). Logs progress.
|
||
pub fn train(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
corpus: &Corpus,
|
||
valid: Option<&Corpus>,
|
||
cfg: &TrainConfig,
|
||
) -> TrainResult {
|
||
let params = model.params();
|
||
let mut opt = GpuAdamW::new(cfg.weight_decay);
|
||
let mut rng = cfg.seed;
|
||
let mut losses = Vec::with_capacity(cfg.steps);
|
||
let mut evals = Vec::new();
|
||
let mut best_val: Option<f32> = None;
|
||
let start = Instant::now();
|
||
let mut tokens_seen: u64 = 0;
|
||
// Best-val checkpointing only kicks in when we actually evaluate.
|
||
let track_best = valid.is_some() && cfg.eval_every > 0;
|
||
|
||
for step in 0..cfg.steps {
|
||
let lr = cfg.schedule.lr(step);
|
||
|
||
// Sample `batch_size` sequences and run them as ONE batched forward/
|
||
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
|
||
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
|
||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||
for _ in 0..cfg.batch_size {
|
||
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
|
||
inputs.push(input);
|
||
targets_v.push(target);
|
||
}
|
||
let ids = batched_ids_tensor(&inputs, device);
|
||
let targets = batched_ids_tensor(&targets_v, device);
|
||
// Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set
|
||
// each step so it is restored after a periodic eval flips to eval mode.
|
||
model.train();
|
||
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
|
||
let step_loss = read_scalar(&loss);
|
||
loss.backward();
|
||
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
|
||
losses.push(step_loss);
|
||
|
||
// Backward already produced the batch-mean gradient — just clip it.
|
||
let gnorm = clip_grad_norm_gpu(¶ms, cfg.max_grad_norm, 1.0);
|
||
opt.step(lr, ¶ms);
|
||
for p in ¶ms {
|
||
p.zero_grad();
|
||
}
|
||
|
||
if step % cfg.log_every == 0 || step == cfg.steps - 1 {
|
||
let elapsed = start.elapsed().as_secs_f32();
|
||
let tps = tokens_seen as f32 / elapsed.max(1e-6);
|
||
println!(
|
||
"step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
|
||
({tps:.0} tok/s)",
|
||
cfg.steps
|
||
);
|
||
}
|
||
|
||
// Periodic held-out eval (deterministic windows, no grad).
|
||
if let Some(v) = valid {
|
||
if cfg.eval_every > 0 && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
|
||
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
|
||
evals.push((step, vl));
|
||
let improved = best_val.map(|b| vl < b).unwrap_or(true);
|
||
println!(
|
||
" eval @ step {step}: val loss {vl:.4}{}",
|
||
if improved { " (best)" } else { "" }
|
||
);
|
||
if improved {
|
||
best_val = Some(vl);
|
||
if let Some(path) = &cfg.ckpt_path {
|
||
checkpoint::save(path, ¶ms).expect("best checkpoint save");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Fixed-cadence checkpointing (only when not tracking best val).
|
||
if !track_best {
|
||
if let Some(path) = &cfg.ckpt_path {
|
||
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
|
||
checkpoint::save(path, ¶ms).expect("checkpoint save");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Without periodic eval, still persist the final params (T6 behaviour). With
|
||
// best-val tracking the checkpoint already holds the best model — don't clobber.
|
||
if !track_best {
|
||
if let Some(path) = &cfg.ckpt_path {
|
||
checkpoint::save(path, ¶ms).expect("final checkpoint save");
|
||
println!("saved checkpoint → {}", path.display());
|
||
}
|
||
}
|
||
TrainResult {
|
||
train_losses: losses,
|
||
evals,
|
||
best_val,
|
||
}
|
||
}
|
||
|
||
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
|
||
/// the validation corpus (no backward — eval only). Deterministic so val loss is
|
||
/// comparable across steps and runs (and across models — the v0-vs-v1 metric).
|
||
pub fn eval_loss(
|
||
model: &TinyTransformer,
|
||
device: Device,
|
||
valid: &Corpus,
|
||
seq: usize,
|
||
batches: usize,
|
||
) -> f32 {
|
||
if valid.len() <= seq + 1 {
|
||
return f32::NAN;
|
||
}
|
||
// Eval mode → dropout is identity (T18).
|
||
model.eval();
|
||
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
|
||
let batches = batches.max(1).min(n_win.max(1));
|
||
let stride = (n_win / batches).max(1);
|
||
let mut sum = 0.0f32;
|
||
let mut count = 0usize;
|
||
for i in 0..batches {
|
||
let s = (i * stride) * seq;
|
||
if s + seq + 1 > valid.len() {
|
||
break;
|
||
}
|
||
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
|
||
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
|
||
let ids = ids_tensor(&input, device);
|
||
let targets = ids_tensor(&target, device);
|
||
let loss = model.loss(&ids, &targets);
|
||
sum += read_scalar(&loss);
|
||
count += 1;
|
||
}
|
||
if count == 0 {
|
||
f32::NAN
|
||
} else {
|
||
sum / count as f32
|
||
}
|
||
}
|
||
|
||
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
||
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
|
||
}
|