test: eps=2e-3 for flash dQ/dK finite-diff (cuts f32 rounding term)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -666,29 +666,25 @@ fn flash_attention_batched_bwd() {
|
||||
let (o, _) = qv.flash_attention(&kv, &vv, scale);
|
||||
weighted_sum(&o, &w)
|
||||
};
|
||||
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
|
||||
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
|
||||
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
|
||||
let cfg_attn = GradCheckConfig {
|
||||
eps: 2e-3,
|
||||
rel_tol: 3e-2,
|
||||
atol: 1e-3,
|
||||
};
|
||||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||||
report(
|
||||
"flash dQ",
|
||||
&grad_check(
|
||||
&q_h,
|
||||
&[bh, seq, hd],
|
||||
&lq,
|
||||
dq.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
|
||||
);
|
||||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||||
report(
|
||||
"flash dK",
|
||||
&grad_check(
|
||||
&k_h,
|
||||
&[bh, seq, hd],
|
||||
&lk,
|
||||
dk.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
|
||||
);
|
||||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||||
|
||||
Reference in New Issue
Block a user