Merge t18-dropout into main
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> # Conflicts: # README.md # crates/xtrain-autodiff/tests/autograd.rs # crates/xtrain-model/src/model.rs # crates/xtrain-train/src/bin/train.rs # crates/xtrain-train/src/train_loop.rs # docs/evolution.md
This commit is contained in:
@@ -140,6 +140,31 @@ pub fn swiglu(gate: &Var, up: &Var) -> Var {
|
||||
mul(&silu(gate), up)
|
||||
}
|
||||
|
||||
/// Dropout (Phase T18). With probability `p` zero each element, scale the kept
|
||||
/// ones by `1/(1-p)` (inverted dropout — `E[out] == x`). The keep/drop mask is
|
||||
/// drawn by a counter-based RNG from `(seed, element index)`, so it is fully
|
||||
/// determined by `seed` (same `seed` ⇒ same mask: stable across the T13 recompute
|
||||
/// re-run, and held fixed across the ± perturbation of a finite-diff grad-check).
|
||||
/// Forward caches the per-element scale `mask`; **backward applies the same mask**
|
||||
/// (`dx = d ⊙ mask`), making dropout a fixed elementwise linear map of `x`.
|
||||
///
|
||||
/// `p == 0` is a no-op: returns `x.clone()` (no node added) so the default graph
|
||||
/// is bit-identical to the no-dropout path. eval-time identity is handled by the
|
||||
/// caller simply not invoking dropout (the model's train/eval switch).
|
||||
pub fn dropout(x: &Var, p: f32, seed: u64) -> Var {
|
||||
if p == 0.0 {
|
||||
return x.clone();
|
||||
}
|
||||
let (out, mask) = x.value().dropout(p, seed);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
Var::push_grad(&parents[0], Tensor::dropout_backward(d, &mask));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]` with per-sequence position
|
||||
/// `row % period` (`period` = sequence length; `period == tokens` for a single
|
||||
/// sequence). Orthogonal map, so the backward is the inverse rotation of `dy` — no
|
||||
|
||||
@@ -776,6 +776,96 @@ fn flash_bwd_matches_composed_bwd() {
|
||||
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
|
||||
}
|
||||
|
||||
// ---- dropout (Phase T18) ----
|
||||
//
|
||||
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
|
||||
// (it depends only on (seed, index), NOT on x), so dropout is a fixed elementwise
|
||||
// linear map `out_i = c_i·x_i` and the central difference of L is differentiable:
|
||||
// the ± perturbation of each x_i sees the SAME mask. The forward function in the
|
||||
// closure calls `ops::dropout(x, p, SEED)` with the same SEED, so it reproduces
|
||||
// the same mask both times.
|
||||
#[test]
|
||||
fn dropout_bwd() {
|
||||
require_gpu();
|
||||
const SEED: u64 = 0xD120_FE5E;
|
||||
let p = 0.3f32;
|
||||
let (m, n) = (16, 12);
|
||||
let x_h = fill(m * n, 71);
|
||||
let w = fill(m * n, 72);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
||||
let out = ops::dropout(&x, p, SEED);
|
||||
scalar_loss(&out, &w).backward();
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let o = ops::dropout(&Var::leaf(cuda(v, s)), p, SEED);
|
||||
weighted_sum(&o.value(), &wf)
|
||||
};
|
||||
report(
|
||||
"dropout dX",
|
||||
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// Inverted-dropout expectation + keep-rate check. Over a large tensor and a sweep
|
||||
// of seeds, the mean of dropout(x) tracks the mean of x (E[out] ≈ x, the inverted
|
||||
// 1/(1-p) scaling), and the kept fraction tracks 1-p (the RNG is ~Bernoulli).
|
||||
#[test]
|
||||
fn dropout_expectation_and_keep_rate() {
|
||||
require_gpu();
|
||||
let p = 0.25f32;
|
||||
let n = 200_000usize;
|
||||
let x_h = vec![1.0f32; n]; // mean(x) = 1 → mean(out) should ≈ 1
|
||||
let x = cuda(&x_h, &[n]);
|
||||
|
||||
let trials = 8;
|
||||
let mut mean_out_acc = 0.0f64;
|
||||
let mut keep_acc = 0.0f64;
|
||||
for t in 0..trials {
|
||||
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
|
||||
let out_h = out.to_device(Device::Cpu);
|
||||
let mask_h = mask.to_device(Device::Cpu);
|
||||
let mean_out: f64 =
|
||||
out_h.as_slice::<f32>().iter().map(|&v| v as f64).sum::<f64>() / n as f64;
|
||||
let kept = mask_h.as_slice::<f32>().iter().filter(|&&m| m != 0.0).count();
|
||||
mean_out_acc += mean_out;
|
||||
keep_acc += kept as f64 / n as f64;
|
||||
}
|
||||
let mean_out = mean_out_acc / trials as f64;
|
||||
let keep_rate = keep_acc / trials as f64;
|
||||
println!(
|
||||
"dropout p={p}: E[out]={mean_out:.5} (input mean 1.0), keep_rate={keep_rate:.5} (1-p={:.3})",
|
||||
1.0 - p
|
||||
);
|
||||
assert!(
|
||||
(mean_out - 1.0).abs() < 0.01,
|
||||
"E[out] {mean_out} not ≈ input mean 1.0 (inverted scaling broken)"
|
||||
);
|
||||
assert!(
|
||||
(keep_rate - (1.0 - p) as f64).abs() < 0.01,
|
||||
"keep_rate {keep_rate} not ≈ 1-p {}",
|
||||
1.0 - p
|
||||
);
|
||||
}
|
||||
|
||||
// p=0 is a no-op (the op returns x.clone(), no node) → output is bit-identical to
|
||||
// x and its grad flows straight through (the default-graph regression guard at the
|
||||
// op level; the model-level bit-identity is in xtrain-model/tests/dropout.rs).
|
||||
#[test]
|
||||
fn dropout_p0_is_identity() {
|
||||
require_gpu();
|
||||
let (m, n) = (8, 5);
|
||||
let x_h = fill(m * n, 91);
|
||||
let x = cuda(&x_h, &[m, n]);
|
||||
let (out, _mask) = x.dropout(0.0, 12345);
|
||||
let out_h = out.to_device(Device::Cpu);
|
||||
for (a, b) in x_h.iter().zip(out_h.as_slice::<f32>()) {
|
||||
assert_eq!(*a, *b, "p=0 dropout must be identity");
|
||||
}
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
|
||||
@@ -38,6 +38,7 @@ fn main() {
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/flash_attention.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.file("../../csrc/ops/dropout.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
|
||||
@@ -500,3 +500,48 @@ unsafe extern "C" {
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// Dropout (Phase T18, csrc/ops/dropout.cu). A counter-based (stateless) RNG: the
|
||||
// keep/drop decision for element `i` is `hash(seed, i)` — no global state, so a
|
||||
// re-run with the same `seed` reproduces the same mask (compatible with T13
|
||||
// activation recomputation). Forward writes `out = x ⊙ mask` and the fp32 `mask`
|
||||
// buffer (mask[i] = (1/(1-p)) if kept else 0, the inverted-dropout scale);
|
||||
// backward applies the SAME mask: dx = d ⊙ mask. fp32 + bf16 activation variants
|
||||
// (mask is fp32 in both; the uniform is computed in fp32, dtype-independent).
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
pub fn launch_dropout_fwd_f32(
|
||||
x: *const f32,
|
||||
out: *mut f32,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_f32(
|
||||
d: *const f32,
|
||||
mask: *const f32,
|
||||
dx: *mut f32,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_fwd_bf16(
|
||||
x: *const c_void,
|
||||
out: *mut c_void,
|
||||
mask: *mut f32,
|
||||
p: f32,
|
||||
scale: f32,
|
||||
seed: u64,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
pub fn launch_dropout_bwd_bf16(
|
||||
d: *const c_void,
|
||||
mask: *const f32,
|
||||
dx: *mut c_void,
|
||||
n: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,11 @@ pub struct Config {
|
||||
pub eps: f32,
|
||||
/// RoPE base frequency (theta).
|
||||
pub rope_theta: f32,
|
||||
/// Dropout probability `p` (Phase T18). Applied at the attention/MLP sub-block
|
||||
/// outputs (before each residual add) at TRAINING time, with inverted scaling
|
||||
/// `1/(1-p)`; disabled (identity) at eval. Default `0.0` = no dropout, and the
|
||||
/// forward graph is then bit-identical to the pre-T18 path.
|
||||
pub dropout: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -36,6 +41,7 @@ impl Config {
|
||||
ffn_hidden: 64,
|
||||
eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +66,7 @@ impl Config {
|
||||
ffn_hidden,
|
||||
eps: 1e-5,
|
||||
rope_theta: 10000.0,
|
||||
dropout: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
use crate::config::Config;
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
@@ -55,6 +57,19 @@ pub struct TinyTransformer {
|
||||
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
|
||||
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
|
||||
use_flash: bool,
|
||||
/// Training mode for dropout (Phase T18). `true` → the attn/MLP sub-block
|
||||
/// outputs pass through `ops::dropout` (with `cfg.dropout` and a per-step,
|
||||
/// per-site seed); `false` (default) → dropout is identity (eval/sampling/
|
||||
/// export). `Cell` so `train()`/`eval()` flip it through `&self` (the forward
|
||||
/// takes `&self`). When `cfg.dropout == 0` this flag is irrelevant — the graph
|
||||
/// is bit-identical to the no-dropout path either way.
|
||||
training: Cell<bool>,
|
||||
/// Per-step dropout RNG seed (Phase T18). Bumped once at the start of each
|
||||
/// TRAINING forward so every step draws fresh masks; combined with the layer
|
||||
/// index + a per-site constant to give each dropout site its own seed. The RNG
|
||||
/// is counter-based, so re-running a checkpointed block's forward in backward
|
||||
/// (T13) reproduces the same seed → the same mask (recompute stays exact).
|
||||
step_seed: Cell<u64>,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -99,6 +114,8 @@ impl TinyTransformer {
|
||||
compute_dtype: DType::F32,
|
||||
recompute: false,
|
||||
use_flash: false,
|
||||
training: Cell::new(false),
|
||||
step_seed: Cell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,6 +166,30 @@ impl TinyTransformer {
|
||||
self.use_flash
|
||||
}
|
||||
|
||||
/// Switch to training mode (Phase T18): dropout (if `cfg.dropout > 0`) is
|
||||
/// active in subsequent forwards. The training loop calls this before stepping.
|
||||
pub fn train(&self) {
|
||||
self.training.set(true);
|
||||
}
|
||||
|
||||
/// Switch to eval mode (Phase T18): dropout is identity. Held-out eval,
|
||||
/// autoregressive sampling, and weight export all run in this mode (default).
|
||||
pub fn eval(&self) {
|
||||
self.training.set(false);
|
||||
}
|
||||
|
||||
pub fn is_training(&self) -> bool {
|
||||
self.training.get()
|
||||
}
|
||||
|
||||
/// Builder-style train/eval toggle (Phase T18) — handy for tests that want a
|
||||
/// model fixed in one mode. Equivalent to [`train`](Self::train) /
|
||||
/// [`eval`](Self::eval) but chains off `new(..)`.
|
||||
pub fn with_training(self, training: bool) -> Self {
|
||||
self.training.set(training);
|
||||
self
|
||||
}
|
||||
|
||||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||||
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
|
||||
/// `backward()`.
|
||||
@@ -198,23 +239,47 @@ impl TinyTransformer {
|
||||
);
|
||||
let seq = total / batch;
|
||||
|
||||
// Dropout (T18) is active only in training mode with p>0; otherwise it is
|
||||
// identity (`ops::dropout` no-ops at p==0). Bump the per-step seed ONCE per
|
||||
// training forward so each step draws fresh masks (counter-based RNG, so a
|
||||
// checkpointed block's recompute reproduces the same seed → same mask).
|
||||
let dropout_p = if self.training.get() {
|
||||
self.cfg.dropout
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
if dropout_p > 0.0 {
|
||||
self.step_seed.set(self.step_seed.get().wrapping_add(1));
|
||||
}
|
||||
let base_seed = self.step_seed.get();
|
||||
|
||||
// Embedding gathers from the fp32 master table; in bf16 mode cast the
|
||||
// activation stream to bf16 here (norms are cast to bf16 gammas too).
|
||||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
|
||||
if self.compute_dtype == DType::BF16 {
|
||||
h = ops::cast(&h, DType::BF16);
|
||||
}
|
||||
for b in &self.blocks {
|
||||
for (li, b) in self.blocks.iter().enumerate() {
|
||||
// Per-layer dropout seed: a deterministic function of (base_seed,
|
||||
// layer index) — NOT a mutable counter — so the checkpoint recompute
|
||||
// (which re-derives it from the captured base_seed/li) gets the same
|
||||
// masks. The block derives its two per-site seeds from this.
|
||||
let block_seed = base_seed
|
||||
.wrapping_mul(0x100000001B3)
|
||||
.wrapping_add(li as u64);
|
||||
h = if self.recompute {
|
||||
// Activation recomputation (T13): run the whole block forward inside
|
||||
// `checkpoint` so its internal activations aren't kept on the tape;
|
||||
// the block forward is re-run in backward to recover the grads. The
|
||||
// segment fn captures only `Copy` config (no borrow of `self`) and
|
||||
// receives the block's params via the slice, in `block_params` order.
|
||||
// `flash` is captured too → the recompute segment also runs flash.
|
||||
// `flash` is captured too → the recompute segment also runs flash;
|
||||
// `dropout_p`/`block_seed` are captured so the recompute re-derives
|
||||
// the same per-site dropout masks (counter-based RNG, exact).
|
||||
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
|
||||
let seg =
|
||||
move |x: &Var, p: &[Var]| block_forward(cfg, cdt, flash, batch, seq, x, p);
|
||||
let seg = move |x: &Var, p: &[Var]| {
|
||||
block_forward(cfg, cdt, flash, batch, seq, dropout_p, block_seed, x, p)
|
||||
};
|
||||
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
|
||||
} else {
|
||||
block_forward(
|
||||
@@ -223,6 +288,8 @@ impl TinyTransformer {
|
||||
self.use_flash,
|
||||
batch,
|
||||
seq,
|
||||
dropout_p,
|
||||
block_seed,
|
||||
&h,
|
||||
&b.block_params(),
|
||||
)
|
||||
@@ -300,10 +367,15 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
|
||||
}
|
||||
|
||||
/// One transformer block's forward: pre-norm + multi-head causal attention +
|
||||
/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch,
|
||||
/// seq, input, params)` (no `&self`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
|
||||
/// the block's leaves in [`Block::block_params`] order.
|
||||
/// (T18) dropout + residual, then pre-norm + SwiGLU MLP + dropout + residual.
|
||||
/// Attention runs the composed or fused-flash (T14) SDPA per `flash`. Pure in
|
||||
/// `(cfg, cdt, flash, batch, seq, dropout_p, block_seed, input, params)` (no
|
||||
/// `&self`, all `Copy`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13) — the
|
||||
/// recompute re-derives the same per-site seeds, so the dropout masks are
|
||||
/// reproduced bit-for-bit. `dropout_p == 0` makes `ops::dropout` a no-op (the
|
||||
/// graph is then identical to the pre-T18 path). `params` is the block's leaves in
|
||||
/// [`Block::block_params`] order.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn block_forward(
|
||||
cfg: Config,
|
||||
@@ -311,6 +383,8 @@ fn block_forward(
|
||||
flash: bool,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
dropout_p: f32,
|
||||
block_seed: u64,
|
||||
h: &Var,
|
||||
p: &[Var],
|
||||
) -> Var {
|
||||
@@ -318,16 +392,23 @@ fn block_forward(
|
||||
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
|
||||
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
// Per-site dropout seeds (XOR a site constant into the block seed) so the two
|
||||
// residual-path dropouts draw independent masks within the same step/layer.
|
||||
let attn_seed = block_seed ^ 0x0A7700;
|
||||
let ffn_seed = block_seed ^ 0x0FF700;
|
||||
|
||||
// --- Attention sub-block (pre-norm + dropout + residual) ---
|
||||
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
|
||||
let attn = attention(
|
||||
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
);
|
||||
let attn = ops::dropout(&attn, dropout_p, attn_seed);
|
||||
let h = ops::add(h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
// --- MLP sub-block (pre-norm + dropout + residual) ---
|
||||
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
|
||||
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
|
||||
let mlp = ops::dropout(&mlp, dropout_p, ffn_seed);
|
||||
ops::add(&h, &mlp)
|
||||
}
|
||||
|
||||
|
||||
222
crates/xtrain-model/tests/dropout.rs
Normal file
222
crates/xtrain-model/tests/dropout.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
// T18 dropout model-level gates.
|
||||
//
|
||||
// 1. p=0 bit-identical: a model built with cfg.dropout=0 (in either train or
|
||||
// eval mode) produces logits/loss/grads bit-for-bit identical to the same
|
||||
// model with no dropout field touched — the default forward graph is
|
||||
// unchanged (the regression guard).
|
||||
// 2. eval identity: with p>0 but eval mode, the forward equals the p=0 forward
|
||||
// bit-for-bit (dropout is OFF at eval).
|
||||
// 3. train vs eval differ: with p>0 and train mode, the forward differs from
|
||||
// eval (dropout actually does something) and grads are still finite.
|
||||
// 4. recompute compatibility: with p>0 + train + recompute, grads match the
|
||||
// non-recompute path (the counter-based seed reproduces the same mask on the
|
||||
// backward re-run — T13 stays exact even with dropout in the block).
|
||||
//
|
||||
// (The fixed-seed grad-check of the dropout op and the E[out]≈x / keep-rate check
|
||||
// live in xtrain-autodiff/tests/autograd.rs; p>0 training convergence is the
|
||||
// dash5 short run noted in docs/17-dropout.md.)
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn tiny_cfg(dropout: f32) -> Config {
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
cfg.dropout = dropout;
|
||||
cfg
|
||||
}
|
||||
|
||||
fn batch_data(cfg: &Config, device: Device) -> (xtrain_tensor::Tensor, xtrain_tensor::Tensor) {
|
||||
let (batch, seq) = (3usize, 6usize);
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| (0..seq).map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32).collect())
|
||||
.collect();
|
||||
(
|
||||
batched_ids_tensor(&seqs, device),
|
||||
batched_ids_tensor(&tgts, device),
|
||||
)
|
||||
}
|
||||
|
||||
fn require_gpu() -> Device {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
Device::Cuda(0)
|
||||
}
|
||||
|
||||
// Run forward+backward, return (logits, loss, per-param grads).
|
||||
fn fwd_bwd(
|
||||
m: &TinyTransformer,
|
||||
ids: &xtrain_tensor::Tensor,
|
||||
tgt: &xtrain_tensor::Tensor,
|
||||
batch: usize,
|
||||
) -> (Vec<f32>, f32, Vec<Vec<f32>>) {
|
||||
let logits = host(&m.forward_batched(ids, batch).value());
|
||||
let loss = m.loss_batched(ids, tgt, batch);
|
||||
let loss_val = host(&loss.value())[0];
|
||||
loss.backward();
|
||||
let grads: Vec<Vec<f32>> = m.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
(logits, loss_val, grads)
|
||||
}
|
||||
|
||||
// --- Gate 3: p=0 is bit-identical to the no-dropout path (default graph). ---
|
||||
#[test]
|
||||
fn dropout_p0_bit_identical() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
|
||||
// Reference: cfg.dropout default (0.0), never touched train/eval.
|
||||
let cfg0 = tiny_cfg(0.0);
|
||||
let (ids, tgt) = batch_data(&cfg0, device);
|
||||
let ref_m = build(cfg0, device);
|
||||
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
||||
|
||||
// p=0 in TRAINING mode: the seed bump is gated on p>0, the op no-ops at p==0,
|
||||
// so the graph must be byte-identical.
|
||||
let p0_train = build(tiny_cfg(0.0), device);
|
||||
p0_train.train();
|
||||
let (lt, lst, gt) = fwd_bwd(&p0_train, &ids, &tgt, batch);
|
||||
|
||||
assert_eq!(ref_logits, lt, "p=0 train logits not bit-identical");
|
||||
assert_eq!(ref_loss, lst, "p=0 train loss not bit-identical");
|
||||
for (i, (a, b)) in ref_grads.iter().zip(>).enumerate() {
|
||||
assert_eq!(a, b, "p=0 train grad[{i}] not bit-identical");
|
||||
}
|
||||
println!("p=0 (train) vs no-dropout: logits/loss/grads bit-identical ✅");
|
||||
}
|
||||
|
||||
// --- Gate 2: eval is exact identity (p>0 but eval mode == p=0). ---
|
||||
#[test]
|
||||
fn dropout_eval_is_identity() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.2);
|
||||
let (ids, tgt) = batch_data(&cfg, device);
|
||||
|
||||
// p=0 reference and a p=0.2 model held in eval — outputs must match bit-for-bit.
|
||||
let ref_m = build(tiny_cfg(0.0), device);
|
||||
let (ref_logits, ref_loss, ref_grads) = fwd_bwd(&ref_m, &ids, &tgt, batch);
|
||||
|
||||
let eval_m = build(cfg, device);
|
||||
eval_m.eval(); // explicit; also the default
|
||||
let (el, els, eg) = fwd_bwd(&eval_m, &ids, &tgt, batch);
|
||||
|
||||
assert_eq!(ref_logits, el, "eval (p>0) logits not identity");
|
||||
assert_eq!(ref_loss, els, "eval (p>0) loss not identity");
|
||||
for (i, (a, b)) in ref_grads.iter().zip(&eg).enumerate() {
|
||||
assert_eq!(a, b, "eval (p>0) grad[{i}] not identity");
|
||||
}
|
||||
println!("eval (p=0.2) == no-dropout: bit-identical (eval is identity) ✅");
|
||||
}
|
||||
|
||||
// --- Gate (train vs eval differ): with p>0 + train, dropout actually fires. ---
|
||||
#[test]
|
||||
fn dropout_train_differs_from_eval() {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.3);
|
||||
let (ids, _tgt) = batch_data(&cfg, device);
|
||||
|
||||
let m = build(cfg, device);
|
||||
m.eval();
|
||||
let eval_logits = host(&m.forward_batched(&ids, batch).value());
|
||||
m.train();
|
||||
let train_logits = host(&m.forward_batched(&ids, batch).value());
|
||||
|
||||
let max_diff = eval_logits
|
||||
.iter()
|
||||
.zip(&train_logits)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
max_diff > 1e-4 && train_logits.iter().all(|v| v.is_finite()),
|
||||
"train logits should differ from eval (dropout active) and be finite; max_diff={max_diff}"
|
||||
);
|
||||
println!("train vs eval logits max diff {max_diff:.4e} (dropout active in train) ✅");
|
||||
}
|
||||
|
||||
// --- Gate 4: p>0 + recompute grads match non-recompute (T13 stays exact). ---
|
||||
// The counter-based seed is a pure function of (step_seed, layer, site); the
|
||||
// checkpoint backward re-runs block_forward and re-derives the SAME seeds, so the
|
||||
// recomputed dropout masks match the forward — grads stay bit-identical.
|
||||
fn recompute_with_dropout(dtype: DType, grad_tol: f32) {
|
||||
let device = require_gpu();
|
||||
let batch = 3;
|
||||
let cfg = tiny_cfg(0.2);
|
||||
let (ids, tgt) = batch_data(&cfg, device);
|
||||
|
||||
// Both models: same init, train mode, p=0.2. step_seed starts at 0 and bumps
|
||||
// to 1 on the first training forward in BOTH, so they draw the same masks.
|
||||
let off = build(cfg, device).with_compute_dtype(dtype).with_training(true);
|
||||
let on = build(cfg, device)
|
||||
.with_compute_dtype(dtype)
|
||||
.with_recompute(true)
|
||||
.with_training(true);
|
||||
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on.params().iter().map(|p| host(&p.grad().unwrap())).collect();
|
||||
|
||||
let mut max_rel = 0.0f32;
|
||||
for (a, b) in off_grads.iter().flatten().zip(on_grads.iter().flatten()) {
|
||||
max_rel = max_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||||
}
|
||||
println!("[{dtype:?}] dropout p=0.2 recompute on/off grad max rel = {max_rel:.3e}");
|
||||
assert!(
|
||||
max_rel < grad_tol,
|
||||
"[{dtype:?}] recompute grads diverged with dropout: {max_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dropout_recompute_matches_fp32() {
|
||||
recompute_with_dropout(DType::F32, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dropout_recompute_matches_bf16() {
|
||||
recompute_with_dropout(DType::BF16, 5e-3);
|
||||
}
|
||||
@@ -668,6 +668,92 @@ impl Tensor {
|
||||
dx
|
||||
}
|
||||
|
||||
/// Dropout forward (Phase T18). Returns `(out, mask)` where, for each element
|
||||
/// `i`, a counter-based RNG draws `u = hash(seed, i) ∈ [0,1)` and keeps the
|
||||
/// element iff `u >= p`; kept elements are scaled by `1/(1-p)` (inverted
|
||||
/// dropout, so `E[out] == x`). `mask[i]` stores that per-element factor
|
||||
/// (`1/(1-p)` if kept, else `0`) for the backward to reuse — the same mask, so
|
||||
/// the op is a fixed elementwise scale w.r.t. `x` (and finite-diff-checkable).
|
||||
///
|
||||
/// The mask depends only on `(seed, i)`, NOT on `self`'s values, so a re-run
|
||||
/// with the same `seed` reproduces the same mask (T13 recompute stays exact).
|
||||
/// `mask` is always fp32 (the uniform is computed in fp32, dtype-independent);
|
||||
/// `out` matches `self`'s dtype. Requires `0 <= p < 1`.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout(&self, p: f32, seed: u64) -> (Self, Self) {
|
||||
assert!(
|
||||
matches!(self.dtype, DType::F32 | DType::BF16),
|
||||
"dropout supports F32/BF16"
|
||||
);
|
||||
assert!((0.0..1.0).contains(&p), "dropout p must be in [0,1)");
|
||||
assert!(self.is_contiguous(), "dropout requires contiguous tensor");
|
||||
let scale = 1.0 / (1.0 - p);
|
||||
let out = Tensor::zeros(&self.shape, self.dtype, self.device());
|
||||
let mask = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||
let n = self.numel() as i32;
|
||||
match self.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_f32(
|
||||
self.data_ptr() as *const f32,
|
||||
out.data_ptr() as *mut f32,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_fwd_bf16(
|
||||
self.data_ptr() as *const std::ffi::c_void,
|
||||
out.data_ptr() as *mut std::ffi::c_void,
|
||||
mask.data_ptr() as *mut f32,
|
||||
p,
|
||||
scale,
|
||||
seed,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
(out, mask)
|
||||
}
|
||||
|
||||
/// Dropout backward: `dx = d ⊙ mask` (the SAME `mask` the forward cached).
|
||||
/// `d` is the upstream grad (activation dtype); `mask` is the fp32 factor
|
||||
/// tensor from [`Self::dropout`]. Output matches `d`'s dtype.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn dropout_backward(d: &Tensor, mask: &Tensor) -> Self {
|
||||
assert_eq!(d.numel(), mask.numel(), "dropout_backward shape mismatch");
|
||||
assert_eq!(mask.dtype, DType::F32, "dropout mask must be F32");
|
||||
let dx = Tensor::zeros(&d.shape, d.dtype, d.device());
|
||||
let n = d.numel() as i32;
|
||||
match d.dtype {
|
||||
DType::F32 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_f32(
|
||||
d.data_ptr() as *const f32,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut f32,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
DType::BF16 => unsafe {
|
||||
xtrain_cuda::ffi::launch_dropout_bwd_bf16(
|
||||
d.data_ptr() as *const std::ffi::c_void,
|
||||
mask.data_ptr() as *const f32,
|
||||
dx.data_ptr() as *mut std::ffi::c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
},
|
||||
_ => panic!("dropout_backward supports F32/BF16"),
|
||||
}
|
||||
dx
|
||||
}
|
||||
|
||||
/// RoPE forward (rotate_half). `self`:[tokens,heads,head_dim]; each token's
|
||||
/// position is `row % period`. `period` = sequence length, so a flattened
|
||||
/// batch `[B*S,heads,head_dim]` gets per-sequence positions (pass `period=S`);
|
||||
|
||||
@@ -113,6 +113,10 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
// Dropout (Phase T18): residual-path dropout prob, active at training time
|
||||
// only (inverted scaling), identity at eval/sampling/export. Default 0 = off
|
||||
// (forward graph bit-identical to the no-dropout path).
|
||||
let dropout: f32 = flag(&args, "--dropout", 0.0f32);
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
@@ -156,7 +160,8 @@ fn main() {
|
||||
(corpus, None)
|
||||
};
|
||||
|
||||
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
|
||||
cfg.dropout = dropout;
|
||||
println!(
|
||||
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
|
||||
(+ embed/lm {:.2}M = {:.2}M total)",
|
||||
@@ -194,6 +199,9 @@ fn main() {
|
||||
model = model.with_flash(true);
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
if dropout > 0.0 {
|
||||
println!("dropout: ON (p={dropout}, residual-path, train-only inverted scaling)");
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
@@ -92,6 +92,10 @@ pub fn train(
|
||||
// is scaled by 1/accum before backward (the tape SUM-accumulates the scaled
|
||||
// grads). `accum == 1` skips the scale entirely → bit-identical to pre-T16.
|
||||
let mut step_loss_sum = 0.0f32;
|
||||
// Training mode → dropout active (T18; no-op when cfg.dropout == 0). Set
|
||||
// each step so it is restored after a periodic eval flips to eval mode.
|
||||
// Each micro-step's forward bumps the per-step seed → fresh masks.
|
||||
model.train();
|
||||
for _ in 0..accum {
|
||||
let mut inputs = Vec::with_capacity(cfg.batch_size);
|
||||
let mut targets_v = Vec::with_capacity(cfg.batch_size);
|
||||
@@ -190,6 +194,8 @@ pub fn eval_loss(
|
||||
if valid.len() <= seq + 1 {
|
||||
return f32::NAN;
|
||||
}
|
||||
// Eval mode → dropout is identity (T18).
|
||||
model.eval();
|
||||
let n_win = (valid.len() - 1) / seq; // disjoint windows that fit
|
||||
let batches = batches.max(1).min(n_win.max(1));
|
||||
let stride = (n_win / batches).max(1);
|
||||
|
||||
Reference in New Issue
Block a user