docs: Phase T13 — activation recompute

Design doc for per-block gradient checkpointing (KI-3): the no-tape forward +
recompute-on-backward design, the `checkpoint` primitive, per-block wrapping,
the exactness/correctness argument (same kernels + inputs → identical grads),
composition with bf16+DDP+batched, and the verification plan (on-vs-off grad
gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD
to fill after the dash5 run.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 09:43:56 +08:00
parent f202351be5
commit 69c5f07359
2 changed files with 94 additions and 1 deletions

View File

@@ -48,8 +48,13 @@ fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTran
m.with_compute_dtype(dtype).with_recompute(recompute)
}
/// Upcast to fp32 then read to host — logits are bf16 in bf16 mode (grads are
/// always fp32 master, but this is uniform and harmless for fp32 tensors).
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {