test: flash bwd vs composed bwd (sharper than finite-diff)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -723,6 +723,55 @@ fn flash_matches_composed_fwd() {
|
||||
);
|
||||
}
|
||||
|
||||
// flash backward must equal the (already grad-checked) composed backward. This is
|
||||
// a sharper test than finite-diff: both share the trusted composed forward as the
|
||||
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
|
||||
// near-zero gradient elements.
|
||||
#[test]
|
||||
fn flash_bwd_matches_composed_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 441);
|
||||
let k_h = fill(n, 442);
|
||||
let v_h = fill(n, 443);
|
||||
let w = fill(n, 444);
|
||||
|
||||
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||||
let out = if flash {
|
||||
ops::flash_attention(&q, &k, &v, scale)
|
||||
} else {
|
||||
ops::attention(&q, &k, &v, scale)
|
||||
};
|
||||
scalar_loss(&out, &w).backward();
|
||||
let g = |x: &Var| {
|
||||
x.grad()
|
||||
.unwrap()
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.to_vec()
|
||||
};
|
||||
(g(&q), g(&k), g(&v))
|
||||
};
|
||||
let (cq, ck, cv) = run(false);
|
||||
let (fq, fk, fv) = run(true);
|
||||
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
|
||||
.fold(0.0f32, f32::max)
|
||||
};
|
||||
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
|
||||
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
|
||||
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
|
||||
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
|
||||
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
|
||||
Reference in New Issue
Block a user