Files
xtrain/crates/xtrain-train/src/train_loop.rs
Gahow Wang e625aa05dd dropout: wire into model (residual sites) + train/eval switch + flag (T18)
Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch
(train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped
once per training forward. forward_batched derives a per-layer block_seed (pure fn
of step_seed×layer) and block_forward derives two per-site seeds, inserting
ops::dropout at the attn and ffn sub-block outputs (before each residual). The
seed is a pure function of (step_seed, layer, site) so the checkpoint (T13)
recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout
node → graph bit-identical to pre-T18.

train_loop: model.train() per step (restored after eval flips to eval); eval_loss
runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in
eval (default), so exported weights are dropout-free (xserv closed loop unaffected).

Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads);
eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout
grads match non-recompute (fp32 + bf16).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 00:05:32 +08:00

205 lines
7.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! The training loop: sample a batch of sequences → ONE batched forward `loss` →
//! backward → grad clip → AdamW step → zero grads; with an LR schedule, periodic
//! loss logging, and periodic checkpointing.
//!
//! Since T10 the model is batched (`loss_batched`): `batch_size` sequences are
//! flattened to `[batch*seq]` and run as a SINGLE forward/backward, so the linear
//! projections become big `[batch*seq, dim]` GEMMs that fill the GPU. The
//! cross-entropy mean is over all `batch*seq` rows — already the batch-mean loss,
//! so backward yields the batch-mean gradient directly (clip pre-scale = 1.0; no
//! more "loop B times + SUM + ×1/batch" hack).
#![cfg(not(no_cuda))]
use std::path::PathBuf;
use std::time::Instant;
use xtrain_model::{TinyTransformer, batched_ids_tensor, ids_tensor};
use xtrain_optim::GpuAdamW;
use xtrain_tensor::Device;
use crate::checkpoint;
use crate::clip::clip_grad_norm_gpu;
use crate::data::Corpus;
use crate::schedule::LrSchedule;
/// Knobs for a training run.
pub struct TrainConfig {
pub seq_len: usize,
pub batch_size: usize,
pub steps: usize,
pub schedule: LrSchedule,
pub weight_decay: f32,
pub max_grad_norm: f32,
pub log_every: usize,
/// Optional checkpoint path written every `ckpt_every` steps (and at the end).
/// When `eval_every > 0`, the checkpoint instead tracks the BEST val loss.
pub ckpt_path: Option<PathBuf>,
pub ckpt_every: usize,
/// Evaluate held-out val loss every `eval_every` steps (0 = never). Each eval
/// averages cross-entropy over `eval_batches` fixed windows of the val corpus.
pub eval_every: usize,
pub eval_batches: usize,
/// Seed for reproducible sequence sampling.
pub seed: u64,
}
/// Outcome of a run: per-step train losses and (step, val_loss) eval points.
pub struct TrainResult {
pub train_losses: Vec<f32>,
pub evals: Vec<(usize, f32)>,
pub best_val: Option<f32>,
}
/// Train `model` on `corpus` for `cfg.steps` AdamW steps. Returns the per-step
/// train-loss trace plus any (step, val_loss) eval points. Logs progress, and —
/// when `valid` is given and `cfg.eval_every > 0` — evaluates held-out val loss
/// periodically and checkpoints the BEST val model (else checkpoints on a fixed
/// cadence, as in T6). Logs progress.
pub fn train(
model: &TinyTransformer,
device: Device,
corpus: &Corpus,
valid: Option<&Corpus>,
cfg: &TrainConfig,
) -> TrainResult {
let params = model.params();
let mut opt = GpuAdamW::new(cfg.weight_decay);
let mut rng = cfg.seed;
let mut losses = Vec::with_capacity(cfg.steps);
let mut evals = Vec::new();
let mut best_val: Option<f32> = None;
let start = Instant::now();
let mut tokens_seen: u64 = 0;
// Best-val checkpointing only kicks in when we actually evaluate.
let track_best = valid.is_some() && cfg.eval_every > 0;
for step in 0..cfg.steps {
let lr = cfg.schedule.lr(step);
// Sample `batch_size` sequences and run them as ONE batched forward/
// backward. The CE mean over all batch*seq rows is the batch-mean loss, so
// backward already yields the batch-mean gradient (clip pre-scale = 1.0).
let mut inputs = Vec::with_capacity(cfg.batch_size);
let mut targets_v = Vec::with_capacity(cfg.batch_size);
for _ in 0..cfg.batch_size {
let (input, target) = corpus.sample(cfg.seq_len, &mut rng);
inputs.push(input);
targets_v.push(target);
}
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
// Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set
// each step so it is restored after a periodic eval flips to eval mode.
model.train();
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
let step_loss = read_scalar(&loss);
loss.backward();
tokens_seen += (cfg.batch_size * cfg.seq_len) as u64;
losses.push(step_loss);
// Backward already produced the batch-mean gradient — just clip it.
let gnorm = clip_grad_norm_gpu(&params, cfg.max_grad_norm, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
if step % cfg.log_every == 0 || step == cfg.steps - 1 {
let elapsed = start.elapsed().as_secs_f32();
let tps = tokens_seen as f32 / elapsed.max(1e-6);
println!(
"step {step:5}/{}: loss {step_loss:.4} lr {lr:.2e} gnorm {gnorm:.3} \
({tps:.0} tok/s)",
cfg.steps
);
}
// Periodic held-out eval (deterministic windows, no grad).
if let Some(v) = valid {
if cfg.eval_every > 0 && ((step + 1) % cfg.eval_every == 0 || step == cfg.steps - 1) {
let vl = eval_loss(model, device, v, cfg.seq_len, cfg.eval_batches);
evals.push((step, vl));
let improved = best_val.map(|b| vl < b).unwrap_or(true);
println!(
" eval @ step {step}: val loss {vl:.4}{}",
if improved { " (best)" } else { "" }
);
if improved {
best_val = Some(vl);
if let Some(path) = &cfg.ckpt_path {
checkpoint::save(path, &params).expect("best checkpoint save");
}
}
}
}
// Fixed-cadence checkpointing (only when not tracking best val).
if !track_best {
if let Some(path) = &cfg.ckpt_path {
if cfg.ckpt_every > 0 && (step + 1) % cfg.ckpt_every == 0 {
checkpoint::save(path, &params).expect("checkpoint save");
}
}
}
}
// Without periodic eval, still persist the final params (T6 behaviour). With
// best-val tracking the checkpoint already holds the best model — don't clobber.
if !track_best {
if let Some(path) = &cfg.ckpt_path {
checkpoint::save(path, &params).expect("final checkpoint save");
println!("saved checkpoint → {}", path.display());
}
}
TrainResult {
train_losses: losses,
evals,
best_val,
}
}
/// Mean cross-entropy over `batches` deterministic, non-overlapping windows of
/// the validation corpus (no backward — eval only). Deterministic so val loss is
/// comparable across steps and runs (and across models — the v0-vs-v1 metric).
pub fn eval_loss(
model: &TinyTransformer,
device: Device,
valid: &Corpus,
seq: usize,
batches: usize,
) -> f32 {
if valid.len() <= seq + 1 {
return f32::NAN;
}
// Eval mode → dropout is identity (T18).
model.eval();
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
let batches = batches.max(1).min(n_win.max(1));
let stride = (n_win / batches).max(1);
let mut sum = 0.0f32;
let mut count = 0usize;
for i in 0..batches {
let s = (i * stride) * seq;
if s + seq + 1 > valid.len() {
break;
}
let input: Vec<i32> = valid.tokens[s..s + seq].to_vec();
let target: Vec<i32> = valid.tokens[s + 1..s + seq + 1].to_vec();
let ids = ids_tensor(&input, device);
let targets = ids_tensor(&target, device);
let loss = model.loss(&ids, &targets);
sum += read_scalar(&loss);
count += 1;
}
if count == 0 {
f32::NAN
} else {
sum / count as f32
}
}
fn read_scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
}