docs: Phase T13 — activation recompute
Design doc for per-block gradient checkpointing (KI-3): the no-tape forward + recompute-on-backward design, the `checkpoint` primitive, per-block wrapping, the exactness/correctness argument (same kernels + inputs → identical grads), composition with bf16+DDP+batched, and the verification plan (on-vs-off grad gate + memory/throughput before→after, dim1024-fits). Bench table left as TBD to fill after the dash5 run. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -48,8 +48,13 @@ fn build(cfg: Config, device: Device, dtype: DType, recompute: bool) -> TinyTran
|
||||
m.with_compute_dtype(dtype).with_recompute(recompute)
|
||||
}
|
||||
|
||||
/// Upcast to fp32 then read to host — logits are bf16 in bf16 mode (grads are
|
||||
/// always fp32 master, but this is uniform and harmless for fp32 tensors).
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {
|
||||
|
||||
Reference in New Issue
Block a user