autodiff: checkpoint primitive (recompute-on-backward)

Add `xtrain_autodiff::checkpoint::checkpoint(segment_fn, input, params)`, a
higher-order autograd node (à la torch.utils.checkpoint) for activation
recomputation (Phase T13 / KI-3):

- forward: run `segment_fn` on detached leaves so its internal ops are NOT
  recorded on the outer tape; keep only the output value (the local sub-tape —
  and thus the segment's intermediate activations — drops immediately). The
  checkpoint node's parents are [input, ..params].
- backward: re-run `segment_fn` from the saved input + (unchanged) param values
  into a fresh local tape, seed the recomputed output with the upstream grad,
  backprop, then push the recovered input/param grads to the real parents. Local
  tape drops at the end → recomputed activations freed.

Exact by construction (same deterministic kernels, same inputs) → grads match
the non-checkpointed path. Composes with bf16 (T12, same path on recompute) and
DDP (T8, per-rank).

Supporting change: `Var::backward_seeded(seed)` — backward from an explicit
non-scalar upstream grad (the segment output is generally not a scalar);
`backward()` is now the scalar wrapper that seeds ones.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 09:42:31 +08:00
parent 9c557f0609
commit c396b39483
3 changed files with 117 additions and 2 deletions

View File

@@ -0,0 +1,103 @@
//! Activation recomputation / gradient checkpointing (Phase T13, KI-3).
//!
//! A higher-order autograd primitive — the analogue of `torch.utils.checkpoint`.
//! It runs a *segment* of the model (a transformer block, here) WITHOUT recording
//! the segment's internal ops on the surrounding tape, so the segment's
//! intermediate activations are freed right after the forward instead of being
//! kept alive until backward. When the segment's output-grad arrives in backward,
//! the segment forward is **re-run** from the saved input (into a throwaway local
//! tape), the recomputed output is seeded with the upstream grad, and the gradient
//! is backpropagated through the local tape to recover the input-grad and the
//! parameter-grads — which are then pushed to the real tape's parents. The local
//! tape is dropped at the end of the closure, freeing the recomputed activations.
//!
//! ## Why it is exact (the hard gate)
//! The recompute runs the *same* `segment_fn` from the *same* input value and the
//! *same* parameter values (parameters are leaves that persist across forward and
//! backward; only their grad slot changes). The forward kernels are deterministic,
//! so the recomputed output equals the original output bit-for-bit, and the local
//! backward is the ordinary analytic backward of that segment. Therefore the input-
//! and parameter-grads are identical to those a non-checkpointed forward would
//! produce — checkpointing trades compute (one extra forward per segment) for
//! memory, never correctness.
//!
//! ## Composition
//! - **bf16 (T12):** `segment_fn` is the unchanged block forward, so the recompute
//! runs the same bf16 path; the `cast` op's grad upcast still bridges bf16→fp32.
//! - **DDP (T8):** each rank checkpoints its own forward/backward independently;
//! the param-grads recovered here feed the same per-rank `.grad()` slots that the
//! all-reduce averages — no change to the distributed path.
//! - **batched (T10):** the segment input/output carry the `[batch*seq, …]` batch
//! dim transparently; `checkpoint` is shape-agnostic.
#![cfg(not(no_cuda))]
use crate::tape::Var;
use std::rc::Rc;
/// Run `segment_fn(input, params)` with activation recomputation.
///
/// `segment_fn(x, p)` must build the segment's forward graph from a single input
/// `x` and the parameter slice `p`, returning the single segment output. It is
/// called once now (forward, result detached from the outer tape) and once per
/// backward (recompute). It MUST be deterministic and depend only on `x` and `p`
/// (this is what makes the recompute exact).
///
/// `params` are the segment's learnable leaves; their grads are accumulated into
/// the SAME leaves the optimizer reads (so DDP / AdamW are unchanged).
///
/// Returns the segment output as a `Var` on the outer tape whose backward triggers
/// the recompute. Equivalent — grad-for-grad — to calling `segment_fn(input,
/// params)` directly, but without keeping the segment's internal activations alive.
pub fn checkpoint<F>(segment_fn: F, input: &Var, params: &[Var]) -> Var
where
F: Fn(&Var, &[Var]) -> Var + 'static,
{
let segment_fn = Rc::new(segment_fn);
// --- Forward (no taping of internals) ---
// Detach the input and params into fresh leaves so `segment_fn` builds a LOCAL
// tape disconnected from the outer graph. We only keep the output's value; the
// local `Var`s (and thus the segment's intermediate activations) are dropped
// when this scope ends.
let out_value = {
let x_det = Var::leaf(input.value());
let params_det: Vec<Var> = params.iter().map(|p| Var::leaf(p.value())).collect();
let out_local = segment_fn(&x_det, &params_det);
out_local.value()
};
// Parents on the OUTER tape: the segment input, then the params (so their grads
// land in the leaves the optimizer reads).
let mut parents = Vec::with_capacity(1 + params.len());
parents.push(input.clone());
parents.extend(params.iter().cloned());
let segment_fn = segment_fn.clone();
Var::from_op(
out_value,
parents,
Box::new(move |dout, parents| {
// --- Backward (recompute) ---
// Rebuild fresh leaves from the CURRENT input/param values (params are
// unchanged since forward; input is the saved segment input), re-run the
// forward to rebuild the local tape, seed the recomputed output with the
// upstream grad, and backprop through the local tape.
let x_det = Var::leaf(parents[0].value());
let params_det: Vec<Var> = parents[1..].iter().map(|p| Var::leaf(p.value())).collect();
let out_local = segment_fn(&x_det, &params_det);
out_local.backward_seeded(dout.clone());
// Push the recovered grads to the real parents (engine SUMs on fan-out).
if let Some(dx) = x_det.grad() {
Var::push_grad(&parents[0], dx);
}
for (det, parent) in params_det.iter().zip(&parents[1..]) {
if let Some(dp) = det.grad() {
Var::push_grad(parent, dp);
}
}
// `out_local` / the local tape drop here → recomputed activations freed.
}),
)
}

View File

@@ -18,6 +18,8 @@ pub use finite_diff::{GradCheckConfig, GradCheckResult, ParamFn, grad_check};
// kernels via xtrain-tensor, so they are gated behind `not(no_cuda)` (the
// per-crate convention); the grad_check harness above stays host-only.
#[cfg(not(no_cuda))]
pub mod checkpoint;
#[cfg(not(no_cuda))]
pub mod ops;
#[cfg(not(no_cuda))]
pub mod tape;

View File

@@ -108,14 +108,24 @@ impl Var {
"backward() expects a scalar loss; got shape {:?}",
self.value().shape()
);
self.backward_seeded(ones_like(&self.value()));
}
/// Reverse-mode backward from this node seeded with an explicit upstream grad
/// `seed` (same shape as this node's value), instead of the scalar `dL/dL = 1`.
///
/// This is the entry point for **activation recomputation** (Phase T13): a
/// checkpointed segment re-runs its forward into a fresh local tape, then
/// backprops the upstream output-grad through it via this method (the segment
/// output is generally NOT a scalar). For a scalar root, [`backward`] is the
/// thin wrapper that seeds ones.
pub fn backward_seeded(&self, seed: Tensor) {
// 1. Topological order (post-order DFS), parents before children.
let mut topo: Vec<Var> = Vec::new();
let mut visited: Vec<*const RefCell<VarNode>> = Vec::new();
build_topo(self, &mut topo, &mut visited);
// 2. Seed the loss gradient with ones.
let seed = ones_like(&self.value());
// 2. Seed this node's gradient with the supplied upstream grad.
self.accumulate(seed);
// 3. Walk in reverse: each node hands its grad to its parents' closures.