autodiff+model: flash-attention op + --flash opt-in wiring
ops::flash_attention autograd node (fwd caches O(N) logsumexp instead of O(N²) probs; bwd via Tensor::flash_attention_backward). Model gets a use_flash bool + with_flash(bool) builder; the SDPA core in attention() picks ops::flash_attention vs ops::attention. flash threads through block_forward so the recompute (T13) segment also runs flash. Default off = composed path, graph unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -325,6 +325,32 @@ pub fn attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
)
|
||||
}
|
||||
|
||||
/// Fused FLASH causal scaled-dot-product attention (Phase T14). Same interface as
|
||||
/// [`attention`] (`q`,`k`,`v` each `[bh, seq, head_dim]`), but the forward is a
|
||||
/// SINGLE fused kernel with an online softmax over KV tiles — the `[bh,seq,seq]`
|
||||
/// score matrix is NEVER materialized, and backward caches only the per-row
|
||||
/// logsumexp (O(N)) instead of the whole probs (O(N²)). Mathematically the same
|
||||
/// SDPA, so it matches the composed [`attention`] within fp/bf16 tolerance.
|
||||
/// Opt-in via the model's `--flash` flag; the composed path stays the default.
|
||||
pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
|
||||
let (out, lse) = q.value().flash_attention(&k.value(), &v.value(), scale);
|
||||
let out_cache = out.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![q.clone(), k.clone(), v.clone()],
|
||||
Box::new(move |dout, parents| {
|
||||
let q = parents[0].value();
|
||||
let k = parents[1].value();
|
||||
let v = parents[2].value();
|
||||
let (dq, dk, dv) =
|
||||
Tensor::flash_attention_backward(&q, &k, &v, &out_cache, &lse, dout, scale);
|
||||
Var::push_grad(&parents[0], dq);
|
||||
Var::push_grad(&parents[1], dk);
|
||||
Var::push_grad(&parents[2], dv);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
|
||||
@@ -47,6 +47,14 @@ pub struct TinyTransformer {
|
||||
/// existing numerics are bit-identical; recompute is mathematically exact, so
|
||||
/// grads match the non-checkpointed path within fp tolerance.
|
||||
recompute: bool,
|
||||
/// Fused flash-attention (Phase T14). When `true`, the SDPA core runs through
|
||||
/// the hand-written single fused kernel ([`ops::flash_attention`]): online
|
||||
/// softmax over KV tiles, the `[bh,seq,seq]` score matrix NEVER materialized,
|
||||
/// backward caches only the O(N) logsumexp. Default `false` = the composed T10
|
||||
/// path (`cublasSgemmStridedBatched`×2 + causal-softmax kernel, O(N²) probs),
|
||||
/// so the default graph is unchanged. Mathematically the same SDPA → grads/loss
|
||||
/// match the composed path within fp/bf16 tolerance. Opt-in via `--flash`.
|
||||
use_flash: bool,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -90,6 +98,7 @@ impl TinyTransformer {
|
||||
lm_head,
|
||||
compute_dtype: DType::F32,
|
||||
recompute: false,
|
||||
use_flash: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +136,19 @@ impl TinyTransformer {
|
||||
self.recompute
|
||||
}
|
||||
|
||||
/// Enable the fused flash-attention SDPA core (Phase T14). Builder-style and
|
||||
/// opt-in; default off keeps the composed T10 path (so the default graph is
|
||||
/// unchanged). On, the SDPA runs through [`ops::flash_attention`] — same SDPA
|
||||
/// math, online softmax, no materialized `[bh,seq,seq]` scores.
|
||||
pub fn with_flash(mut self, use_flash: bool) -> Self {
|
||||
self.use_flash = use_flash;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn use_flash(&self) -> bool {
|
||||
self.use_flash
|
||||
}
|
||||
|
||||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||||
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
|
||||
/// `backward()`.
|
||||
@@ -189,13 +211,16 @@ impl TinyTransformer {
|
||||
// the block forward is re-run in backward to recover the grads. The
|
||||
// segment fn captures only `Copy` config (no borrow of `self`) and
|
||||
// receives the block's params via the slice, in `block_params` order.
|
||||
let (cfg, cdt) = (self.cfg, self.compute_dtype);
|
||||
let seg = move |x: &Var, p: &[Var]| block_forward(cfg, cdt, batch, seq, x, p);
|
||||
// `flash` is captured too → the recompute segment also runs flash.
|
||||
let (cfg, cdt, flash) = (self.cfg, self.compute_dtype, self.use_flash);
|
||||
let seg =
|
||||
move |x: &Var, p: &[Var]| block_forward(cfg, cdt, flash, batch, seq, x, p);
|
||||
xtrain_autodiff::checkpoint::checkpoint(seg, &h, &b.block_params())
|
||||
} else {
|
||||
block_forward(
|
||||
self.cfg,
|
||||
self.compute_dtype,
|
||||
self.use_flash,
|
||||
batch,
|
||||
seq,
|
||||
&h,
|
||||
@@ -279,7 +304,16 @@ fn norm_gamma(cdt: DType, gamma: &Var) -> Var {
|
||||
/// seq, input, params)` (no `&self`) so it can be the segment fn of
|
||||
/// [`xtrain_autodiff::checkpoint`] for activation recomputation (T13). `params` is
|
||||
/// the block's leaves in [`Block::block_params`] order.
|
||||
fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p: &[Var]) -> Var {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn block_forward(
|
||||
cfg: Config,
|
||||
cdt: DType,
|
||||
flash: bool,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
h: &Var,
|
||||
p: &[Var],
|
||||
) -> Var {
|
||||
let (attn_norm, wq, wk, wv) = (&p[0], &p[1], &p[2], &p[3]);
|
||||
let (q_norm, k_norm, wo) = (&p[4], &p[5], &p[6]);
|
||||
let (ffn_norm, w_gate, w_up, w_down) = (&p[7], &p[8], &p[9], &p[10]);
|
||||
@@ -287,7 +321,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(h, &norm_gamma(cdt, attn_norm), cfg.eps);
|
||||
let attn = attention(
|
||||
cfg, cdt, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
cfg, cdt, flash, batch, seq, &normed, wq, wk, wv, q_norm, k_norm, wo,
|
||||
);
|
||||
let h = ops::add(h, &attn);
|
||||
|
||||
@@ -308,6 +342,7 @@ fn block_forward(cfg: Config, cdt: DType, batch: usize, seq: usize, h: &Var, p:
|
||||
fn attention(
|
||||
cfg: Config,
|
||||
cdt: DType,
|
||||
flash: bool,
|
||||
batch: usize,
|
||||
seq: usize,
|
||||
x: &Var,
|
||||
@@ -351,9 +386,14 @@ fn attention(
|
||||
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
|
||||
let v = to_bh(linear(cdt, x, wv), None);
|
||||
|
||||
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
|
||||
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
|
||||
let out = ops::attention(&q, &k, &v, scale); // [B*nh, S, hd]
|
||||
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
|
||||
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);
|
||||
// otherwise the composed T10 path (2 batched GEMMs + 1 causal-softmax kernel).
|
||||
let out = if flash {
|
||||
ops::flash_attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||||
} else {
|
||||
ops::attention(&q, &k, &v, scale) // [B*nh, S, hd]
|
||||
};
|
||||
|
||||
// Back to [B*S, dim]: [B*nh,S,hd] → [B,nh,S,hd] → transpose(1,2) →
|
||||
// [B,S,nh,hd] → [B*S, dim].
|
||||
|
||||
Reference in New Issue
Block a user