test: scale Q/K in flash grad-check for well-conditioned grads

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:17:04 +08:00
parent f38beb0346
commit 80602099dc

View File

@@ -640,8 +640,12 @@ fn flash_attention_batched_bwd() {
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 241);
let k_h = fill(n, 242);
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
// gradients are well-conditioned, not the near-zero saddle values a uniform
// softmax produces (those make central finite-diff give spurious 0.0 / sign
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
let v_h = fill(n, 243);
let w = fill(n, 244);