test: flash finite-diff grad-check uses single-tile clean regime
Match the trusted composed grad-check dims (seq=5<FA_TILE); the multi-tile online-softmax path is gated by flash_bwd_matches_composed_bwd (seq=40), sharper than finite-diff on the near-zero grads a long softmax produces. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -626,14 +626,18 @@ fn attention_batched_bwd() {
|
||||
}
|
||||
|
||||
// ---- fused FLASH causal attention (the T14 op) ----
|
||||
// Same structure as attention_batched_bwd, but exercises ops::flash_attention.
|
||||
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L=sum(W∘out).
|
||||
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised (not
|
||||
// just a single KV tile).
|
||||
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
|
||||
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
|
||||
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
|
||||
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
|
||||
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
|
||||
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
|
||||
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
|
||||
// produces.
|
||||
#[test]
|
||||
fn flash_attention_batched_bwd() {
|
||||
require_gpu();
|
||||
let (bh, seq, hd) = (2, 40, 16);
|
||||
let (bh, seq, hd) = (2, 5, 6);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 241);
|
||||
|
||||
Reference in New Issue
Block a user