test: flash finite-diff grad-check uses single-tile clean regime

Match the trusted composed grad-check dims (seq=5<FA_TILE); the multi-tile
online-softmax path is gated by flash_bwd_matches_composed_bwd (seq=40),
sharper than finite-diff on the near-zero grads a long softmax produces.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:16:20 +08:00
parent 01fb22d114
commit f38beb0346

View File

@@ -626,14 +626,18 @@ fn attention_batched_bwd() {
}
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure as attention_batched_bwd, but exercises ops::flash_attention.
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L=sum(W∘out).
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised (not
// just a single KV tile).
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
// produces.
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 241);