gemm: realistic f32 tolerances in GEMM acceptance tests
Forward: compare via matrix relative error (max abs error / max|ref|) instead of a per-element ratio, so near-zero outputs where two correct f32 GEMMs differ only in rounding order don't inflate the metric. Backward: L = sum(W∘C) is bilinear, so central differences are truncation-free — use eps=1e-2 (sharper f32 resolution of the difference) and atol=1e-3 to floor near-zero-gradient subtraction noise. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -71,11 +71,21 @@ fn cublas_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
c
|
||||
}
|
||||
|
||||
// Matrix relative error: max element-wise abs error normalized by the
|
||||
// magnitude scale of the reference. Using a single global denominator avoids
|
||||
// individual near-zero outputs (where two correct f32 GEMMs differ only in
|
||||
// rounding order) blowing up a per-element ratio.
|
||||
fn max_rel_err(got: &[f32], reference: &[f32]) -> f32 {
|
||||
got.iter()
|
||||
let scale = reference
|
||||
.iter()
|
||||
.fold(0.0f32, |m, r| m.max(r.abs()))
|
||||
.max(1e-6);
|
||||
let max_abs = got
|
||||
.iter()
|
||||
.zip(reference)
|
||||
.map(|(g, r)| (g - r).abs() / (g.abs() + r.abs() + 1e-6))
|
||||
.fold(0.0f32, f32::max)
|
||||
.map(|(g, r)| (g - r).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
max_abs / scale
|
||||
}
|
||||
|
||||
// --- Forward: hand-written tiled GEMM vs cuBLAS sgemm ---
|
||||
@@ -129,7 +139,16 @@ fn run_bwd(m: usize, k: usize, n: usize) {
|
||||
let da_host = da.to_device(Device::Cpu);
|
||||
let db_host = db.to_device(Device::Cpu);
|
||||
|
||||
let cfg = GradCheckConfig::default();
|
||||
// L = sum(W∘C) is bilinear, so it is *exactly linear* in A (B fixed) and in
|
||||
// B (A fixed): central differences carry no truncation error, and a larger
|
||||
// eps only sharpens the f32 resolution of f(x+eps)-f(x-eps). atol floors the
|
||||
// denominator at the ~1e-3 gradient scale so near-zero grads (pure f32
|
||||
// subtraction noise) don't dominate the relative error.
|
||||
let cfg = GradCheckConfig {
|
||||
eps: 1e-2,
|
||||
rel_tol: 2e-2,
|
||||
atol: 1e-3,
|
||||
};
|
||||
|
||||
// Check dA: vary A, hold B fixed.
|
||||
let b_fixed = b_host.clone();
|
||||
|
||||
Reference in New Issue
Block a user