dropout: wire into model (residual sites) + train/eval switch + flag (T18)

Config.dropout (default 0). TinyTransformer gets a Cell<bool> training switch
(train()/eval()/with_training, default eval = safe) + a Cell<u64> step_seed bumped
once per training forward. forward_batched derives a per-layer block_seed (pure fn
of step_seed×layer) and block_forward derives two per-site seeds, inserting
ops::dropout at the attn and ffn sub-block outputs (before each residual). The
seed is a pure function of (step_seed, layer, site) so the checkpoint (T13)
recompute re-derives the same masks → grads stay exact. p=0 or eval → no dropout
node → graph bit-identical to pre-T18.

train_loop: model.train() per step (restored after eval flips to eval); eval_loss
runs model.eval(). bin/train: --dropout flag → cfg.dropout. Export/sampling run in
eval (default), so exported weights are dropout-free (xserv closed loop unaffected).

Model-level tests (dropout.rs): p=0 bit-identical to no-dropout (logits/loss/grads);
eval(p>0) == p=0 identity; train differs from eval + finite; recompute-with-dropout
grads match non-recompute (fp32 + bf16).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 00:05:32 +08:00
parent 5eb27783f8
commit e625aa05dd
5 changed files with 339 additions and 10 deletions

View File

@@ -20,6 +20,11 @@ pub struct Config {
pub eps: f32,
/// RoPE base frequency (theta).
pub rope_theta: f32,
/// Dropout probability `p` (Phase T18). Applied at the attention/MLP sub-block
/// outputs (before each residual add) at TRAINING time, with inverted scaling
/// `1/(1-p)`; disabled (identity) at eval. Default `0.0` = no dropout, and the
/// forward graph is then bit-identical to the pre-T18 path.
pub dropout: f32,
}
impl Config {
@@ -36,6 +41,7 @@ impl Config {
ffn_hidden: 64,
eps: 1e-5,
rope_theta: 10000.0,
dropout: 0.0,
}
}
@@ -60,6 +66,7 @@ impl Config {
ffn_hidden,
eps: 1e-5,
rope_theta: 10000.0,
dropout: 0.0,
}
}

View File

@@ -2,6 +2,8 @@
#![cfg(not(no_cuda))]
use std::cell::Cell;
use crate::config::Config;
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
@@ -47,6 +49,19 @@ pub struct TinyTransformer {
/// existing numerics are bit-identical; recompute is mathematically exact, so
/// grads match the non-checkpointed path within fp tolerance.
recompute: bool,
/// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block
/// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step,
/// per-site seed); `false` (default) → dropout is identity (eval/sampling/
/// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward
/// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph
/// is bit-identical to the no-dropout path either way.
training: Cell<bool>,
/// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each
/// TRAINING forward so every step draws fresh masks; combined with the layer
/// index + a per-site constant to give each dropout site its own seed. The RNG
/// is counter-based, so re-running a checkpointed block's forward in backward
/// (T13) reproduces the same seed → the same mask (recompute stays exact).
step_seed: Cell<u64>,
}
impl TinyTransformer {
@@ -90,6 +105,8 @@ impl TinyTransformer {
lm_head,
compute_dtype: DType::F32,
recompute: false,
training: Cell::new(false),
step_seed: Cell::new(0),
}
}
@@ -127,6 +144,30 @@ impl TinyTransformer {
self.recompute
}
/// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is
/// active in subsequent forwards. The training loop calls this before stepping.
pub fn train(&self) {
self.training.set(true);
}
/// Switch to eval mode (Phase T18): dropout is identity. Held-out eval,
/// autoregressive sampling, and weight export all run in this mode (default).
pub fn eval(&self) {
self.training.set(false);
}
pub fn is_training(&self) -> bool {
self.training.get()
}
/// Builder-style train/eval toggle (Phase T18) — handy for tests that want a
/// model fixed in one mode. Equivalent to [`train`](Self::train) /
/// [`eval`](Self::eval) but chains off `new(..)`.
pub fn with_training(self, training: bool) -> Self {
self.training.set(training);
self
}
/// All learnable parameters, in a stable order. The optimizer (a hand-written
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
/// `backward()`.
@@ -176,13 +217,34 @@ impl TinyTransformer {
);
let seq = total / batch;
// Dropout (T18) is active only in training mode with p>0; otherwise it is
// identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per
// training forward so each step draws fresh masks (counter-based RNG, so a
// checkpointed block's recompute reproduces the same seed → same mask).
let dropout_p = if self.training.get() {
self.cfg.dropout
} else {
0.0
};
if dropout_p > 0.0 {
self.step_seed.set(self.step_seed.get().wrapping_add(1));
}
let base_seed = self.step_seed.get();
// Embedding gathers from the fp32 master table; in bf16 mode cast the
// activation stream to bf16 here (norms are cast to bf16 gammas too).
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
if self.compute_dtype == DType::BF16 {
h = ops::cast(&h, DType::BF16);
}
for b in &self.blocks {
for (li, b) in self.blocks.iter().enumerate() {
// Per-layer dropout seed: a deterministic function of (base_seed,
// layer index) — NOT a mutable counter — so the checkpoint recompute
// (which re-derives it from the captured base_seed/li) gets the same
// masks. The block derives its two per-site seeds from this.
let block_seed = base_seed
.wrapping_mul(0x100000001B3)
.wrapping_add(li as u64);
h = if self.recompute {
// Activation recomputation (T13): run the whole block forward inside
// `checkpoint` so its internal activations aren't kept on the tape;
@@ -190,7 +252,9 @@ impl TinyTransformer {
// segment fn captures only `Copy` config (no borrow of `self`) and
// receives the block's params via the slice, in `block_params` order.
let (cfg, cdt) = (self.cfg, self.compute_dtype);
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p);
let seg = move |x: &Var, p: &[Var]| {
block_forward(cfg, cdt, batch, seq, dropout_p, block_seed, x, p)
};
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
} else {
block_forward(
@@ -198,6 +262,8 @@ impl TinyTransformer {
self.compute_dtype,
batch,
seq,
dropout_p,
block_seed,
&h,
&b.block_params(),
)
@@ -275,25 +341,46 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
}
/// One transformer block's forward: pre-norm + multi-head causal attention +
/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch,
/// seq, input, params)` (no `&self`) so it can be the segment fn of
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
/// the block's leaves in [`Block::block_params`] order.
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var {
/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual.
/// Pure in `(cfg, cdt, batch, seq, dropout_p, block_seed, input, params)` (no
/// `&self`, all `Copy`) so it can be the segment fn of
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the
/// recompute re-derives the same per-site seeds, so the dropout masks are
/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the
/// graph is then identical to the pre-T18 path). `params` is the block's leaves in
/// [`Block::block_params`] order.
#[allow(clippy::too_many_arguments)]
fn block_forward(
cfg: Config,
cdt: DType,
batch: usize,
seq: usize,
dropout_p: f32,
block_seed: u64,
h: &Var,
p: &[Var],
) -> Var {
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
// --- Attention sub-block (pre-norm + residual) ---
// Per-site dropout seeds (XOR a site constant into the block seed) so the two
// residual-path dropouts draw independent masks within the same step/layer.
let attn_seed = block_seed ^ 0x0A7700;
let ffn_seed = block_seed ^ 0x0FF700;
// --- Attention sub-block (pre-norm + dropout + residual) ---
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
let attn = attention(
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
);
let attn = ops::dropout(&attn, dropout_p, attn_seed);
let h = ops::add(h, &attn);
// --- MLP sub-block (pre-norm + residual) ---
// --- MLP sub-block (pre-norm + dropout + residual) ---
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
let mlp = ops::dropout(&mlp, dropout_p, ffn_seed);
ops::add(&h, &mlp)
}

View File

@@ -0,0 +1,222 @@
// T18 dropout model-level gates.
//
// 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or
// eval mode) produces logits/loss/grads bit-for-bit identical to the same
// model with no dropout field touched — the default forward graph is
// unchanged (the regression guard).
// 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward
// bit-for-bit (dropout is OFF at eval).
// 3. train vs eval differ: with p>0 and train mode, the forward differs from
// eval (dropout actually does something) and grads are still finite.
// 4. recompute compatibility: with p>0 + train + recompute, grads match the
// non-recompute path (the counter-based seed reproduces the same mask on the
// backward re-run — T13 stays exact even with dropout in the block).
//
// (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check
// live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the
// dash5 short run noted in docs/17-dropout.md.)
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn tiny_cfg(dropout: f32) -> Config {
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 4;
cfg.dropout = dropout;
cfg
}
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
let (batch, seq) = (3usize, 6usize);
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
.collect();
(
batched_ids_tensor(&seqs, device),
batched_ids_tensor(&tgts, device),
)
}
fn require_gpu() -> Device {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
Device::Cuda(0)
}
// Run forward+backward, return (logits, loss, per-param grads).
fn fwd_bwd(
m: &TinyTransformer,
ids: &xtrain_tensor::Tensor,
tgt: &xtrain_tensor::Tensor,
batch: usize,
) -> (Vec<f32>, f32, Vec<Vec<f32>>) {
let logits = host(&m.forward_batched(ids, batch).value());
let loss = m.loss_batched(ids, tgt, batch);
let loss_val = host(&loss.value())[0];
loss.backward();
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
(logits, loss_val, grads)
}
// --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). ---
#[test]
fn dropout_p0_bit_identical() {
let device = require_gpu();
let batch = 3;
// Reference: cfg.dropout default (0.0), never touched train/eval.
let cfg0 = tiny_cfg(0.0);
let (ids, tgt) = batch_data(&cfg0, device);
let ref_m = build(cfg0, device);
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
// p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0,
// so the graph must be byte-identical.
let p0_train = build(tiny_cfg(0.0), device);
p0_train.train();
let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch);
assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical");
assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical");
for (i, (a, b)) in ref_grads.iter().zip(&gt).enumerate() {
assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical");
}
println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅");
}
// --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). ---
#[test]
fn dropout_eval_is_identity() {
let device = require_gpu();
let batch = 3;
let cfg = tiny_cfg(0.2);
let (ids, tgt) = batch_data(&cfg, device);
// p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit.
let ref_m = build(tiny_cfg(0.0), device);
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
let eval_m = build(cfg, device);
eval_m.eval(); // explicit; also the default
let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch);
assert_eq!(ref_logits, el, "eval (p>0) logits not identity");
assert_eq!(ref_loss, els, "eval (p>0) loss not identity");
for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() {
assert_eq!(a, b, "eval (p>0) grad[{i}] not identity");
}
println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅");
}
// --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. ---
#[test]
fn dropout_train_differs_from_eval() {
let device = require_gpu();
let batch = 3;
let cfg = tiny_cfg(0.3);
let (ids, _tgt) = batch_data(&cfg, device);
let m = build(cfg, device);
m.eval();
let eval_logits = host(&m.forward_batched(&ids, batch).value());
m.train();
let train_logits = host(&m.forward_batched(&ids, batch).value());
let max_diff = eval_logits
.iter()
.zip(&train_logits)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()),
"train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}"
);
println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅");
}
// --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). ---
// The counter-based seed is a pure function of (step_seed, layer, site); the
// checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the
// recomputed dropout masks match the forward — grads stay bit-identical.
fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
let device = require_gpu();
let batch = 3;
let cfg = tiny_cfg(0.2);
let (ids, tgt) = batch_data(&cfg, device);
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
// to 1 on the first training forward in BOTH, so they draw the same masks.
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
let on = build(cfg, device)
.with_compute_dtype(dtype)
.with_recompute(true)
.with_training(true);
let off_loss = off.loss_batched(&ids, &tgt, batch);
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
let on_loss = on.loss_batched(&ids, &tgt, batch);
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
let mut max_rel = 0.0f32;
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3));
}
println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}");
assert!(
max_rel < grad_tol,
"[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}"
);
}
#[test]
fn dropout_recompute_matches_fp32() {
recompute_with_dropout(DType::F32, 1e-4);
}
#[test]
fn dropout_recompute_matches_bf16() {
recompute_with_dropout(DType::BF16, 5e-3);
}

View File

@@ -109,6 +109,10 @@ fn main() {
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
// Dropout (Phase T18): residual-path dropout prob, active at training time
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
// (forward graph bit-identical to the no-dropout path).
let dropout: f32 = flag(&args, "--dropout", 0.0f32);
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
// activations. Opt-in; default fp32 reproduces v0v4 numerics.
let bf16 = args.iter().any(|a| a == "--bf16");
@@ -149,7 +153,8 @@ fn main() {
(corpus, None)
};
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
cfg.dropout = dropout;
println!(
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
(+ embed/lm {:.2}M = {:.2}M total)",
@@ -183,6 +188,9 @@ fn main() {
model = model.with_recompute(true);
println!("activation recompute: ON (per-block gradient checkpointing)");
}
if dropout > 0.0 {
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
}
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same

View File

@@ -89,6 +89,9 @@ pub fn train(
}
let ids = batched_ids_tensor(&inputs, device);
let targets = batched_ids_tensor(&targets_v, device);
// Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set
// each step so it is restored after a periodic eval flips to eval mode.
model.train();
let loss = model.loss_batched(&ids, &targets, cfg.batch_size);
let step_loss = read_scalar(&loss);
loss.backward();
@@ -169,6 +172,8 @@ pub fn eval_loss(
if valid.len() <= seq + 1 {
return f32::NAN;
}
// Eval mode → dropout is identity (T18).
model.eval();
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
let batches = batches.max(1).min(n_win.max(1));
let stride = (n_win / batches).max(1);