model: per-block activation recompute (--recompute)
Wrap each transformer block's forward in the checkpoint primitive when recompute is enabled (Phase T13 / KI-3). To make the block forward a pure segment fn (no `&self` borrow, so it can re-run in the backward closure), extract the block body + its helpers (linear / norm_gamma / attention / swiglu_mlp) into free functions parameterised by (cfg, compute_dtype) and add `Block::block_params()` (the 11 leaves in the params() per-block order). The non-recompute path calls `block_forward` directly — identical graph to before. - `TinyTransformer::with_recompute(bool)` builder (opt-in; default off keeps the unchanged tape / bit-identical numerics). - `--recompute` flag wired into bin/train and bin/train_ddp (DDP: each rank checkpoints independently). Correctness gate: tests/recompute.rs builds two identical models (recompute on/off), runs the same batched loss+backward, and asserts the forward logits, the loss, and EVERY parameter grad match within tight fp tol — parameterised over fp32 and bf16 (T12 composition). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,10 @@ fn main() {
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
// Activation recomputation (Phase T13): per-block gradient checkpointing — each
|
||||
// rank checkpoints its own forward/backward; exact grads, lower peak activation
|
||||
// memory (lets dim1024 batch32 fit). Opt-in; default off.
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -167,18 +171,23 @@ fn main() {
|
||||
if bf16 {
|
||||
println!("bf16 mixed precision: ON (fp32 master weights)");
|
||||
}
|
||||
if recompute {
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
valid.as_ref(),
|
||||
&dcfg,
|
||||
move |device| {
|
||||
let m = build_model(cfg, device);
|
||||
let mut m = build_model(cfg, device);
|
||||
if bf16 {
|
||||
m.with_compute_dtype(xtrain_tensor::DType::BF16)
|
||||
} else {
|
||||
m
|
||||
m = m.with_compute_dtype(xtrain_tensor::DType::BF16);
|
||||
}
|
||||
if recompute {
|
||||
m = m.with_recompute(true);
|
||||
}
|
||||
m
|
||||
},
|
||||
);
|
||||
let r0 = &results[0];
|
||||
|
||||
@@ -37,6 +37,16 @@ pub struct TinyTransformer {
|
||||
/// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16
|
||||
/// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged.
|
||||
compute_dtype: DType,
|
||||
/// Activation recomputation / gradient checkpointing (Phase T13, KI-3). When
|
||||
/// `true`, each transformer block's forward runs through
|
||||
/// [`xtrain_autodiff::checkpoint`]: the block's internal activations are NOT
|
||||
/// kept on the tape during forward (only the block input is), and the block
|
||||
/// forward is re-run during backward to recover them. Trades ~one extra forward
|
||||
/// per block for a large drop in peak activation memory → lets dim1024 batch32
|
||||
/// fit. Default `false` = the unchanged path (every activation stored), so
|
||||
/// existing numerics are bit-identical; recompute is mathematically exact, so
|
||||
/// grads match the non-checkpointed path within fp tolerance.
|
||||
recompute: bool,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -79,6 +89,7 @@ impl TinyTransformer {
|
||||
final_norm,
|
||||
lm_head,
|
||||
compute_dtype: DType::F32,
|
||||
recompute: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,16 +114,17 @@ impl TinyTransformer {
|
||||
self.compute_dtype
|
||||
}
|
||||
|
||||
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32
|
||||
/// master leaf). In bf16 mode the weight is cast to bf16 via the autograd
|
||||
/// `cast` op (whose backward upcasts the grad to fp32); in fp32 mode this is
|
||||
/// just `matmul(x, w)`. The activation `x` already carries `compute_dtype`.
|
||||
fn linear(&self, x: &Var, w: &Var) -> Var {
|
||||
match self.compute_dtype {
|
||||
DType::F32 => ops::matmul(x, w),
|
||||
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
|
||||
_ => unreachable!(),
|
||||
/// Enable per-block activation recomputation / gradient checkpointing (Phase
|
||||
/// T13). Builder-style and opt-in; default off keeps the unchanged tape (every
|
||||
/// activation stored). On, each block's forward is wrapped in
|
||||
/// [`xtrain_autodiff::checkpoint`] — exact grads, lower peak activation memory.
|
||||
pub fn with_recompute(mut self, recompute: bool) -> Self {
|
||||
self.recompute = recompute;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn recompute(&self) -> bool {
|
||||
self.recompute
|
||||
}
|
||||
|
||||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||||
@@ -171,32 +183,36 @@ impl TinyTransformer {
|
||||
h = ops::cast(&h, DType::BF16);
|
||||
}
|
||||
for b in &self.blocks {
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.attn_norm), self.cfg.eps);
|
||||
let attn = self.attention(b, &normed, batch, seq);
|
||||
h = ops::add(&h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.ffn_norm), self.cfg.eps);
|
||||
let mlp = self.swiglu_mlp(b, &normed);
|
||||
h = ops::add(&h, &mlp);
|
||||
h = if self.recompute {
|
||||
// Activation recomputation (T13): run the whole block forward inside
|
||||
// `checkpoint` so its internal activations aren't kept on the tape;
|
||||
// the block forward is re-run in backward to recover the grads. The
|
||||
// segment fn captures only `Copy` config (no borrow of `self`) and
|
||||
// receives the block's params via the slice, in `block_params` order.
|
||||
let (cfg, cdt) = (self.cfg, self.compute_dtype);
|
||||
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p);
|
||||
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
|
||||
} else {
|
||||
block_forward(
|
||||
self.cfg,
|
||||
self.compute_dtype,
|
||||
batch,
|
||||
seq,
|
||||
&h,
|
||||
&b.block_params(),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
|
||||
let h = ops::rms_norm(
|
||||
&h,
|
||||
&norm_gamma(self.compute_dtype, &self.final_norm),
|
||||
self.cfg.eps,
|
||||
);
|
||||
// lm_head matmul in compute dtype. Logits stay bf16 in bf16 mode — the
|
||||
// cross_entropy op upcasts to fp32 internally (no persistent fp32 logits
|
||||
// buffer, a real saving at vocab 50257), and its backward casts dx back.
|
||||
self.linear(&h, &self.lm_head) // [batch*seq, vocab]
|
||||
}
|
||||
|
||||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
|
||||
/// op, grad upcast) in bf16 mode; identity in fp32 mode.
|
||||
fn norm_gamma(&self, gamma: &Var) -> Var {
|
||||
match self.compute_dtype {
|
||||
DType::F32 => gamma.clone(),
|
||||
DType::BF16 => ops::cast(gamma, DType::BF16),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
linear(self.compute_dtype, &h, &self.lm_head) // [batch*seq, vocab]
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
|
||||
@@ -213,17 +229,96 @@ impl TinyTransformer {
|
||||
let logits = self.forward_batched(ids, batch);
|
||||
ops::cross_entropy(&logits, targets)
|
||||
}
|
||||
}
|
||||
|
||||
impl Block {
|
||||
/// The block's learnable leaves, in the fixed order the segment forward
|
||||
/// (`block_forward`) indexes them — matches the per-block slice in
|
||||
/// [`TinyTransformer::params`]. This is the param order `checkpoint` passes to
|
||||
/// the recompute closure.
|
||||
fn block_params(&self) -> Vec<Var> {
|
||||
vec![
|
||||
self.attn_norm.clone(),
|
||||
self.wq.clone(),
|
||||
self.wk.clone(),
|
||||
self.wv.clone(),
|
||||
self.q_norm.clone(),
|
||||
self.k_norm.clone(),
|
||||
self.wo.clone(),
|
||||
self.ffn_norm.clone(),
|
||||
self.w_gate.clone(),
|
||||
self.w_up.clone(),
|
||||
self.w_down.clone(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32 master
|
||||
/// leaf). In bf16 mode the weight is cast to bf16 via the autograd `cast` op (whose
|
||||
/// backward upcasts the grad to fp32); in fp32 mode this is just `matmul(x, w)`.
|
||||
fn linear(cdt: DType, x: &Var, w: &Var) -> Var {
|
||||
match cdt {
|
||||
DType::F32 => ops::matmul(x, w),
|
||||
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast op,
|
||||
/// grad upcast) in bf16 mode; identity in fp32 mode.
|
||||
fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
|
||||
match cdt {
|
||||
DType::F32 => gamma.clone(),
|
||||
DType::BF16 => ops::cast(gamma, DType::BF16),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// One transformer block's forward: pre-norm + multi-head causal attention +
|
||||
/// residual, then pre-norm + SwiGLU MLP + residual. Pure in `(cfg, cdt, batch,
|
||||
/// seq, input, params)` (no `&self`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
|
||||
/// the block's leaves in [`Block::block_params`] order.
|
||||
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var {
|
||||
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
|
||||
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
|
||||
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
|
||||
let attn = attention(
|
||||
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
);
|
||||
let h = ops::add(h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &norm_gamma(cdt, ffn_norm), cfg.eps);
|
||||
let mlp = swiglu_mlp(cdt, &normed, w_gate, w_up, w_down);
|
||||
ops::add(&h, &mlp)
|
||||
}
|
||||
|
||||
/// Multi-head causal self-attention over a flattened batch. `x`:[batch*seq,dim]
|
||||
/// (already normed), laid out sequence-major. The Q/K/V/O projections are big
|
||||
/// `[batch*seq, dim]` GEMMs; the scaled-dot-product attention itself runs as a
|
||||
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each
|
||||
/// attends within its own `[seq,seq]` causal window (NO cross-sequence
|
||||
/// attention), with RoPE positions reset per sequence (`period = seq`). Causal
|
||||
/// masking is applied inside the fused op's softmax kernel (no additive
|
||||
/// `[seq,seq]` mask tensor).
|
||||
fn attention(&self, b: &Block, x: &Var, batch: usize, seq: usize) -> Var {
|
||||
let (nh, hd) = (self.cfg.n_heads, self.cfg.head_dim);
|
||||
/// fused BATCHED op over the `batch·n_heads` (sequence,head) blocks — each attends
|
||||
/// within its own `[seq,seq]` causal window (NO cross-sequence attention), with
|
||||
/// RoPE positions reset per sequence (`period = seq`). Causal masking is applied
|
||||
/// inside the fused op's softmax kernel (no additive `[seq,seq]` mask tensor).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn attention(
|
||||
cfg: Config,
|
||||
cdt: DType,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
x: &Var,
|
||||
wq: &Var,
|
||||
wk: &Var,
|
||||
wv: &Var,
|
||||
q_norm: &Var,
|
||||
k_norm: &Var,
|
||||
wo: &Var,
|
||||
) -> Var {
|
||||
let (nh, hd) = (cfg.n_heads, cfg.head_dim);
|
||||
let total = batch * seq;
|
||||
let bh = batch * nh;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
@@ -241,9 +336,9 @@ impl TinyTransformer {
|
||||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||||
Some(gamma) => {
|
||||
let flat = ops::reshape(&r, &[total * nh, hd]);
|
||||
let normed = ops::rms_norm(&flat, &self.norm_gamma(gamma), self.cfg.eps);
|
||||
let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps);
|
||||
let r = ops::reshape(&normed, &[total, nh, hd]);
|
||||
ops::rope(&r, self.cfg.rope_theta, seq)
|
||||
ops::rope(&r, cfg.rope_theta, seq)
|
||||
}
|
||||
None => r,
|
||||
};
|
||||
@@ -252,9 +347,9 @@ impl TinyTransformer {
|
||||
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
|
||||
};
|
||||
|
||||
let q = to_bh(self.linear(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_bh(self.linear(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_bh(self.linear(x, &b.wv), None);
|
||||
let q = to_bh(linear(cdt, x, wq), Some(q_norm));
|
||||
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
|
||||
let v = to_bh(linear(cdt, x, wv), None);
|
||||
|
||||
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
|
||||
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
|
||||
@@ -265,16 +360,15 @@ impl TinyTransformer {
|
||||
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
|
||||
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
|
||||
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
|
||||
self.linear(&concat, &b.wo) // out projection
|
||||
linear(cdt, &concat, wo) // out projection
|
||||
}
|
||||
|
||||
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
|
||||
fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var {
|
||||
let gate = self.linear(x, &b.w_gate); // [seq, ffn_hidden]
|
||||
let up = self.linear(x, &b.w_up); // [seq, ffn_hidden]
|
||||
fn swiglu_mlp(cdt: DType, x: &Var, w_gate: &Var, w_up: &Var, w_down: &Var) -> Var {
|
||||
let gate = linear(cdt, x, w_gate); // [seq, ffn_hidden]
|
||||
let up = linear(cdt, x, w_up); // [seq, ffn_hidden]
|
||||
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
|
||||
self.linear(&act, &b.w_down) // [seq, dim]
|
||||
}
|
||||
linear(cdt, &act, w_down) // [seq, dim]
|
||||
}
|
||||
|
||||
/// Materialise a parameter's value back to a host `Vec<f32>` (for the GD step
|
||||
|
||||
156
crates/xtrain-model/tests/recompute.rs
Normal file
156
crates/xtrain-model/tests/recompute.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
// T13 activation-recomputation correctness gate (the HARD gate).
|
||||
//
|
||||
// Gradient checkpointing is mathematically EXACT: the backward re-runs the same
|
||||
// `segment_fn` from the same saved input and the same (unchanged) parameter
|
||||
// values, so the recomputed activations equal the originals and the recovered
|
||||
// grads equal the non-checkpointed grads — checkpointing trades compute for
|
||||
// memory, never correctness. This test makes that a closed loop on-GPU:
|
||||
//
|
||||
// build two identical models (same init), one with `--recompute` on, one off,
|
||||
// run the SAME batched loss + backward on both, and assert
|
||||
// 1. the forward logits match (recompute doesn't touch forward output)
|
||||
// 2. the loss matches
|
||||
// 3. EVERY parameter's grad matches within a tight fp tolerance.
|
||||
//
|
||||
// Composition is covered by parameterising over fp32 AND bf16 (T12): the
|
||||
// recompute path is the unchanged block forward, so it runs the same dtype path.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_recompute(recompute)
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
// A few layers so checkpointing actually wraps multiple blocks.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
let batch = 3usize;
|
||||
let seq = 6usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// --- recompute OFF (reference) ---
|
||||
let off = build(cfg, device, dtype, false);
|
||||
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
let off_loss_val = host(&off_loss.value())[0];
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("off grad")))
|
||||
.collect();
|
||||
|
||||
// --- recompute ON ---
|
||||
let on = build(cfg, device, dtype, true);
|
||||
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
let on_loss_val = host(&on_loss.value())[0];
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("on grad")))
|
||||
.collect();
|
||||
|
||||
// 1. Forward logits — recompute must not change the forward output.
|
||||
let logit_rel = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
// 2. Loss.
|
||||
let loss_rel = (off_loss_val - on_loss_val).abs() / off_loss_val.abs().max(1e-4);
|
||||
println!(
|
||||
"[{dtype:?}] recompute on/off: loss {off_loss_val:.6}/{on_loss_val:.6} (rel {loss_rel:.2e}), \
|
||||
logits max rel {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
logit_rel < logit_tol,
|
||||
"[{dtype:?}] logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
loss_rel < logit_tol,
|
||||
"[{dtype:?}] loss diverged: {loss_rel:.2e}"
|
||||
);
|
||||
|
||||
// 3. Every parameter grad — the load-bearing gate.
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||||
for (a, b) in off_g.iter().zip(on_g) {
|
||||
let rel = (a - b).abs() / a.abs().max(1e-3);
|
||||
max_grad_rel = max_grad_rel.max(rel);
|
||||
}
|
||||
}
|
||||
println!("[{dtype:?}] recompute on/off: grad max rel err = {max_grad_rel:.3e}");
|
||||
assert!(
|
||||
max_grad_rel < grad_tol,
|
||||
"[{dtype:?}] recompute grads diverged from non-recompute: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recompute_matches_non_recompute_fp32() {
|
||||
// fp32: recompute runs the identical deterministic kernels → grads match to
|
||||
// (near) bit-exact; allow a hair for any nondeterministic GPU reduction.
|
||||
run(DType::F32, 1e-5, 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recompute_matches_non_recompute_bf16() {
|
||||
// bf16 (T12 composition): same bf16 path on recompute. The recompute is still
|
||||
// exact w.r.t. the bf16 forward, so on/off match tightly (looser tol only for
|
||||
// bf16 rounding, not for any recompute discrepancy).
|
||||
run(DType::BF16, 5e-3, 5e-3);
|
||||
}
|
||||
@@ -112,6 +112,10 @@ fn main() {
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
// Activation recomputation (Phase T13): per-block gradient checkpointing —
|
||||
// exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in;
|
||||
// default off stores every activation (unchanged numerics).
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
let ckpt: PathBuf = PathBuf::from(
|
||||
args.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -175,6 +179,10 @@ fn main() {
|
||||
model = model.with_compute_dtype(DType::BF16);
|
||||
println!("bf16 mixed precision: ON (fp32 master weights)");
|
||||
}
|
||||
if recompute {
|
||||
model = model.with_recompute(true);
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
Reference in New Issue
Block a user