autodiff: checkpoint primitive (recompute-on-backward)
Add `xtrain_autodiff::checkpoint::checkpoint(segment_fn, input, params)`, a higher-order autograd node (à la torch.utils.checkpoint) for activation recomputation (Phase T13 / KI-3): - forward: run `segment_fn` on detached leaves so its internal ops are NOT recorded on the outer tape; keep only the output value (the local sub-tape — and thus the segment's intermediate activations — drops immediately). The checkpoint node's parents are [input, ..params]. - backward: re-run `segment_fn` from the saved input + (unchanged) param values into a fresh local tape, seed the recomputed output with the upstream grad, backprop, then push the recovered input/param grads to the real parents. Local tape drops at the end → recomputed activations freed. Exact by construction (same deterministic kernels, same inputs) → grads match the non-checkpointed path. Composes with bf16 (T12, same path on recompute) and DDP (T8, per-rank). Supporting change: `Var::backward_seeded(seed)` — backward from an explicit non-scalar upstream grad (the segment output is generally not a scalar); `backward()` is now the scalar wrapper that seeds ones. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
103
crates/xtrain-autodiff/src/checkpoint.rs
Normal file
103
crates/xtrain-autodiff/src/checkpoint.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! Activation recomputation / gradient checkpointing (Phase T13, KI-3).
|
||||
//!
|
||||
//! A higher-order autograd primitive — the analogue of `torch.utils.checkpoint`.
|
||||
//! It runs a *segment* of the model (a transformer block, here) WITHOUT recording
|
||||
//! the segment's internal ops on the surrounding tape, so the segment's
|
||||
//! intermediate activations are freed right after the forward instead of being
|
||||
//! kept alive until backward. When the segment's output-grad arrives in backward,
|
||||
//! the segment forward is **re-run** from the saved input (into a throwaway local
|
||||
//! tape), the recomputed output is seeded with the upstream grad, and the gradient
|
||||
//! is backpropagated through the local tape to recover the input-grad and the
|
||||
//! parameter-grads — which are then pushed to the real tape's parents. The local
|
||||
//! tape is dropped at the end of the closure, freeing the recomputed activations.
|
||||
//!
|
||||
//! ## Why it is exact (the hard gate)
|
||||
//! The recompute runs the *same* `segment_fn` from the *same* input value and the
|
||||
//! *same* parameter values (parameters are leaves that persist across forward and
|
||||
//! backward; only their grad slot changes). The forward kernels are deterministic,
|
||||
//! so the recomputed output equals the original output bit-for-bit, and the local
|
||||
//! backward is the ordinary analytic backward of that segment. Therefore the input-
|
||||
//! and parameter-grads are identical to those a non-checkpointed forward would
|
||||
//! produce — checkpointing trades compute (one extra forward per segment) for
|
||||
//! memory, never correctness.
|
||||
//!
|
||||
//! ## Composition
|
||||
//! - **bf16 (T12):** `segment_fn` is the unchanged block forward, so the recompute
|
||||
//! runs the same bf16 path; the `cast` op's grad upcast still bridges bf16→fp32.
|
||||
//! - **DDP (T8):** each rank checkpoints its own forward/backward independently;
|
||||
//! the param-grads recovered here feed the same per-rank `.grad()` slots that the
|
||||
//! all-reduce averages — no change to the distributed path.
|
||||
//! - **batched (T10):** the segment input/output carry the `[batch*seq, …]` batch
|
||||
//! dim transparently; `checkpoint` is shape-agnostic.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use crate::tape::Var;
|
||||
use std::rc::Rc;
|
||||
|
||||
/// Run `segment_fn(input, params)` with activation recomputation.
|
||||
///
|
||||
/// `segment_fn(x, p)` must build the segment's forward graph from a single input
|
||||
/// `x` and the parameter slice `p`, returning the single segment output. It is
|
||||
/// called once now (forward, result detached from the outer tape) and once per
|
||||
/// backward (recompute). It MUST be deterministic and depend only on `x` and `p`
|
||||
/// (this is what makes the recompute exact).
|
||||
///
|
||||
/// `params` are the segment's learnable leaves; their grads are accumulated into
|
||||
/// the SAME leaves the optimizer reads (so DDP / AdamW are unchanged).
|
||||
///
|
||||
/// Returns the segment output as a `Var` on the outer tape whose backward triggers
|
||||
/// the recompute. Equivalent — grad-for-grad — to calling `segment_fn(input,
|
||||
/// params)` directly, but without keeping the segment's internal activations alive.
|
||||
pub fn checkpoint<F>(segment_fn: F, input: &Var, params: &[Var]) -> Var
|
||||
where
|
||||
F: Fn(&Var, &[Var]) -> Var + 'static,
|
||||
{
|
||||
let segment_fn = Rc::new(segment_fn);
|
||||
|
||||
// --- Forward (no taping of internals) ---
|
||||
// Detach the input and params into fresh leaves so `segment_fn` builds a LOCAL
|
||||
// tape disconnected from the outer graph. We only keep the output's value; the
|
||||
// local `Var`s (and thus the segment's intermediate activations) are dropped
|
||||
// when this scope ends.
|
||||
let out_value = {
|
||||
let x_det = Var::leaf(input.value());
|
||||
let params_det: Vec<Var> = params.iter().map(|p| Var::leaf(p.value())).collect();
|
||||
let out_local = segment_fn(&x_det, ¶ms_det);
|
||||
out_local.value()
|
||||
};
|
||||
|
||||
// Parents on the OUTER tape: the segment input, then the params (so their grads
|
||||
// land in the leaves the optimizer reads).
|
||||
let mut parents = Vec::with_capacity(1 + params.len());
|
||||
parents.push(input.clone());
|
||||
parents.extend(params.iter().cloned());
|
||||
|
||||
let segment_fn = segment_fn.clone();
|
||||
Var::from_op(
|
||||
out_value,
|
||||
parents,
|
||||
Box::new(move |dout, parents| {
|
||||
// --- Backward (recompute) ---
|
||||
// Rebuild fresh leaves from the CURRENT input/param values (params are
|
||||
// unchanged since forward; input is the saved segment input), re-run the
|
||||
// forward to rebuild the local tape, seed the recomputed output with the
|
||||
// upstream grad, and backprop through the local tape.
|
||||
let x_det = Var::leaf(parents[0].value());
|
||||
let params_det: Vec<Var> = parents[1..].iter().map(|p| Var::leaf(p.value())).collect();
|
||||
let out_local = segment_fn(&x_det, ¶ms_det);
|
||||
out_local.backward_seeded(dout.clone());
|
||||
|
||||
// Push the recovered grads to the real parents (engine SUMs on fan-out).
|
||||
if let Some(dx) = x_det.grad() {
|
||||
Var::push_grad(&parents[0], dx);
|
||||
}
|
||||
for (det, parent) in params_det.iter().zip(&parents[1..]) {
|
||||
if let Some(dp) = det.grad() {
|
||||
Var::push_grad(parent, dp);
|
||||
}
|
||||
}
|
||||
// `out_local` / the local tape drop here → recomputed activations freed.
|
||||
}),
|
||||
)
|
||||
}
|
||||
@@ -18,6 +18,8 @@ pub use finite_diff::{GradCheckConfig, GradCheckResult, ParamFn, grad_check};
|
||||
// kernels via xtrain-tensor, so they are gated behind `not(no_cuda)` (the
|
||||
// per-crate convention); the grad_check harness above stays host-only.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod ops;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod tape;
|
||||
|
||||
@@ -108,14 +108,24 @@ impl Var {
|
||||
"backward() expects a scalar loss; got shape {:?}",
|
||||
self.value().shape()
|
||||
);
|
||||
self.backward_seeded(ones_like(&self.value()));
|
||||
}
|
||||
|
||||
/// Reverse-mode backward from this node seeded with an explicit upstream grad
|
||||
/// `seed` (same shape as this node's value), instead of the scalar `dL/dL = 1`.
|
||||
///
|
||||
/// This is the entry point for **activation recomputation** (Phase T13): a
|
||||
/// checkpointed segment re-runs its forward into a fresh local tape, then
|
||||
/// backprops the upstream output-grad through it via this method (the segment
|
||||
/// output is generally NOT a scalar). For a scalar root, [`backward`] is the
|
||||
/// thin wrapper that seeds ones.
|
||||
pub fn backward_seeded(&self, seed: Tensor) {
|
||||
// 1. Topological order (post-order DFS), parents before children.
|
||||
let mut topo: Vec<Var> = Vec::new();
|
||||
let mut visited: Vec<*const RefCell<VarNode>> = Vec::new();
|
||||
build_topo(self, &mut topo, &mut visited);
|
||||
|
||||
// 2. Seed the loss gradient with ones.
|
||||
let seed = ones_like(&self.value());
|
||||
// 2. Seed this node's gradient with the supplied upstream grad.
|
||||
self.accumulate(seed);
|
||||
|
||||
// 3. Walk in reverse: each node hands its grad to its parents' closures.
|
||||
|
||||
Reference in New Issue
Block a user