From f202351be59d9d0ae829ae6d1e67f277100ae7d2 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 17 Jun 2026 09:42:42 +0800 Subject: [PATCH] model: per-block activation recompute (--recompute) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wrap each transformer block's forward in the checkpoint primitive when recompute is enabled (Phase T13 / KI-3). To make the block forward a pure segment fn (no `&self` borrow, so it can re-run in the backward closure), extract the block body + its helpers (linear / norm_gamma / attention / swiglu_mlp) into free functions parameterised by (cfg, compute_dtype) and add `Block::block_params()` (the 11 leaves in the params() per-block order). The non-recompute path calls `block_forward` directly — identical graph to before. - `TinyTransformer::with_recompute(bool)` builder (opt-in; default off keeps the unchanged tape / bit-identical numerics). - `--recompute` flag wired into bin/train and bin/train_ddp (DDP: each rank checkpoints independently). Correctness gate: tests/recompute.rs builds two identical models (recompute on/off), runs the same batched loss+backward, and asserts the forward logits, the loss, and EVERY parameter grad match within tight fp tol — parameterised over fp32 and bf16 (T12 composition). Co-Authored-By: Claude Opus 4.8 --- .../xtrain-distributed/src/bin/train_ddp.rs | 17 +- crates/xtrain-model/src/model.rs | 266 ++++++++++++------ crates/xtrain-model/tests/recompute.rs | 156 ++++++++++ crates/xtrain-train/src/bin/train.rs | 8 + 4 files changed, 357 insertions(+), 90 deletions(-) create mode 100644 crates/xtrain-model/tests/recompute.rs diff --git a/crates/xtrain-distributed/src/bin/train_ddp.rs b/crates/xtrain-distributed/src/bin/train_ddp.rs index 2f362c5..0d0df9b 100644 --- a/crates/xtrain-distributed/src/bin/train_ddp.rs +++ b/crates/xtrain-distributed/src/bin/train_ddp.rs @@ -85,6 +85,10 @@ fn main() { // bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears + // activations. Opt-in; default fp32 reproduces v0–v4 numerics. let bf16 = args.iter().any(|a| a == "--bf16"); + // Activation recomputation (Phase T13): per-block gradient checkpointing — each + // rank checkpoints its own forward/backward; exact grads, lower peak activation + // memory (lets dim1024 batch32 fit). Opt-in; default off. + let recompute = args.iter().any(|a| a == "--recompute"); let ckpt: Option = args .iter() .position(|a| a == "--ckpt") @@ -167,18 +171,23 @@ fn main() { if bf16 { println!("bf16 mixed precision: ON (fp32 master weights)"); } + if recompute { + println!("activation recompute: ON (per-block gradient checkpointing)"); + } let results = launch( &devices, &train_corpus, valid.as_ref(), &dcfg, move |device| { - let m = build_model(cfg, device); + let mut m = build_model(cfg, device); if bf16 { - m.with_compute_dtype(xtrain_tensor::DType::BF16) - } else { - m + m = m.with_compute_dtype(xtrain_tensor::DType::BF16); } + if recompute { + m = m.with_recompute(true); + } + m }, ); let r0 = &results[0]; diff --git a/crates/xtrain-model/src/model.rs b/crates/xtrain-model/src/model.rs index fbf0499..830f068 100644 --- a/crates/xtrain-model/src/model.rs +++ b/crates/xtrain-model/src/model.rs @@ -37,6 +37,16 @@ pub struct TinyTransformer { /// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16 /// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged. compute_dtype: DType, + /// Activation recomputation / gradient checkpointing (Phase T13, KI-3). When + /// `true`, each transformer block's forward runs through + /// [`xtrain_autodiff::checkpoint`]: the block's internal activations are NOT + /// kept on the tape during forward (only the block input is), and the block + /// forward is re-run during backward to recover them. Trades ~one extra forward + /// per block for a large drop in peak activation memory → lets dim1024 batch32 + /// fit. Default `false` = the unchanged path (every activation stored), so + /// existing numerics are bit-identical; recompute is mathematically exact, so + /// grads match the non-checkpointed path within fp tolerance. + recompute: bool, } impl TinyTransformer { @@ -79,6 +89,7 @@ impl TinyTransformer { final_norm, lm_head, compute_dtype: DType::F32, + recompute: false, } } @@ -103,16 +114,17 @@ impl TinyTransformer { self.compute_dtype } - /// Project `x` (activation, in the compute dtype) by weight `w` (an fp32 - /// master leaf). In bf16 mode the weight is cast to bf16 via the autograd - /// `cast` op (whose backward upcasts the grad to fp32); in fp32 mode this is - /// just `matmul(x, w)`. The activation `x` already carries `compute_dtype`. - fn linear(&self, x: &Var, w: &Var) -> Var { - match self.compute_dtype { - DType::F32 => ops::matmul(x, w), - DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)), - _ => unreachable!(), - } + /// Enable per-block activation recomputation / gradient checkpointing (Phase + /// T13). Builder-style and opt-in; default off keeps the unchanged tape (every + /// activation stored). On, each block's forward is wrapped in + /// [`xtrain_autodiff::checkpoint`] — exact grads, lower peak activation memory. + pub fn with_recompute(mut self, recompute: bool) -> Self { + self.recompute = recompute; + self + } + + pub fn recompute(&self) -> bool { + self.recompute } /// All learnable parameters, in a stable order. The optimizer (a hand-written @@ -171,32 +183,36 @@ impl TinyTransformer { h = ops::cast(&h, DType::BF16); } for b in &self.blocks { - // --- Attention sub-block (pre-norm + residual) --- - let normed = ops::rms_norm(&h, &self.norm_gamma(&b.attn_norm), self.cfg.eps); - let attn = self.attention(b, &normed, batch, seq); - h = ops::add(&h, &attn); - - // --- MLP sub-block (pre-norm + residual) --- - let normed = ops::rms_norm(&h, &self.norm_gamma(&b.ffn_norm), self.cfg.eps); - let mlp = self.swiglu_mlp(b, &normed); - h = ops::add(&h, &mlp); + h = if self.recompute { + // Activation recomputation (T13): run the whole block forward inside + // `checkpoint` so its internal activations aren't kept on the tape; + // the block forward is re-run in backward to recover the grads. The + // segment fn captures only `Copy` config (no borrow of `self`) and + // receives the block's params via the slice, in `block_params` order. + let (cfg, cdt) = (self.cfg, self.compute_dtype); + let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p); + xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params()) + } else { + block_forward( + self.cfg, + self.compute_dtype, + batch, + seq, + &h, + &b.block_params(), + ) + }; } - let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps); + let h = ops::rms_norm( + &h, + &norm_gamma(self.compute_dtype, &self.final_norm), + self.cfg.eps, + ); // lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the // cross_entropy op upcasts to fp32 internally (no persistent fp32 logits // buffer, a real saving at vocab 50257), and its backward casts dx back. - self.linear(&h, &self.lm_head) // [batch*seq, vocab] - } - - /// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast - /// op, grad upcast) in bf16 mode; identity in fp32 mode. - fn norm_gamma(&self, gamma: &Var) -> Var { - match self.compute_dtype { - DType::F32 => gamma.clone(), - DType::BF16 => ops::cast(gamma, DType::BF16), - _ => unreachable!(), - } + linear(self.compute_dtype, &h, &self.lm_head) // [batch*seq, vocab] } /// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32). @@ -213,68 +229,146 @@ impl TinyTransformer { let logits = self.forward_batched(ids, batch); ops::cross_entropy(&logits, targets) } +} - /// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim] - /// (already normed), laid out sequence-major. The Q/K/V/O projections are big - /// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a - /// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each - /// attends within its own `[seq,seq]` causal window (NO cross-sequence - /// attention), with RoPE positions reset per sequence (`period = seq`). Causal - /// masking is applied inside the fused op's softmax kernel (no additive - /// `[seq,seq]` mask tensor). - fn attention(&self, b: &Block, x: &Var, batch: usize, seq: usize) -> Var { - let (nh, hd) = (self.cfg.n_heads, self.cfg.head_dim); - let total = batch * seq; - let bh = batch * nh; - let scale = 1.0 / (hd as f32).sqrt(); +impl Block { + /// The block's learnable leaves, in the fixed order the segment forward + /// (`block_forward`) indexes them — matches the per-block slice in + /// [`TinyTransformer::params`]. This is the param order `checkpoint` passes to + /// the recompute closure. + fn block_params(&self) -> Vec { + vec![ + self.attn_norm.clone(), + self.wq.clone(), + self.wk.clone(), + self.wv.clone(), + self.q_norm.clone(), + self.k_norm.clone(), + self.wo.clone(), + self.ffn_norm.clone(), + self.w_gate.clone(), + self.w_up.clone(), + self.w_down.clone(), + ] + } +} - // Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor. - // [B*S,dim] @ [dim,dim] = [B*S,dim] - // reshape [B*S, nh, hd] - // qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE) - // rope [B*S, nh, hd] with per-sequence position (period = seq) - // reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd] - let to_bh = |proj: Var, norm: Option<&Var>| -> Var { - let r = ops::reshape(&proj, &[total, nh, hd]); - let r = match norm { - // Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd, - // restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs). - Some(gamma) => { - let flat = ops::reshape(&r, &[total * nh, hd]); - let normed = ops::rms_norm(&flat, &self.norm_gamma(gamma), self.cfg.eps); - let r = ops::reshape(&normed, &[total, nh, hd]); - ops::rope(&r, self.cfg.rope_theta, seq) - } - None => r, - }; - let r = ops::reshape(&r, &[batch, seq, nh, hd]); - let t = ops::transpose_4d12(&r); // [B, nh, S, hd] - ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd] +/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32 master +/// leaf). In bf16 mode the weight is cast to bf16 via the autograd `cast` op (whose +/// backward upcasts the grad to fp32); in fp32 mode this is just `matmul(x, w)`. +fn linear(cdt: DType, x: &Var, w: &Var) -> Var { + match cdt { + DType::F32 => ops::matmul(x, w), + DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)), + _ => unreachable!(), + } +} + +/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast op, +/// grad upcast) in bf16 mode; identity in fp32 mode. +fn norm_gamma(cdt: DType, gamma: &Var) -> Var { + match cdt { + DType::F32 => gamma.clone(), + DType::BF16 => ops::cast(gamma, DType::BF16), + _ => unreachable!(), + } +} + +/// One transformer block's forward: pre-norm + multi-head causal attention + +/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch, +/// seq, input, params)` (no `&self`) so it can be the segment fn of +/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is +/// the block's leaves in [`Block::block_params`] order. +fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var { + let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]); + let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]); + let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]); + + // --- Attention sub-block (pre-norm + residual) --- + let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps); + let attn = attention( + cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo, + ); + let h = ops::add(h, &attn); + + // --- MLP sub-block (pre-norm + residual) --- + let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps); + let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down); + ops::add(&h, &mlp) +} + +/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim] +/// (already normed), laid out sequence-major. The Q/K/V/O projections are big +/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a +/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each attends +/// within its own `[seq,seq]` causal window (NO cross-sequence attention), with +/// RoPE positions reset per sequence (`period = seq`). Causal masking is applied +/// inside the fused op's softmax kernel (no additive `[seq,seq]` mask tensor). +#[allow(clippy::too_many_arguments)] +fn attention( + cfg: Config, + cdt: DType, + batch: usize, + seq: usize, + x: &Var, + wq: &Var, + wk: &Var, + wv: &Var, + q_norm: &Var, + k_norm: &Var, + wo: &Var, +) -> Var { + let (nh, hd) = (cfg.n_heads, cfg.head_dim); + let total = batch * seq; + let bh = batch * nh; + let scale = 1.0 / (hd as f32).sqrt(); + + // Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor. + // [B*S,dim] @ [dim,dim] = [B*S,dim] + // reshape [B*S, nh, hd] + // qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE) + // rope [B*S, nh, hd] with per-sequence position (period = seq) + // reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd] + let to_bh = |proj: Var, norm: Option<&Var>| -> Var { + let r = ops::reshape(&proj, &[total, nh, hd]); + let r = match norm { + // Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd, + // restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs). + Some(gamma) => { + let flat = ops::reshape(&r, &[total * nh, hd]); + let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps); + let r = ops::reshape(&normed, &[total, nh, hd]); + ops::rope(&r, cfg.rope_theta, seq) + } + None => r, }; + let r = ops::reshape(&r, &[batch, seq, nh, hd]); + let t = ops::transpose_4d12(&r); // [B, nh, S, hd] + ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd] + }; - let q = to_bh(self.linear(x, &b.wq), Some(&b.q_norm)); - let k = to_bh(self.linear(x, &b.wk), Some(&b.k_norm)); - let v = to_bh(self.linear(x, &b.wv), None); + let q = to_bh(linear(cdt, x, wq), Some(q_norm)); + let k = to_bh(linear(cdt, x, wk), Some(k_norm)); + let v = to_bh(linear(cdt, x, wv), None); - // Fused batched causal SDPA over all B*nh (sequence,head) blocks at once - // (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop). - let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd] + // Fused batched causal SDPA over all B*nh (sequence,head) blocks at once + // (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop). + let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd] - // Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) → - // [B,S,nh,hd] → [B*S, dim]. - let out = ops::reshape(&out, &[batch, nh, seq, hd]); - let out = ops::transpose_4d12(&out); // [B, S, nh, hd] - let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim] - self.linear(&concat, &b.wo) // out projection - } + // Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) → + // [B,S,nh,hd] → [B*S, dim]. + let out = ops::reshape(&out, &[batch, nh, seq, hd]); + let out = ops::transpose_4d12(&out); // [B, S, nh, hd] + let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim] + linear(cdt, &concat, wo) // out projection +} - /// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim]. - fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var { - let gate = self.linear(x, &b.w_gate); // [seq, ffn_hidden] - let up = self.linear(x, &b.w_up); // [seq, ffn_hidden] - let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up - self.linear(&act, &b.w_down) // [seq, dim] - } +/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim]. +fn swiglu_mlp(cdt: DType, x: &Var, w_gate: &Var, w_up: &Var, w_down: &Var) -> Var { + let gate = linear(cdt, x, w_gate); // [seq, ffn_hidden] + let up = linear(cdt, x, w_up); // [seq, ffn_hidden] + let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up + linear(cdt, &act, w_down) // [seq, dim] } /// Materialise a parameter's value back to a host `Vec` (for the GD step diff --git a/crates/xtrain-model/tests/recompute.rs b/crates/xtrain-model/tests/recompute.rs new file mode 100644 index 0000000..dd8c600 --- /dev/null +++ b/crates/xtrain-model/tests/recompute.rs @@ -0,0 +1,156 @@ +// T13 activation-recomputation correctness gate (the HARD gate). +// +// Gradient checkpointing is mathematically EXACT: the backward re-runs the same +// `segment_fn` from the same saved input and the same (unchanged) parameter +// values, so the recomputed activations equal the originals and the recovered +// grads equal the non-checkpointed grads — checkpointing trades compute for +// memory, never correctness. This test makes that a closed loop on-GPU: +// +// build two identical models (same init), one with `--recompute` on, one off, +// run the SAME batched loss + backward on both, and assert +// 1. the forward logits match (recompute doesn't touch forward output) +// 2. the loss matches +// 3. EVERY parameter's grad matches within a tight fp tolerance. +// +// Composition is covered by parameterising over fp32 AND bf16 (T12): the +// recompute path is the unchanged block forward, so it runs the same dtype path. +#![cfg(not(no_cuda))] + +use xtrain_cuda::device; +use xtrain_model::{Config, TinyTransformer, batched_ids_tensor}; +use xtrain_tensor::{DType, Device}; + +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTransformer { + let mut seed = 1u64; + let m = TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.08) + } + }); + m.with_compute_dtype(dtype).with_recompute(recompute) +} + +fn host(t: &xtrain_tensor::Tensor) -> Vec { + t.to_device(Device::Cpu).as_slice::().to_vec() +} + +fn run(dtype: DType, logit_tol: f32, grad_tol: f32) { + assert!(device::device_count().unwrap() > 0, "no CUDA device"); + device::set_device(0).unwrap(); + let device = Device::Cuda(0); + + // A few layers so checkpointing actually wraps multiple blocks. + let mut cfg = Config::tiny(); + cfg.vocab = 16; + cfg.n_layers = 4; + let batch = 3usize; + let seq = 6usize; + let seqs: Vec> = (0..batch) + .map(|b| { + (0..seq) + .map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32) + .collect() + }) + .collect(); + let tgts: Vec> = (0..batch) + .map(|b| { + (0..seq) + .map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32) + .collect() + }) + .collect(); + let ids = batched_ids_tensor(&seqs, device); + let tgt = batched_ids_tensor(&tgts, device); + + // --- recompute OFF (reference) --- + let off = build(cfg, device, dtype, false); + let off_logits = host(&off.forward_batched(&ids, batch).value()); + let off_loss = off.loss_batched(&ids, &tgt, batch); + let off_loss_val = host(&off_loss.value())[0]; + off_loss.backward(); + let off_grads: Vec> = off + .params() + .iter() + .map(|p| host(&p.grad().expect("off grad"))) + .collect(); + + // --- recompute ON --- + let on = build(cfg, device, dtype, true); + let on_logits = host(&on.forward_batched(&ids, batch).value()); + let on_loss = on.loss_batched(&ids, &tgt, batch); + let on_loss_val = host(&on_loss.value())[0]; + on_loss.backward(); + let on_grads: Vec> = on + .params() + .iter() + .map(|p| host(&p.grad().expect("on grad"))) + .collect(); + + // 1. Forward logits — recompute must not change the forward output. + let logit_rel = off_logits + .iter() + .zip(&on_logits) + .map(|(a, b)| (a - b).abs() / a.abs().max(1e-4)) + .fold(0.0f32, f32::max); + // 2. Loss. + let loss_rel = (off_loss_val - on_loss_val).abs() / off_loss_val.abs().max(1e-4); + println!( + "[{dtype:?}] recompute on/off: loss {off_loss_val:.6}/{on_loss_val:.6} (rel {loss_rel:.2e}), \ + logits max rel {logit_rel:.2e}" + ); + assert!( + logit_rel < logit_tol, + "[{dtype:?}] logits diverged: {logit_rel:.2e}" + ); + assert!( + loss_rel < logit_tol, + "[{dtype:?}] loss diverged: {loss_rel:.2e}" + ); + + // 3. Every parameter grad — the load-bearing gate. + let mut max_grad_rel = 0.0f32; + for (off_g, on_g) in off_grads.iter().zip(&on_grads) { + for (a, b) in off_g.iter().zip(on_g) { + let rel = (a - b).abs() / a.abs().max(1e-3); + max_grad_rel = max_grad_rel.max(rel); + } + } + println!("[{dtype:?}] recompute on/off: grad max rel err = {max_grad_rel:.3e}"); + assert!( + max_grad_rel < grad_tol, + "[{dtype:?}] recompute grads diverged from non-recompute: {max_grad_rel:.3e}" + ); +} + +#[test] +fn recompute_matches_non_recompute_fp32() { + // fp32: recompute runs the identical deterministic kernels → grads match to + // (near) bit-exact; allow a hair for any nondeterministic GPU reduction. + run(DType::F32, 1e-5, 1e-4); +} + +#[test] +fn recompute_matches_non_recompute_bf16() { + // bf16 (T12 composition): same bf16 path on recompute. The recompute is still + // exact w.r.t. the bf16 forward, so on/off match tightly (looser tol only for + // bf16 rounding, not for any recompute discrepancy). + run(DType::BF16, 5e-3, 5e-3); +} diff --git a/crates/xtrain-train/src/bin/train.rs b/crates/xtrain-train/src/bin/train.rs index 1182348..b1d50cf 100644 --- a/crates/xtrain-train/src/bin/train.rs +++ b/crates/xtrain-train/src/bin/train.rs @@ -112,6 +112,10 @@ fn main() { // bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears + // activations. Opt-in; default fp32 reproduces v0–v4 numerics. let bf16 = args.iter().any(|a| a == "--bf16"); + // Activation recomputation (Phase T13): per-block gradient checkpointing — + // exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in; + // default off stores every activation (unchanged numerics). + let recompute = args.iter().any(|a| a == "--recompute"); let ckpt: PathBuf = PathBuf::from( args.iter() .position(|a| a == "--ckpt") @@ -175,6 +179,10 @@ fn main() { model = model.with_compute_dtype(DType::BF16); println!("bf16 mixed precision: ON (fp32 master weights)"); } + if recompute { + model = model.with_recompute(true); + println!("activation recompute: ON (per-block gradient checkpointing)"); + } // Eval-only mode: load a checkpoint and score it on the held-out val set, then // exit. Used to put an EXISTING model (e.g. v0) and a new one on the same