test: flash bwd vs composed bwd (sharper than finite-diff)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:12:30 +08:00
parent 5f3b81ac96
commit 01fb22d114

View File

@@ -723,6 +723,55 @@ fn flash_matches_composed_fwd() {
);
}
// flash backward must equal the (already grad-checked) composed backward. This is
// a sharper test than finite-diff: both share the trusted composed forward as the
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
// near-zero gradient elements.
#[test]
fn flash_bwd_matches_composed_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 441);
let k_h = fill(n, 442);
let v_h = fill(n, 443);
let w = fill(n, 444);
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = if flash {
ops::flash_attention(&q, &k, &v, scale)
} else {
ops::attention(&q, &k, &v, scale)
};
scalar_loss(&out, &w).backward();
let g = |x: &Var| {
x.grad()
.unwrap()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
};
(g(&q), g(&k), g(&v))
};
let (cq, ck, cv) = run(false);
let (fq, fk, fv) = run(true);
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
.fold(0.0f32, f32::max)
};
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We