From e625aa05dd607e0e8797bfae6f48d58ff9dda021 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 18 Jun 2026 00:05:32 +0800 Subject: [PATCH] dropout: wire into model (residual sites) + train/eval switch + flag (T18) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Config.dropout (default 0). TinyTransformer gets a Cell training switch (train()/eval()/with_training, default eval = safe) + a Cell step_seed bumped once per training forward. forward_batched derives a per-layer block_seed (pure fn of step_seed×layer) and block_forward derives two per-site seeds, inserting ops::dropout at the attn and ffn sub-block outputs (before each residual). The seed is a pure function of (step_seed, layer, site) so the checkpoint (T13) recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout node → graph bit-identical to pre-T18. train_loop: model.train() per step (restored after eval flips to eval); eval_loss runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in eval (default), so exported weights are dropout-free (xserv closed loop unaffected). Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads); eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout grads match non-recompute (fp32 + bf16). Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-model/src/config.rs | 7 + crates/xtrain-model/src/model.rs | 105 ++++++++++-- crates/xtrain-model/tests/dropout.rs | 222 ++++++++++++++++++++++++++ crates/xtrain-train/src/bin/train.rs | 10 +- crates/xtrain-train/src/train_loop.rs | 5 + 5 files changed, 339 insertions(+), 10 deletions(-) create mode 100644 crates/xtrain-model/tests/dropout.rs diff --git a/crates/xtrain-model/src/config.rs b/crates/xtrain-model/src/config.rs index 3eba07a..554f930 100644 --- a/crates/xtrain-model/src/config.rs +++ b/crates/xtrain-model/src/config.rs @@ -20,6 +20,11 @@ pub struct Config { pub eps: f32, /// RoPE base frequency (theta). pub rope_theta: f32, + /// Dropout probability `p` (Phase T18). Applied at the attention/MLP sub-block + /// outputs (before each residual add) at TRAINING time, with inverted scaling + /// `1/(1-p)`; disabled (identity) at eval. Default `0.0` = no dropout, and the + /// forward graph is then bit-identical to the pre-T18 path. + pub dropout: f32, } impl Config { @@ -36,6 +41,7 @@ impl Config { ffn_hidden: 64, eps: 1e-5, rope_theta: 10000.0, + dropout: 0.0, } } @@ -60,6 +66,7 @@ impl Config { ffn_hidden, eps: 1e-5, rope_theta: 10000.0, + dropout: 0.0, } } diff --git a/crates/xtrain-model/src/model.rs b/crates/xtrain-model/src/model.rs index 830f068..bb57a3b 100644 --- a/crates/xtrain-model/src/model.rs +++ b/crates/xtrain-model/src/model.rs @@ -2,6 +2,8 @@ #![cfg(not(no_cuda))] +use std::cell::Cell; + use crate::config::Config; use xtrain_autodiff::ops; use xtrain_autodiff::tape::Var; @@ -47,6 +49,19 @@ pub struct TinyTransformer { /// existing numerics are bit-identical; recompute is mathematically exact, so /// grads match the non-checkpointed path within fp tolerance. recompute: bool, + /// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block + /// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step, + /// per-site seed); `false` (default) → dropout is identity (eval/sampling/ + /// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward + /// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph + /// is bit-identical to the no-dropout path either way. + training: Cell, + /// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each + /// TRAINING forward so every step draws fresh masks; combined with the layer + /// index + a per-site constant to give each dropout site its own seed. The RNG + /// is counter-based, so re-running a checkpointed block's forward in backward + /// (T13) reproduces the same seed → the same mask (recompute stays exact). + step_seed: Cell, } impl TinyTransformer { @@ -90,6 +105,8 @@ impl TinyTransformer { lm_head, compute_dtype: DType::F32, recompute: false, + training: Cell::new(false), + step_seed: Cell::new(0), } } @@ -127,6 +144,30 @@ impl TinyTransformer { self.recompute } + /// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is + /// active in subsequent forwards. The training loop calls this before stepping. + pub fn train(&self) { + self.training.set(true); + } + + /// Switch to eval mode (Phase T18): dropout is identity. Held-out eval, + /// autoregressive sampling, and weight export all run in this mode (default). + pub fn eval(&self) { + self.training.set(false); + } + + pub fn is_training(&self) -> bool { + self.training.get() + } + + /// Builder-style train/eval toggle (Phase T18) — handy for tests that want a + /// model fixed in one mode. Equivalent to [`train`](Self::train) / + /// [`eval`](Self::eval) but chains off `new(..)`. + pub fn with_training(self, training: bool) -> Self { + self.training.set(training); + self + } + /// All learnable parameters, in a stable order. The optimizer (a hand-written /// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after /// `backward()`. @@ -176,13 +217,34 @@ impl TinyTransformer { ); let seq = total / batch; + // Dropout (T18) is active only in training mode with p>0; otherwise it is + // identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per + // training forward so each step draws fresh masks (counter-based RNG, so a + // checkpointed block's recompute reproduces the same seed → same mask). + let dropout_p = if self.training.get() { + self.cfg.dropout + } else { + 0.0 + }; + if dropout_p > 0.0 { + self.step_seed.set(self.step_seed.get().wrapping_add(1)); + } + let base_seed = self.step_seed.get(); + // Embedding gathers from the fp32 master table; in bf16 mode cast the // activation stream to bf16 here (norms are cast to bf16 gammas too). let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32 if self.compute_dtype == DType::BF16 { h = ops::cast(&h, DType::BF16); } - for b in &self.blocks { + for (li, b) in self.blocks.iter().enumerate() { + // Per-layer dropout seed: a deterministic function of (base_seed, + // layer index) — NOT a mutable counter — so the checkpoint recompute + // (which re-derives it from the captured base_seed/li) gets the same + // masks. The block derives its two per-site seeds from this. + let block_seed = base_seed + .wrapping_mul(0x100000001B3) + .wrapping_add(li as u64); h = if self.recompute { // Activation recomputation (T13): run the whole block forward inside // `checkpoint` so its internal activations aren't kept on the tape; @@ -190,7 +252,9 @@ impl TinyTransformer { // segment fn captures only `Copy` config (no borrow of `self`) and // receives the block's params via the slice, in `block_params` order. let (cfg, cdt) = (self.cfg, self.compute_dtype); - let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p); + let seg = move |x: &Var, p: &[Var]| { + block_forward(cfg, cdt, batch, seq, dropout_p, block_seed, x, p) + }; xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params()) } else { block_forward( @@ -198,6 +262,8 @@ impl TinyTransformer { self.compute_dtype, batch, seq, + dropout_p, + block_seed, &h, &b.block_params(), ) @@ -275,25 +341,46 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var { } /// One transformer block's forward: pre-norm + multi-head causal attention + -/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch, -/// seq, input, params)` (no `&self`) so it can be the segment fn of -/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is -/// the block's leaves in [`Block::block_params`] order. -fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var { +/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual. +/// Pure in `(cfg, cdt, batch, seq, dropout_p, block_seed, input, params)` (no +/// `&self`, all `Copy`) so it can be the segment fn of +/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the +/// recompute re-derives the same per-site seeds, so the dropout masks are +/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the +/// graph is then identical to the pre-T18 path). `params` is the block's leaves in +/// [`Block::block_params`] order. +#[allow(clippy::too_many_arguments)] +fn block_forward( + cfg: Config, + cdt: DType, + batch: usize, + seq: usize, + dropout_p: f32, + block_seed: u64, + h: &Var, + p: &[Var], +) -> Var { let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]); let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]); let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]); - // --- Attention sub-block (pre-norm + residual) --- + // Per-site dropout seeds (XOR a site constant into the block seed) so the two + // residual-path dropouts draw independent masks within the same step/layer. + let attn_seed = block_seed ^ 0x0A7700; + let ffn_seed = block_seed ^ 0x0FF700; + + // --- Attention sub-block (pre-norm + dropout + residual) --- let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps); let attn = attention( cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo, ); + let attn = ops::dropout(&attn, dropout_p, attn_seed); let h = ops::add(h, &attn); - // --- MLP sub-block (pre-norm + residual) --- + // --- MLP sub-block (pre-norm + dropout + residual) --- let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps); let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down); + let mlp = ops::dropout(&mlp, dropout_p, ffn_seed); ops::add(&h, &mlp) } diff --git a/crates/xtrain-model/tests/dropout.rs b/crates/xtrain-model/tests/dropout.rs new file mode 100644 index 0000000..04d5d68 --- /dev/null +++ b/crates/xtrain-model/tests/dropout.rs @@ -0,0 +1,222 @@ +// T18 dropout model-level gates. +// +// 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or +// eval mode) produces logits/loss/grads bit-for-bit identical to the same +// model with no dropout field touched — the default forward graph is +// unchanged (the regression guard). +// 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward +// bit-for-bit (dropout is OFF at eval). +// 3. train vs eval differ: with p>0 and train mode, the forward differs from +// eval (dropout actually does something) and grads are still finite. +// 4. recompute compatibility: with p>0 + train + recompute, grads match the +// non-recompute path (the counter-based seed reproduces the same mask on the +// backward re-run — T13 stays exact even with dropout in the block). +// +// (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check +// live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the +// dash5 short run noted in docs/17-dropout.md.) +#![cfg(not(no_cuda))] + +use xtrain_cuda::device; +use xtrain_model::{Config, TinyTransformer, batched_ids_tensor}; +use xtrain_tensor::{DType, Device}; + +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +fn build(cfg: Config, device: Device) -> TinyTransformer { + let mut seed = 1u64; + TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.08) + } + }) +} + +fn host(t: &xtrain_tensor::Tensor) -> Vec { + t.to_dtype(DType::F32) + .to_device(Device::Cpu) + .as_slice::() + .to_vec() +} + +fn tiny_cfg(dropout: f32) -> Config { + let mut cfg = Config::tiny(); + cfg.vocab = 16; + cfg.n_layers = 4; + cfg.dropout = dropout; + cfg +} + +fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) { + let (batch, seq) = (3usize, 6usize); + let seqs: Vec> = (0..batch) + .map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect()) + .collect(); + let tgts: Vec> = (0..batch) + .map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect()) + .collect(); + ( + batched_ids_tensor(&seqs, device), + batched_ids_tensor(&tgts, device), + ) +} + +fn require_gpu() -> Device { + assert!(device::device_count().unwrap() > 0, "no CUDA device"); + device::set_device(0).unwrap(); + Device::Cuda(0) +} + +// Run forward+backward, return (logits, loss, per-param grads). +fn fwd_bwd( + m: &TinyTransformer, + ids: &xtrain_tensor::Tensor, + tgt: &xtrain_tensor::Tensor, + batch: usize, +) -> (Vec, f32, Vec>) { + let logits = host(&m.forward_batched(ids, batch).value()); + let loss = m.loss_batched(ids, tgt, batch); + let loss_val = host(&loss.value())[0]; + loss.backward(); + let grads: Vec> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + (logits, loss_val, grads) +} + +// --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). --- +#[test] +fn dropout_p0_bit_identical() { + let device = require_gpu(); + let batch = 3; + + // Reference: cfg.dropout default (0.0), never touched train/eval. + let cfg0 = tiny_cfg(0.0); + let (ids, tgt) = batch_data(&cfg0, device); + let ref_m = build(cfg0, device); + let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch); + + // p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0, + // so the graph must be byte-identical. + let p0_train = build(tiny_cfg(0.0), device); + p0_train.train(); + let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch); + + assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical"); + assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical"); + for (i, (a, b)) in ref_grads.iter().zip(>).enumerate() { + assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical"); + } + println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅"); +} + +// --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). --- +#[test] +fn dropout_eval_is_identity() { + let device = require_gpu(); + let batch = 3; + let cfg = tiny_cfg(0.2); + let (ids, tgt) = batch_data(&cfg, device); + + // p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit. + let ref_m = build(tiny_cfg(0.0), device); + let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch); + + let eval_m = build(cfg, device); + eval_m.eval(); // explicit; also the default + let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch); + + assert_eq!(ref_logits, el, "eval (p>0) logits not identity"); + assert_eq!(ref_loss, els, "eval (p>0) loss not identity"); + for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() { + assert_eq!(a, b, "eval (p>0) grad[{i}] not identity"); + } + println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅"); +} + +// --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. --- +#[test] +fn dropout_train_differs_from_eval() { + let device = require_gpu(); + let batch = 3; + let cfg = tiny_cfg(0.3); + let (ids, _tgt) = batch_data(&cfg, device); + + let m = build(cfg, device); + m.eval(); + let eval_logits = host(&m.forward_batched(&ids, batch).value()); + m.train(); + let train_logits = host(&m.forward_batched(&ids, batch).value()); + + let max_diff = eval_logits + .iter() + .zip(&train_logits) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + assert!( + max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()), + "train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}" + ); + println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅"); +} + +// --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). --- +// The counter-based seed is a pure function of (step_seed, layer, site); the +// checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the +// recomputed dropout masks match the forward — grads stay bit-identical. +fn recompute_with_dropout(dtype: DType, grad_tol: f32) { + let device = require_gpu(); + let batch = 3; + let cfg = tiny_cfg(0.2); + let (ids, tgt) = batch_data(&cfg, device); + + // Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps + // to 1 on the first training forward in BOTH, so they draw the same masks. + let off = build(cfg, device).with_compute_dtype(dtype).with_training(true); + let on = build(cfg, device) + .with_compute_dtype(dtype) + .with_recompute(true) + .with_training(true); + + let off_loss = off.loss_batched(&ids, &tgt, batch); + off_loss.backward(); + let off_grads: Vec> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + + let on_loss = on.loss_batched(&ids, &tgt, batch); + on_loss.backward(); + let on_grads: Vec> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect(); + + let mut max_rel = 0.0f32; + for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) { + max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3)); + } + println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}"); + assert!( + max_rel < grad_tol, + "[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}" + ); +} + +#[test] +fn dropout_recompute_matches_fp32() { + recompute_with_dropout(DType::F32, 1e-4); +} + +#[test] +fn dropout_recompute_matches_bf16() { + recompute_with_dropout(DType::BF16, 5e-3); +} diff --git a/crates/xtrain-train/src/bin/train.rs b/crates/xtrain-train/src/bin/train.rs index b1d50cf..c11a12e 100644 --- a/crates/xtrain-train/src/bin/train.rs +++ b/crates/xtrain-train/src/bin/train.rs @@ -109,6 +109,10 @@ fn main() { let val_tokens: usize = flag(&args, "--val-tokens", 0); let eval_every: usize = flag(&args, "--eval-every", 0); let eval_batches: usize = flag(&args, "--eval-batches", 64); + // Dropout (Phase T18): residual-path dropout prob, active at training time + // only (inverted scaling), identity at eval/sampling/export. Default 0 = off + // (forward graph bit-identical to the no-dropout path). + let dropout: f32 = flag(&args, "--dropout", 0.0f32); // bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears + // activations. Opt-in; default fp32 reproduces v0–v4 numerics. let bf16 = args.iter().any(|a| a == "--bf16"); @@ -149,7 +153,8 @@ fn main() { (corpus, None) }; - let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn); + let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn); + cfg.dropout = dropout; println!( "model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \ (+ embed/lm {:.2}M = {:.2}M total)", @@ -183,6 +188,9 @@ fn main() { model = model.with_recompute(true); println!("activation recompute: ON (per-block gradient checkpointing)"); } + if dropout > 0.0 { + println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)"); + } // Eval-only mode: load a checkpoint and score it on the held-out val set, then // exit. Used to put an EXISTING model (e.g. v0) and a new one on the same diff --git a/crates/xtrain-train/src/train_loop.rs b/crates/xtrain-train/src/train_loop.rs index 4d93b7f..8ba5f52 100644 --- a/crates/xtrain-train/src/train_loop.rs +++ b/crates/xtrain-train/src/train_loop.rs @@ -89,6 +89,9 @@ pub fn train( } let ids = batched_ids_tensor(&inputs, device); let targets = batched_ids_tensor(&targets_v, device); + // Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set + // each step so it is restored after a periodic eval flips to eval mode. + model.train(); let loss = model.loss_batched(&ids, &targets, cfg.batch_size); let step_loss = read_scalar(&loss); loss.backward(); @@ -169,6 +172,8 @@ pub fn eval_loss( if valid.len() <= seq + 1 { return f32::NAN; } + // Eval mode → dropout is identity (T18). + model.eval(); let n_win = (valid.len() - 1) / seq; // disjoint windows that fit let batches = batches.max(1).min(n_win.max(1)); let stride = (n_win / batches).max(1);