test+bins: flash grad-check, flash==composed, PyTorch parity, --flash flag

autograd: flash_attention_batched_bwd (dQ/dK/dV finite-diff, seq>tile)
+ flash_matches_composed_fwd. model/tests/flash.rs: flash==composed
on-vs-off (logits/loss/every param grad), fp32 + bf16. parity_dump:
XTRAIN_PARITY_FLASH dumps the flash path for the same parity.py oracle
(PyTorch SDPA parity at B>1). train + train_ddp get the --flash flag.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:10:39 +08:00
parent 0e20821633
commit 5f3b81ac96
5 changed files with 280 additions and 1 deletions

View File

@@ -625,6 +625,104 @@ fn attention_batched_bwd() {
);
}
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure as attention_batched_bwd, but exercises ops::flash_attention.
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L=sum(W∘out).
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised (not
// just a single KV tile).
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 241);
let k_h = fill(n, 242);
let v_h = fill(n, 243);
let w = fill(n, 244);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::flash_attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.flash_attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"flash dQ",
&grad_check(
&q_h,
&[bh, seq, hd],
&lq,
dq.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"flash dK",
&grad_check(
&k_h,
&[bh, seq, hd],
&lk,
dk.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"flash dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// flash forward must equal the composed attention forward (same SDPA math).
#[test]
fn flash_matches_composed_fwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
let (oc, _) = q.attention(&k, &v, scale);
let (of, _) = q.flash_attention(&k, &v, scale);
let oc = oc.to_device(Device::Cpu);
let of = of.to_device(Device::Cpu);
let max_rel = oc
.as_slice::<f32>()
.iter()
.zip(of.as_slice::<f32>())
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
.fold(0.0f32, f32::max);
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
assert!(
max_rel < 1e-4,
"flash fwd diverges from composed: {max_rel:.3e}"
);
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We

View File

@@ -89,6 +89,9 @@ fn main() {
// rank checkpoints its own forward/backward; exact grads, lower peak activation
// memory (lets dim1024 batch32 fit). Opt-in; default off.
let recompute = args.iter().any(|a| a == "--recompute");
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
let flash = args.iter().any(|a| a == "--flash");
let ckpt: Option<PathBuf> = args
.iter()
.position(|a| a == "--ckpt")
@@ -174,6 +177,9 @@ fn main() {
if recompute {
println!("activation recompute: ON (per-block gradient checkpointing)");
}
if flash {
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
}
let results = launch(
&devices,
&train_corpus,
@@ -187,6 +193,9 @@ fn main() {
if recompute {
m = m.with_recompute(true);
}
if flash {
m = m.with_flash(true);
}
m
},
);

View File

@@ -0,0 +1,157 @@
// T14 flash-attention correctness gate: the fused flash SDPA core must match the
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
// math (online softmax never materializes the [bh,S,S] scores), so it differs
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
//
// build two identical models (same init), one with `--flash` on, one off, run
// the SAME batched loss + backward on both, and assert
// 1. the forward logits match within tolerance
// 2. the loss matches
// 3. EVERY parameter's grad matches within tolerance
//
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
// composed path's fp32 softmax, so the two are still the same softmax numerics.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
fn run(dtype: DType, logit_tol: f32, grad_tol: f32) {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 4;
let batch = 3usize;
let seq = 40usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// --- flash OFF (composed reference) ---
let off = build(cfg, device, dtype, false);
let off_logits = host(&off.forward_batched(&ids, batch).value());
let off_loss = off.loss_batched(&ids, &tgt, batch);
let off_loss_val = host(&off_loss.value())[0];
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().expect("off grad")))
.collect();
// --- flash ON ---
let on = build(cfg, device, dtype, true);
let on_logits = host(&on.forward_batched(&ids, batch).value());
let on_loss = on.loss_batched(&ids, &tgt, batch);
let on_loss_val = host(&on_loss.value())[0];
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().expect("on grad")))
.collect();
// 1. Forward logits.
let logit_rel = off_logits
.iter()
.zip(&on_logits)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
// 2. Loss.
let loss_rel = (off_loss_val - on_loss_val).abs() / off_loss_val.abs().max(1e-4);
println!(
"[{dtype:?}] flash on/off: loss {off_loss_val:.6}/{on_loss_val:.6} (rel {loss_rel:.2e}), \
logits max rel {logit_rel:.2e}"
);
assert!(
logit_rel < logit_tol,
"[{dtype:?}] logits diverged: {logit_rel:.2e}"
);
assert!(
loss_rel < logit_tol,
"[{dtype:?}] loss diverged: {loss_rel:.2e}"
);
// 3. Every parameter grad — the load-bearing gate.
let mut max_grad_rel = 0.0f32;
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
for (a, b) in off_g.iter().zip(on_g) {
let rel = (a - b).abs() / a.abs().max(1e-3);
max_grad_rel = max_grad_rel.max(rel);
}
}
println!("[{dtype:?}] flash on/off: grad max rel err = {max_grad_rel:.3e}");
assert!(
max_grad_rel < grad_tol,
"[{dtype:?}] flash grads diverged from composed: {max_grad_rel:.3e}"
);
}
#[test]
fn flash_matches_composed_fp32() {
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
// cuBLAS, dK/dV atomicAdd order). Tight but not bit-exact.
run(DType::F32, 1e-3, 2e-2);
}
#[test]
fn flash_matches_composed_bf16() {
// bf16 (T12 composition): bf16 rounding band on top of the fp32-softmax core.
run(DType::BF16, 2e-2, 5e-2);
}

View File

@@ -67,7 +67,7 @@ fn dump_for_parity() {
// Same deterministic init as the overfit test.
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
let mut model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
@@ -76,6 +76,14 @@ fn dump_for_parity() {
fill(n, seed, 0.08)
}
});
// T14: with XTRAIN_PARITY_FLASH set, dump from the fused flash-attention path.
// flash is the SAME SDPA math, so the SAME parity.py PyTorch oracle is the
// reference for both paths — running this once per path checks flash against
// PyTorch at B>1 (forward logits + every parameter grad).
if std::env::var("XTRAIN_PARITY_FLASH").is_ok() {
model = model.with_flash(true);
println!("parity: FLASH attention path");
}
// config + ids
{

View File

@@ -116,6 +116,9 @@ fn main() {
// exact grads, lower peak activation memory (lets dim1024 batch32 fit). Opt-in;
// default off stores every activation (unchanged numerics).
let recompute = args.iter().any(|a| a == "--recompute");
// Fused flash-attention (Phase T14): single fused SDPA kernel, online softmax,
// no materialized [bh,S,S] scores. Opt-in; default off keeps the composed path.
let flash = args.iter().any(|a| a == "--flash");
let ckpt: PathBuf = PathBuf::from(
args.iter()
.position(|a| a == "--ckpt")
@@ -183,6 +186,10 @@ fn main() {
model = model.with_recompute(true);
println!("activation recompute: ON (per-block gradient checkpointing)");
}
if flash {
model = model.with_flash(true);
println!("flash-attention: ON (fused SDPA kernel, no materialized scores)");
}
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same