gemm: realistic f32 tolerances in GEMM acceptance tests

Forward: compare via matrix relative error (max abs error / max|ref|)
instead of a per-element ratio, so near-zero outputs where two correct
f32 GEMMs differ only in rounding order don't inflate the metric.
Backward: L = sum(W∘C) is bilinear, so central differences are
truncation-free — use eps=1e-2 (sharper f32 resolution of the
difference) and atol=1e-3 to floor near-zero-gradient subtraction noise.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 15:28:57 +08:00
parent dde2fde297
commit 88fbe0a85d

View File

@@ -71,11 +71,21 @@ fn cublas_matmul(a: &Tensor, b: &Tensor) -> Tensor {
c
}
// Matrix relative error: max element-wise abs error normalized by the
// magnitude scale of the reference. Using a single global denominator avoids
// individual near-zero outputs (where two correct f32 GEMMs differ only in
// rounding order) blowing up a per-element ratio.
fn max_rel_err(got: &[f32], reference: &[f32]) -> f32 {
got.iter()
let scale = reference
.iter()
.fold(0.0f32, |m, r| m.max(r.abs()))
.max(1e-6);
let max_abs = got
.iter()
.zip(reference)
.map(|(g, r)| (g - r).abs() / (g.abs() + r.abs() + 1e-6))
.fold(0.0f32, f32::max)
.map(|(g, r)| (g - r).abs())
.fold(0.0f32, f32::max);
max_abs / scale
}
// --- Forward: hand-written tiled GEMM vs cuBLAS sgemm ---
@@ -129,7 +139,16 @@ fn run_bwd(m: usize, k: usize, n: usize) {
let da_host = da.to_device(Device::Cpu);
let db_host = db.to_device(Device::Cpu);
let cfg = GradCheckConfig::default();
// L = sum(W∘C) is bilinear, so it is *exactly linear* in A (B fixed) and in
// B (A fixed): central differences carry no truncation error, and a larger
// eps only sharpens the f32 resolution of f(x+eps)-f(x-eps). atol floors the
// denominator at the ~1e-3 gradient scale so near-zero grads (pure f32
// subtraction noise) don't dominate the relative error.
let cfg = GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
};
// Check dA: vary A, hold B fixed.
let b_fixed = b_host.clone();