Files
xtrain/crates/xtrain-autodiff/src/ops.rs
Gahow Wang 0e20821633 autodiff+model: flash-attention op + --flash opt-in wiring
ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of
O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a
use_flash bool + with_flash(bool) builder; the SDPA core in attention()
picks ops::flash_attention vs ops::attention. flash threads through
block_forward so the recompute (T13) segment also runs flash. Default
off = composed path, graph unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:32 +08:00

383 lines
14 KiB
Rust

//! Differentiable ops as autograd nodes (Phase T4).
//!
//! Each function runs the forward [`Tensor`] kernel, then builds a [`Var`] whose
//! backward closure computes the analytic gradient (see
//! `docs/03-autograd-engine.md` for the math) and pushes it to each parent via
//! [`Var::push_grad`] (which SUMs — correct under fan-out). Forward outputs that
//! the backward needs (softmax `y`, rms `inv_rms`, cross-entropy `probs`) are
//! cached by moving them into the closure.
//!
//! Attention is NOT a node here: it is composed from `matmul` + `scale` +
//! `softmax` in user code, and its backward falls out of theirs.
#![cfg(not(no_cuda))]
use crate::tape::Var;
use xtrain_tensor::{DType, Tensor};
/// dtype cast as an autograd node (Phase T12 — the AMP bridge between fp32 master
/// weights / fp32 reductions and the bf16 compute stream). Forward casts `x` to
/// `target`; **backward casts the upstream grad back to `x`'s dtype**. So a fp32
/// master-weight leaf fed through `cast(w, BF16)` into a bf16 matmul accumulates
/// an **fp32** grad — AdamW / clip / DDP all-reduce stay fp32, untouched.
pub fn cast(x: &Var, target: DType) -> Var {
let src = x.value().dtype();
if src == target {
return x.clone();
}
let out = x.value().to_dtype(target);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
Var::push_grad(&parents[0], d.to_dtype(src));
}),
)
}
/// `C = A @ B` (2D). Backward: `dA = dC @ Bᵀ`, `dB = Aᵀ @ dC`.
pub fn matmul(a: &Var, b: &Var) -> Var {
let out = a.value().matmul(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|dc, parents| {
let a = parents[0].value();
let b = parents[1].value();
let (da, db) = Tensor::matmul_backward(&a, &b, dc);
Var::push_grad(&parents[0], da);
Var::push_grad(&parents[1], db);
}),
)
}
/// Elementwise `out = a + b` (same shape). Backward: grad flows unchanged to both.
pub fn add(a: &Var, b: &Var) -> Var {
let out = a.value().add(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.clone());
Var::push_grad(&parents[1], d.clone());
}),
)
}
/// Elementwise `out = a * b` (Hadamard). Backward: `da = d∘b`, `db = d∘a`.
pub fn mul(a: &Var, b: &Var) -> Var {
let out = a.value().mul(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|d, parents| {
let a = parents[0].value();
let b = parents[1].value();
Var::push_grad(&parents[0], d.mul(&b));
Var::push_grad(&parents[1], d.mul(&a));
}),
)
}
/// Broadcast bias add: `out[r,c] = x[r,c] + bias[c]`. Backward: `dx = d`,
/// `dbias[c] = sum_r d[r,c]` (sum over the broadcast dim).
pub fn add_bias(x: &Var, bias: &Var) -> Var {
let out = x.value().add_bias(&bias.value());
Var::from_op(
out,
vec![x.clone(), bias.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.clone());
Var::push_grad(&parents[1], d.sum_rows());
}),
)
}
/// Scale by a constant: `out = x * alpha`. Backward: `dx = d * alpha`.
pub fn scale(x: &Var, alpha: f32) -> Var {
let out = x.value().scale(alpha);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
Var::push_grad(&parents[0], d.scale(alpha));
}),
)
}
/// RMSNorm: `y = x * rsqrt(mean(x²)+eps) * gamma`. Caches `inv_rms` for backward.
pub fn rms_norm(x: &Var, gamma: &Var, eps: f32) -> Var {
let (y, inv_rms) = x.value().rms_norm(&gamma.value(), eps);
Var::from_op(
y,
vec![x.clone(), gamma.clone()],
Box::new(move |dy, parents| {
let x = parents[0].value();
let gamma = parents[1].value();
let (dx, dgamma) = Tensor::rms_norm_backward(&x, &gamma, dy, &inv_rms);
Var::push_grad(&parents[0], dx);
Var::push_grad(&parents[1], dgamma);
}),
)
}
/// SiLU: `y = x * sigmoid(x)`. Backward uses the forward `x`.
pub fn silu(x: &Var) -> Var {
let out = x.value().silu();
Var::from_op(
out,
vec![x.clone()],
Box::new(|dy, parents| {
let x = parents[0].value();
Var::push_grad(&parents[0], Tensor::silu_backward(&x, dy));
}),
)
}
/// SwiGLU (SiLU-gated GLU): `out = silu(gate) ∘ up`. Composed from `silu` + `mul`
/// so its backward comes from theirs — no dedicated kernel needed.
pub fn swiglu(gate: &Var, up: &Var) -> Var {
mul(&silu(gate), up)
}
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]` with per-sequence position
/// `row % period` (`period` = sequence length; `period == tokens` for a single
/// sequence). Orthogonal map, so the backward is the inverse rotation of `dy` — no
/// cached forward values needed.
pub fn rope(x: &Var, theta: f32, period: usize) -> Var {
let out = x.value().rope(theta, period);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |dy, parents| {
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta, period));
}),
)
}
/// Row-wise softmax. Caches the output `y` for the Jacobian backward.
pub fn softmax(x: &Var) -> Var {
let y = x.value().softmax();
let y_cache = y.clone();
Var::from_op(
y,
vec![x.clone()],
Box::new(move |dy, parents| {
Var::push_grad(&parents[0], Tensor::softmax_backward(&y_cache, dy));
}),
)
}
/// Token embedding gather: `out[s,:] = table[ids[s], :]`. `table`:[vocab,dim]
/// (a learnable [`Var`]), `ids`:[seq] I32 (a constant index, not a `Var`).
/// Backward scatter-adds the upstream grad back into the table rows.
pub fn embedding(table: &Var, ids: &Tensor) -> Var {
let out = table.value().embedding(ids);
let vocab = table.value().shape()[0];
let ids = ids.clone();
Var::from_op(
out,
vec![table.clone()],
Box::new(move |dout, parents| {
let dtable = Tensor::embedding_backward(dout, &ids, vocab);
Var::push_grad(&parents[0], dtable);
}),
)
}
/// Reshape (contiguous, metadata-only). Backward reshapes the grad back to the
/// input shape. Used for the multi-head layout swap `[seq, h*hd] <-> [seq, h, hd]`.
pub fn reshape(x: &Var, new_shape: &[usize]) -> Var {
let in_shape: Vec<usize> = x.value().shape().to_vec();
let out = x.value().reshape(new_shape);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
Var::push_grad(&parents[0], d.reshape(&in_shape));
}),
)
}
/// 3D axis-(0,1) transpose `[a,b,c] -> [b,a,c]`. Self-inverse structure: the
/// backward is the same transpose applied to the grad.
pub fn transpose_3d01(x: &Var) -> Var {
let out = x.value().transpose_3d01();
Var::from_op(
out,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_3d01());
}),
)
}
/// 4D axis-(1,2) transpose `[a,b,c,d] -> [a,c,b,d]`. Self-inverse structure: the
/// backward is the same transpose applied to the grad. Lays out the batched
/// multi-head attention `[B,S,nh,hd] <-> [B,nh,S,hd]`.
pub fn transpose_4d12(x: &Var) -> Var {
let out = x.value().transpose_4d12();
Var::from_op(
out,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_4d12());
}),
)
}
/// 2D transpose `[r,c] -> [c,r]` as an autograd node (backward transposes the
/// grad back). Used for `Kᵀ` in attention scores.
pub fn transpose_2d(x: &Var) -> Var {
let out = x.value().transpose_2d();
Var::from_op(
out,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_2d());
}),
)
}
/// Split a `[heads, seq, head_dim]` tensor into one `[seq, head_dim]` [`Var`] per
/// head. Each head block is contiguous in this layout, so the forward copies the
/// head block into its own contiguous tensor; the backward scatters each head's
/// grad back into a zero `[heads, seq, head_dim]` grad (the engine then SUMs the
/// `heads` contributions on the shared parent — fan-out).
pub fn split_heads(x: &Var) -> Vec<Var> {
let v = x.value();
assert_eq!(v.ndim(), 3, "split_heads requires [heads,seq,head_dim]");
let (heads, seq, hd) = (v.shape()[0], v.shape()[1], v.shape()[2]);
let dev = v.device();
let flat_host = v.to_device(xtrain_tensor::Device::Cpu);
let flat = flat_host.as_slice::<f32>();
(0..heads)
.map(|h| {
let base = h * seq * hd;
let block = Tensor::from_slice(&flat[base..base + seq * hd], &[seq, hd]).to_device(dev);
Var::from_op(
block,
vec![x.clone()],
Box::new(move |d, parents| {
let mut host = vec![0.0f32; heads * seq * hd];
let dvals = d.to_device(xtrain_tensor::Device::Cpu);
let base = h * seq * hd;
host[base..base + seq * hd].copy_from_slice(dvals.as_slice::<f32>());
let g = Tensor::from_slice(&host, &[heads, seq, hd]).to_device(dev);
Var::push_grad(&parents[0], g);
}),
)
})
.collect()
}
/// Inverse of [`split_heads`]: stack per-head `[seq, head_dim]` outputs into a
/// `[heads, seq, head_dim]` tensor. Backward hands each head its own slice of the
/// grad.
pub fn merge_heads(heads_v: &[Var]) -> Var {
let heads = heads_v.len();
let v0 = heads_v[0].value();
let (seq, hd) = (v0.shape()[0], v0.shape()[1]);
let dev = v0.device();
let mut host = vec![0.0f32; heads * seq * hd];
for (h, hv) in heads_v.iter().enumerate() {
let block = hv.value().to_device(xtrain_tensor::Device::Cpu);
let base = h * seq * hd;
host[base..base + seq * hd].copy_from_slice(block.as_slice::<f32>());
}
let out = Tensor::from_slice(&host, &[heads, seq, hd]).to_device(dev);
Var::from_op(
out,
heads_v.to_vec(),
Box::new(move |d, parents| {
let dhost = d.to_device(xtrain_tensor::Device::Cpu);
let dflat = dhost.as_slice::<f32>();
for (h, parent) in parents.iter().enumerate() {
let base = h * seq * hd;
let g =
Tensor::from_slice(&dflat[base..base + seq * hd], &[seq, hd]).to_device(dev);
Var::push_grad(parent, g);
}
}),
)
}
/// Batched causal scaled-dot-product attention. `q`,`k`,`v` are each
/// `[bh, seq, head_dim]` (bh = batch·n_heads). Returns `[bh, seq, head_dim]`.
/// One fused op (2 batched GEMMs + 1 causal-softmax kernel forward; 4 batched
/// GEMMs + 1 softmax-backward kernel in backward) — replaces the per-(batch,head)
/// matmul/softmax loop, so attention is a handful of launches regardless of bh.
/// Caches the softmax `probs` for backward.
pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
let (out, probs) = q.value().attention(&k.value(), &v.value(), scale);
Var::from_op(
out,
vec![q.clone(), k.clone(), v.clone()],
Box::new(move |dout, parents| {
let q = parents[0].value();
let k = parents[1].value();
let v = parents[2].value();
let (dq, dk, dv) = Tensor::attention_backward(&q, &k, &v, &probs, dout, scale);
Var::push_grad(&parents[0], dq);
Var::push_grad(&parents[1], dk);
Var::push_grad(&parents[2], dv);
}),
)
}
/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as
/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a
/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]`
/// score matrix is NEVER materialized, and backward caches only the per-row
/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same
/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance.
/// Opt-in via the model's `--flash` flag; the composed path stays the default.
pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale);
let out_cache = out.clone();
Var::from_op(
out,
vec![q.clone(), k.clone(), v.clone()],
Box::new(move |dout, parents| {
let q = parents[0].value();
let k = parents[1].value();
let v = parents[2].value();
let (dq, dk, dv) =
Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale);
Var::push_grad(&parents[0], dq);
Var::push_grad(&parents[1], dk);
Var::push_grad(&parents[2], dv);
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
// CE math is fp32 (cross_entropy upcasts bf16 logits internally + caches fp32
// probs). The grad must match the logits' dtype so it chains into a bf16
// lm_head matmul backward — cast dx back. Keeping logits bf16 (no persistent
// fp32 logits buffer) is a real activation-memory saving at large vocab.
let logit_dtype = x.value().dtype();
let (probs, per_row) = x.value().cross_entropy(target);
let rows = x.value().shape()[0];
// Mean loss as a host scalar wrapped back into a [1] tensor.
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
let target = target.clone();
Var::from_op(
loss,
vec![x.clone()],
Box::new(move |d, parents| {
// `d` is the scalar upstream grad (1.0 when this is the loss root).
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
let scale = upstream / rows as f32;
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}