autodiff+model: flash-attention op + --flash opt-in wiring

ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of
O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a
use_flash bool + with_flash(bool) builder; the SDPA core in attention()
picks ops::flash_attention vs ops::attention. flash threads through
block_forward so the recompute (T13) segment also runs flash. Default
off = composed path, graph unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:10:32 +08:00
parent 326a6fadfe
commit 0e20821633
2 changed files with 73 additions and 7 deletions

View File

@@ -325,6 +325,32 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
)
}
/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as
/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a
/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]`
/// score matrix is NEVER materialized, and backward caches only the per-row
/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same
/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance.
/// Opt-in via the model's `--flash` flag; the composed path stays the default.
pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale);
let out_cache = out.clone();
Var::from_op(
out,
vec![q.clone(), k.clone(), v.clone()],
Box::new(move |dout, parents| {
let q = parents[0].value();
let k = parents[1].value();
let v = parents[2].value();
let (dq, dk, dv) =
Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale);
Var::push_grad(&parents[0], dq);
Var::push_grad(&parents[1], dk);
Var::push_grad(&parents[2], dv);
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.

View File

@@ -47,6 +47,14 @@ pub struct TinyTransformer {
/// existing numerics are bit-identical; recompute is mathematically exact, so
/// grads match the non-checkpointed path within fp tolerance.
recompute: bool,
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
use_flash: bool,
}
impl TinyTransformer {
@@ -90,6 +98,7 @@ impl TinyTransformer {
lm_head,
compute_dtype: DType::F32,
recompute: false,
use_flash: false,
}
}
@@ -127,6 +136,19 @@ impl TinyTransformer {
self.recompute
}
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
/// opt-in; default off keeps the composed T10 path (so the default graph is
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
pub fn with_flash(mut self, use_flash: bool) -> Self {
self.use_flash = use_flash;
self
}
pub fn use_flash(&self) -> bool {
self.use_flash
}
/// All learnable parameters, in a stable order. The optimizer (a hand-written
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
/// `backward()`.
@@ -189,13 +211,16 @@ impl TinyTransformer {
// the block forward is re-run in backward to recover the grads. The
// segment fn captures only `Copy` config (no borrow of `self`) and
// receives the block's params via the slice, in `block_params` order.
let (cfg, cdt) = (self.cfg, self.compute_dtype);
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p);
// `flash` is captured too → the recompute segment also runs flash.
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
let seg =
move |x: &Var, p: &[Var]| block_forward(cfg, cdt, flash, batch, seq, x, p);
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
} else {
block_forward(
self.cfg,
self.compute_dtype,
self.use_flash,
batch,
seq,
&h,
@@ -279,7 +304,16 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
/// seq, input, params)` (no `&self`) so it can be the segment fn of
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
/// the block's leaves in [`Block::block_params`] order.
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var {
#[allow(clippy::too_many_arguments)]
fn block_forward(
cfg: Config,
cdt: DType,
flash: bool,
batch: usize,
seq: usize,
h: &Var,
p: &[Var],
) -> Var {
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
@@ -287,7 +321,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
// --- Attention sub-block (pre-norm + residual) ---
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
let attn = attention(
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
);
let h = ops::add(h, &attn);
@@ -308,6 +342,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
fn attention(
cfg: Config,
cdt: DType,
flash: bool,
batch: usize,
seq: usize,
x: &Var,
@@ -351,9 +386,14 @@ fn attention(
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
let v = to_bh(linear(cdt, x, wv), None);
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd]
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
// otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
let out = if flash {
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
} else {
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
};
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
// [B,S,nh,hd] → [B*S, dim].