Files
xtrain/crates/xtrain-autodiff
Gahow Wang c396b39483 autodiff: checkpoint primitive (recompute-on-backward)
Add `xtrain_autodiff::checkpoint::checkpoint(segment_fn, input, params)`, a
higher-order autograd node (à la torch.utils.checkpoint) for activation
recomputation (Phase T13 / KI-3):

- forward: run `segment_fn` on detached leaves so its internal ops are NOT
  recorded on the outer tape; keep only the output value (the local sub-tape —
  and thus the segment's intermediate activations — drops immediately). The
  checkpoint node's parents are [input, ..params].
- backward: re-run `segment_fn` from the saved input + (unchanged) param values
  into a fresh local tape, seed the recomputed output with the upstream grad,
  backprop, then push the recovered input/param grads to the real parents. Local
  tape drops at the end → recomputed activations freed.

Exact by construction (same deterministic kernels, same inputs) → grads match
the non-checkpointed path. Composes with bf16 (T12, same path on recompute) and
DDP (T8, per-rank).

Supporting change: `Var::backward_seeded(seed)` — backward from an explicit
non-scalar upstream grad (the segment output is generally not a scalar);
`backward()` is now the scalar wrapper that seeds ones.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 09:42:31 +08:00
..