test+bins: flash grad-check, flash==composed, PyTorch parity, --flash flag
autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile) + flash_matches_composed_fwd. model/tests/flash.rs: flash==composed on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump: XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle (PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -625,6 +625,104 @@ fn attention_batched_bwd() {
|
||||
);
|
||||
}
|
||||
|
||||
// ---- fused FLASH causal attention (the T14 op) ----
|
||||
// Same structure as attention_batched_bwd, but exercises ops::flash_attention.
|
||||
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L=sum(W∘out).
|
||||
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised (not
|
||||
// just a single KV tile).
|
||||
#[test]
|
||||
fn flash_attention_batched_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 241);
|
||||
let k_h = fill(n, 242);
|
||||
let v_h = fill(n, 243);
|
||||
let w = fill(n, 244);
|
||||
|
||||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||||
let out = ops::flash_attention(&q, &k, &v, scale);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||||
let qv = cuda(qh, &[bh, seq, hd]);
|
||||
let kv = cuda(kh, &[bh, seq, hd]);
|
||||
let vv = cuda(vh, &[bh, seq, hd]);
|
||||
let (o, _) = qv.flash_attention(&kv, &vv, scale);
|
||||
weighted_sum(&o, &w)
|
||||
};
|
||||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||||
report(
|
||||
"flash dQ",
|
||||
&grad_check(
|
||||
&q_h,
|
||||
&[bh, seq, hd],
|
||||
&lq,
|
||||
dq.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||||
report(
|
||||
"flash dK",
|
||||
&grad_check(
|
||||
&k_h,
|
||||
&[bh, seq, hd],
|
||||
&lk,
|
||||
dk.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||||
report(
|
||||
"flash dV",
|
||||
&grad_check(
|
||||
&v_h,
|
||||
&[bh, seq, hd],
|
||||
&lv,
|
||||
dv.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// flash forward must equal the composed attention forward (same SDPA math).
|
||||
#[test]
|
||||
fn flash_matches_composed_fwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
|
||||
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
|
||||
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
|
||||
let (oc, _) = q.attention(&k, &v, scale);
|
||||
let (of, _) = q.flash_attention(&k, &v, scale);
|
||||
let oc = oc.to_device(Device::Cpu);
|
||||
let of = of.to_device(Device::Cpu);
|
||||
let max_rel = oc
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.zip(of.as_slice::<f32>())
|
||||
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
|
||||
.fold(0.0f32, f32::max);
|
||||
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
|
||||
assert!(
|
||||
max_rel < 1e-4,
|
||||
"flash fwd diverges from composed: {max_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
|
||||
@@ -89,6 +89,9 @@ fn main() {
|
||||
// rank checkpoints its own forward/backward; exact grads, lower peak activation
|
||||
// memory (lets dim1024 batch32 fit). Opt-in; default off.
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
|
||||
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
|
||||
let flash = args.iter().any(|a| a == "--flash");
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -174,6 +177,9 @@ fn main() {
|
||||
if recompute {
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
if flash {
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
@@ -187,6 +193,9 @@ fn main() {
|
||||
if recompute {
|
||||
m = m.with_recompute(true);
|
||||
}
|
||||
if flash {
|
||||
m = m.with_flash(true);
|
||||
}
|
||||
m
|
||||
},
|
||||
);
|
||||
|
||||
157
crates/xtrain-model/tests/flash.rs
Normal file
157
crates/xtrain-model/tests/flash.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
// T14 flash-attention correctness gate: the fused flash SDPA core must match the
|
||||
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
|
||||
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
|
||||
// math (online softmax never materializes the [bh,S,S] scores), so it differs
|
||||
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
|
||||
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
|
||||
//
|
||||
// build two identical models (same init), one with `--flash` on, one off, run
|
||||
// the SAME batched loss + backward on both, and assert
|
||||
// 1. the forward logits match within tolerance
|
||||
// 2. the loss matches
|
||||
// 3. EVERY parameter's grad matches within tolerance
|
||||
//
|
||||
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
|
||||
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
|
||||
// composed path's fp32 softmax, so the two are still the same softmax numerics.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
m.with_compute_dtype(dtype).with_flash(flash)
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_dtype(DType::F32)
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 16;
|
||||
cfg.n_layers = 4;
|
||||
let batch = 3usize;
|
||||
let seq = 40usize;
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// --- flash OFF (composed reference) ---
|
||||
let off = build(cfg, device, dtype, false);
|
||||
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
||||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||||
let off_loss_val = host(&off_loss.value())[0];
|
||||
off_loss.backward();
|
||||
let off_grads: Vec<Vec<f32>> = off
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("off grad")))
|
||||
.collect();
|
||||
|
||||
// --- flash ON ---
|
||||
let on = build(cfg, device, dtype, true);
|
||||
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
||||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||||
let on_loss_val = host(&on_loss.value())[0];
|
||||
on_loss.backward();
|
||||
let on_grads: Vec<Vec<f32>> = on
|
||||
.params()
|
||||
.iter()
|
||||
.map(|p| host(&p.grad().expect("on grad")))
|
||||
.collect();
|
||||
|
||||
// 1. Forward logits.
|
||||
let logit_rel = off_logits
|
||||
.iter()
|
||||
.zip(&on_logits)
|
||||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||||
.fold(0.0f32, f32::max);
|
||||
// 2. Loss.
|
||||
let loss_rel = (off_loss_val - on_loss_val).abs() / off_loss_val.abs().max(1e-4);
|
||||
println!(
|
||||
"[{dtype:?}] flash on/off: loss {off_loss_val:.6}/{on_loss_val:.6} (rel {loss_rel:.2e}), \
|
||||
logits max rel {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
logit_rel < logit_tol,
|
||||
"[{dtype:?}] logits diverged: {logit_rel:.2e}"
|
||||
);
|
||||
assert!(
|
||||
loss_rel < logit_tol,
|
||||
"[{dtype:?}] loss diverged: {loss_rel:.2e}"
|
||||
);
|
||||
|
||||
// 3. Every parameter grad — the load-bearing gate.
|
||||
let mut max_grad_rel = 0.0f32;
|
||||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||||
for (a, b) in off_g.iter().zip(on_g) {
|
||||
let rel = (a - b).abs() / a.abs().max(1e-3);
|
||||
max_grad_rel = max_grad_rel.max(rel);
|
||||
}
|
||||
}
|
||||
println!("[{dtype:?}] flash on/off: grad max rel err = {max_grad_rel:.3e}");
|
||||
assert!(
|
||||
max_grad_rel < grad_tol,
|
||||
"[{dtype:?}] flash grads diverged from composed: {max_grad_rel:.3e}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_matches_composed_fp32() {
|
||||
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
|
||||
// cuBLAS, dK/dV atomicAdd order). Tight but not bit-exact.
|
||||
run(DType::F32, 1e-3, 2e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_matches_composed_bf16() {
|
||||
// bf16 (T12 composition): bf16 rounding band on top of the fp32-softmax core.
|
||||
run(DType::BF16, 2e-2, 5e-2);
|
||||
}
|
||||
@@ -67,7 +67,7 @@ fn dump_for_parity() {
|
||||
|
||||
// Same deterministic init as the overfit test.
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
let mut model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
@@ -76,6 +76,14 @@ fn dump_for_parity() {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
// T14: with XTRAIN_PARITY_FLASH set, dump from the fused flash-attention path.
|
||||
// flash is the SAME SDPA math, so the SAME parity.py PyTorch oracle is the
|
||||
// reference for both paths — running this once per path checks flash against
|
||||
// PyTorch at B>1 (forward logits + every parameter grad).
|
||||
if std::env::var("XTRAIN_PARITY_FLASH").is_ok() {
|
||||
model = model.with_flash(true);
|
||||
println!("parity: FLASH attention path");
|
||||
}
|
||||
|
||||
// config + ids
|
||||
{
|
||||
|
||||
@@ -116,6 +116,9 @@ fn main() {
|
||||
// exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in;
|
||||
// default off stores every activation (unchanged numerics).
|
||||
let recompute = args.iter().any(|a| a == "--recompute");
|
||||
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
|
||||
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
|
||||
let flash = args.iter().any(|a| a == "--flash");
|
||||
let ckpt: PathBuf = PathBuf::from(
|
||||
args.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -183,6 +186,10 @@ fn main() {
|
||||
model = model.with_recompute(true);
|
||||
println!("activation recompute: ON (per-block gradient checkpointing)");
|
||||
}
|
||||
if flash {
|
||||
model = model.with_flash(true);
|
||||
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
Reference in New Issue
Block a user