From 0e20821633ae787b463ca1d372afa2e9d8d8a77b Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 17 Jun 2026 23:10:32 +0800 Subject: [PATCH] autodiff+model: flash-attention op + --flash opt-in wiring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a use_flash bool + with_flash(bool) builder; the SDPA core in attention() picks ops::flash_attention vs ops::attention. flash threads through block_forward so the recompute (T13) segment also runs flash. Default off = composed path, graph unchanged. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-autodiff/src/ops.rs | 26 +++++++++++++++ crates/xtrain-model/src/model.rs | 54 +++++++++++++++++++++++++++---- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/crates/xtrain-autodiff/src/ops.rs b/crates/xtrain-autodiff/src/ops.rs index 0a5e489..0ddc394 100644 --- a/crates/xtrain-autodiff/src/ops.rs +++ b/crates/xtrain-autodiff/src/ops.rs @@ -325,6 +325,32 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var { ) } +/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as +/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a +/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]` +/// score matrix is NEVER materialized, and backward caches only the per-row +/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same +/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance. +/// Opt-in via the model's `--flash` flag; the composed path stays the default. +pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var { + let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale); + let out_cache = out.clone(); + Var::from_op( + out, + vec![q.clone(), k.clone(), v.clone()], + Box::new(move |dout, parents| { + let q = parents[0].value(); + let k = parents[1].value(); + let v = parents[2].value(); + let (dq, dk, dv) = + Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale); + Var::push_grad(&parents[0], dq); + Var::push_grad(&parents[1], dk); + Var::push_grad(&parents[2], dv); + }), + ) +} + /// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per /// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`, /// scaled by the upstream scalar grad. diff --git a/crates/xtrain-model/src/model.rs b/crates/xtrain-model/src/model.rs index 830f068..e8c0552 100644 --- a/crates/xtrain-model/src/model.rs +++ b/crates/xtrain-model/src/model.rs @@ -47,6 +47,14 @@ pub struct TinyTransformer { /// existing numerics are bit-identical; recompute is mathematically exact, so /// grads match the non-checkpointed path within fp tolerance. recompute: bool, + /// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through + /// the hand-written single fused kernel ([`ops::flash_attention`]): online + /// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized, + /// backward caches only the O(N) logsumexp. Default `false` = the composed T10 + /// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs), + /// so the default graph is unchanged. Mathematically the same SDPA → grads/loss + /// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`. + use_flash: bool, } impl TinyTransformer { @@ -90,6 +98,7 @@ impl TinyTransformer { lm_head, compute_dtype: DType::F32, recompute: false, + use_flash: false, } } @@ -127,6 +136,19 @@ impl TinyTransformer { self.recompute } + /// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and + /// opt-in; default off keeps the composed T10 path (so the default graph is + /// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA + /// math, online softmax, no materialized `[bh,seq,seq]` scores. + pub fn with_flash(mut self, use_flash: bool) -> Self { + self.use_flash = use_flash; + self + } + + pub fn use_flash(&self) -> bool { + self.use_flash + } + /// All learnable parameters, in a stable order. The optimizer (a hand-written /// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after /// `backward()`. @@ -189,13 +211,16 @@ impl TinyTransformer { // the block forward is re-run in backward to recover the grads. The // segment fn captures only `Copy` config (no borrow of `self`) and // receives the block's params via the slice, in `block_params` order. - let (cfg, cdt) = (self.cfg, self.compute_dtype); - let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p); + // `flash` is captured too → the recompute segment also runs flash. + let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash); + let seg = + move |x: &Var, p: &[Var]| block_forward(cfg, cdt, flash, batch, seq, x, p); xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params()) } else { block_forward( self.cfg, self.compute_dtype, + self.use_flash, batch, seq, &h, @@ -279,7 +304,16 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var { /// seq, input, params)` (no `&self`) so it can be the segment fn of /// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is /// the block's leaves in [`Block::block_params`] order. -fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var { +#[allow(clippy::too_many_arguments)] +fn block_forward( + cfg: Config, + cdt: DType, + flash: bool, + batch: usize, + seq: usize, + h: &Var, + p: &[Var], +) -> Var { let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]); let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]); let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]); @@ -287,7 +321,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: // --- Attention sub-block (pre-norm + residual) --- let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps); let attn = attention( - cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo, + cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo, ); let h = ops::add(h, &attn); @@ -308,6 +342,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: fn attention( cfg: Config, cdt: DType, + flash: bool, batch: usize, seq: usize, x: &Var, @@ -351,9 +386,14 @@ fn attention( let k = to_bh(linear(cdt, x, wk), Some(k_norm)); let v = to_bh(linear(cdt, x, wv), None); - // Fused batched causal SDPA over all B*nh (sequence,head) blocks at once - // (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop). - let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd] + // Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the + // single fused flash kernel (online softmax, no materialized [bh,S,S] scores); + // otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel). + let out = if flash { + ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd] + } else { + ops::attention(&q, &k, &v, scale) // [B*nh, S, hd] + }; // Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) → // [B,S,nh,hd] → [B*S, dim].