test: scale Q/K in flash grad-check for well-conditioned grads
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -640,8 +640,12 @@ fn flash_attention_batched_bwd() {
|
||||
let (bh, seq, hd) = (2, 5, 6);
|
||||
let n = bh * seq * hd;
|
||||
let scale = 1.0 / (hd as f32).sqrt();
|
||||
let q_h = fill(n, 241);
|
||||
let k_h = fill(n, 242);
|
||||
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
|
||||
// gradients are well-conditioned, not the near-zero saddle values a uniform
|
||||
// softmax produces (those make central finite-diff give spurious 0.0 / sign
|
||||
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
|
||||
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
|
||||
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
|
||||
let v_h = fill(n, 243);
|
||||
let w = fill(n, 244);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user