// T18 dropout model-level gates. // // 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or // eval mode) produces logits/loss/grads bit-for-bit identical to the same // model with no dropout field touched — the default forward graph is // unchanged (the regression guard). // 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward // bit-for-bit (dropout is OFF at eval). // 3. train vs eval differ: with p>0 and train mode, the forward differs from // eval (dropout actually does something) and grads are still finite. // 4. recompute compatibility: with p>0 + train + recompute, grads match the // non-recompute path (the counter-based seed reproduces the same mask on the // backward re-run — T13 stays exact even with dropout in the block). // // (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check // live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the // dash5 short run noted in docs/17-dropout.md.) #![cfg(not(no_cuda))] use xtrain_cuda::device; use xtrain_model::{Config, TinyTransformer, batched_ids_tensor}; use xtrain_tensor::{DType, Device}; fn fill(n: usize, seed: u64, scale: f32) -> Vec { let mut state = seed .wrapping_mul(2862933555777941757) .wrapping_add(3037000493); (0..n) .map(|_| { state = state .wrapping_mul(6364136223846793005) .wrapping_add(1442695040888963407); (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale }) .collect() } fn build(cfg: Config, device: Device) -> TinyTransformer { let mut seed = 1u64; TinyTransformer::new(cfg, device, |shape| { seed = seed.wrapping_add(1); let n: usize = shape.iter().product(); if shape.len() == 1 { fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() } else { fill(n, seed, 0.08) } }) } fn host(t: &xtrain_tensor::Tensor) -> Vec { t.to_dtype(DType::F32) .to_device(Device::Cpu) .as_slice::() .to_vec() } fn tiny_cfg(dropout: f32) -> Config { let mut cfg = Config::tiny(); cfg.vocab = 16; cfg.n_layers = 4; cfg.dropout = dropout; cfg } fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) { let (batch, seq) = (3usize, 6usize); let seqs: Vec> = (0..batch) .map(|b| { (0..seq) .map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32) .collect() }) .collect(); let tgts: Vec> = (0..batch) .map(|b| { (0..seq) .map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32) .collect() }) .collect(); ( batched_ids_tensor(&seqs, device), batched_ids_tensor(&tgts, device), ) } fn require_gpu() -> Device { assert!(device::device_count().unwrap() > 0, "no CUDA device"); device::set_device(0).unwrap(); Device::Cuda(0) } // Run forward+backward, return (logits, loss, per-param grads). fn fwd_bwd( m: &TinyTransformer, ids: &xtrain_tensor::Tensor, tgt: &xtrain_tensor::Tensor, batch: usize, ) -> (Vec, f32, Vec>) { let logits = host(&m.forward_batched(ids, batch).value()); let loss = m.loss_batched(ids, tgt, batch); let loss_val = host(&loss.value())[0]; loss.backward(); let grads: Vec> = m .params() .iter() .map(|p| host(&p.grad().unwrap())) .collect(); (logits, loss_val, grads) } // --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). --- #[test] fn dropout_p0_bit_identical() { let device = require_gpu(); let batch = 3; // Reference: cfg.dropout default (0.0), never touched train/eval. let cfg0 = tiny_cfg(0.0); let (ids, tgt) = batch_data(&cfg0, device); let ref_m = build(cfg0, device); let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch); // p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0, // so the graph must be byte-identical. let p0_train = build(tiny_cfg(0.0), device); p0_train.train(); let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch); assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical"); assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical"); for (i, (a, b)) in ref_grads.iter().zip(>).enumerate() { assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical"); } println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅"); } // --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). --- #[test] fn dropout_eval_is_identity() { let device = require_gpu(); let batch = 3; let cfg = tiny_cfg(0.2); let (ids, tgt) = batch_data(&cfg, device); // p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit. let ref_m = build(tiny_cfg(0.0), device); let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch); let eval_m = build(cfg, device); eval_m.eval(); // explicit; also the default let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch); assert_eq!(ref_logits, el, "eval (p>0) logits not identity"); assert_eq!(ref_loss, els, "eval (p>0) loss not identity"); for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() { assert_eq!(a, b, "eval (p>0) grad[{i}] not identity"); } println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅"); } // --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. --- #[test] fn dropout_train_differs_from_eval() { let device = require_gpu(); let batch = 3; let cfg = tiny_cfg(0.3); let (ids, _tgt) = batch_data(&cfg, device); let m = build(cfg, device); m.eval(); let eval_logits = host(&m.forward_batched(&ids, batch).value()); m.train(); let train_logits = host(&m.forward_batched(&ids, batch).value()); let max_diff = eval_logits .iter() .zip(&train_logits) .map(|(a, b)| (a - b).abs()) .fold(0.0f32, f32::max); assert!( max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()), "train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}" ); println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅"); } // --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). --- // The counter-based seed is a pure function of (step_seed, layer, site); the // checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the // recomputed dropout masks match the forward — grads stay bit-identical. fn recompute_with_dropout(dtype: DType, grad_tol: f32) { let device = require_gpu(); let batch = 3; let cfg = tiny_cfg(0.2); let (ids, tgt) = batch_data(&cfg, device); // Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps // to 1 on the first training forward in BOTH, so they draw the same masks. let off = build(cfg, device) .with_compute_dtype(dtype) .with_training(true); let on = build(cfg, device) .with_compute_dtype(dtype) .with_recompute(true) .with_training(true); let off_loss = off.loss_batched(&ids, &tgt, batch); off_loss.backward(); let off_grads: Vec> = off .params() .iter() .map(|p| host(&p.grad().unwrap())) .collect(); let on_loss = on.loss_batched(&ids, &tgt, batch); on_loss.backward(); let on_grads: Vec> = on .params() .iter() .map(|p| host(&p.grad().unwrap())) .collect(); let mut max_rel = 0.0f32; for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) { max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3)); } println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}"); assert!( max_rel < grad_tol, "[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}" ); } #[test] fn dropout_recompute_matches_fp32() { recompute_with_dropout(DType::F32, 1e-4); } #[test] fn dropout_recompute_matches_bf16() { recompute_with_dropout(DType::BF16, 5e-3); } // --- Cross-feature gate (Phase-2 integration): flash (T14) + dropout (T18) // together in the SAME model still grad-checks. Build two identical models, both // in train mode with p=0.2 (so dropout fires), one with `--flash` on, one off. // The dropout site seeds are a pure function of (step_seed, layer, site) and are // INDEPENDENT of flash, so both models draw the SAME masks on their first training // forward → the only difference is the SDPA reduction order. Assert logits/loss/ // grads match within the flash-vs-composed tolerance and are finite. This is the // orthogonality check for the two Phase-2 features landing together. #[test] fn flash_plus_dropout_grad_check_fp32() { let device = require_gpu(); let batch = 3; // seq=40 > FA_TILE=32 exercises flash's online-softmax tile-rescale path. let mut cfg = Config::tiny(); cfg.vocab = 16; cfg.n_layers = 4; cfg.dropout = 0.2; let seq = 40usize; let seqs: Vec> = (0..batch) .map(|b| { (0..seq) .map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32) .collect() }) .collect(); let tgts: Vec> = (0..batch) .map(|b| { (0..seq) .map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32) .collect() }) .collect(); let ids = batched_ids_tensor(&seqs, device); let tgt = batched_ids_tensor(&tgts, device); // Both: same init, train mode (dropout active), same step_seed progression → // identical masks; one composed SDPA, one flash SDPA. let off = build(cfg, device).with_training(true); let on = build(cfg, device).with_flash(true).with_training(true); let (off_logits, off_loss, off_grads) = fwd_bwd(&off, &ids, &tgt, batch); let (on_logits, on_loss, on_grads) = fwd_bwd(&on, &ids, &tgt, batch); assert!( on_logits.iter().all(|v| v.is_finite()) && on_grads.iter().flatten().all(|v| v.is_finite()), "flash+dropout produced non-finite logits/grads" ); let logit_rel = off_logits .iter() .zip(&on_logits) .map(|(a, b)| (a - b).abs() / a.abs().max(1e-4)) .fold(0.0f32, f32::max); let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4); let mut grad_rel = 0.0f32; for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) { grad_rel = grad_rel.max((a - b).abs() / a.abs().max(1e-3)); } println!( "[F32] flash+dropout vs composed+dropout: loss rel {loss_rel:.2e}, \ logits max rel {logit_rel:.2e}, grad max rel {grad_rel:.3e}" ); // Same tolerances as the flash-vs-composed gate (flash.rs run_fp32): flash // differs from composed only by reduction order; dropout masks are identical. assert!( logit_rel < 1e-3, "[F32] flash+dropout logits diverged: {logit_rel:.2e}" ); assert!( loss_rel < 1e-3, "[F32] flash+dropout loss diverged: {loss_rel:.2e}" ); assert!( grad_rel < 2e-2, "[F32] flash+dropout grads diverged: {grad_rel:.3e}" ); }