New xtrain-autodiff crate with a reusable central finite-difference gradient check: grad_check(x, shape, f, analytic_grad, cfg) compares an analytic gradient against (f(x+ε)-f(x-ε))/2ε per element with a relative tolerance. Host-only (no CUDA): the loss closure owns any GPU work, so T4's per-op backward checks can reuse it directly. Includes host unit tests (sum(x²) grad 2x passes; a wrong grad is rejected). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
127 lines
3.8 KiB
Rust
127 lines
3.8 KiB
Rust
//! Central finite-difference gradient check.
|
|
|
|
/// A scalar loss as a function of a flat parameter vector with a given shape.
|
|
/// `(data, shape) -> loss`. The closure owns how `data` becomes a `Tensor`
|
|
/// (e.g. `Tensor::from_slice(data, shape).to_device(...)`) and runs forward.
|
|
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a;
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct GradCheckConfig {
|
|
/// Perturbation magnitude per element.
|
|
pub eps: f32,
|
|
/// Max allowed relative error: `|num - ana| / (|num| + |ana| + atol)`.
|
|
pub rel_tol: f32,
|
|
/// Absolute floor in the denominator, so near-zero grads don't blow up the
|
|
/// relative error.
|
|
pub atol: f32,
|
|
}
|
|
|
|
impl Default for GradCheckConfig {
|
|
fn default() -> Self {
|
|
// eps=1e-3 balances truncation error (∝ eps²) against the ~1e-7 f32
|
|
// rounding noise on f(x±eps); rel_tol=2e-2 is the usual slack for an
|
|
// f32 GPU GEMM checked against an f32 central difference.
|
|
Self {
|
|
eps: 1e-3,
|
|
rel_tol: 2e-2,
|
|
atol: 1e-4,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct GradCheckResult {
|
|
pub passed: bool,
|
|
pub max_rel_err: f32,
|
|
/// Index of the worst element (largest relative error).
|
|
pub worst_index: usize,
|
|
pub worst_numeric: f32,
|
|
pub worst_analytic: f32,
|
|
}
|
|
|
|
/// Check `analytic_grad` against the central finite difference of `f` at `x`.
|
|
///
|
|
/// - `x`: flat parameter values (the point at which the gradient is taken).
|
|
/// - `shape`: logical shape passed through to `f`.
|
|
/// - `f`: scalar loss; called `2 * x.len()` times.
|
|
/// - `analytic_grad`: candidate gradient, same length as `x`.
|
|
pub fn grad_check(
|
|
x: &[f32],
|
|
shape: &[usize],
|
|
f: &ParamFn,
|
|
analytic_grad: &[f32],
|
|
cfg: GradCheckConfig,
|
|
) -> GradCheckResult {
|
|
assert_eq!(
|
|
x.len(),
|
|
analytic_grad.len(),
|
|
"param/grad length mismatch: {} vs {}",
|
|
x.len(),
|
|
analytic_grad.len()
|
|
);
|
|
|
|
let mut perturbed = x.to_vec();
|
|
let mut max_rel_err = 0.0f32;
|
|
let mut worst_index = 0;
|
|
let mut worst_numeric = 0.0f32;
|
|
let mut worst_analytic = 0.0f32;
|
|
|
|
for i in 0..x.len() {
|
|
let orig = x[i];
|
|
|
|
perturbed[i] = orig + cfg.eps;
|
|
let f_plus = f(&perturbed, shape);
|
|
|
|
perturbed[i] = orig - cfg.eps;
|
|
let f_minus = f(&perturbed, shape);
|
|
|
|
perturbed[i] = orig; // restore for the next element
|
|
|
|
let numeric = (f_plus - f_minus) / (2.0 * cfg.eps);
|
|
let analytic = analytic_grad[i];
|
|
|
|
let rel_err = (numeric - analytic).abs() / (numeric.abs() + analytic.abs() + cfg.atol);
|
|
if rel_err > max_rel_err {
|
|
max_rel_err = rel_err;
|
|
worst_index = i;
|
|
worst_numeric = numeric;
|
|
worst_analytic = analytic;
|
|
}
|
|
}
|
|
|
|
GradCheckResult {
|
|
passed: max_rel_err <= cfg.rel_tol,
|
|
max_rel_err,
|
|
worst_index,
|
|
worst_numeric,
|
|
worst_analytic,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// Host-only sanity check (no GPU): loss = sum(x²), grad = 2x.
|
|
#[test]
|
|
fn quadratic_grad_check() {
|
|
let x = vec![1.0f32, -2.0, 3.0, 0.5];
|
|
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
|
|
let grad: Vec<f32> = x.iter().map(|t| 2.0 * t).collect();
|
|
|
|
let res = grad_check(&x, &[4], &f, &grad, GradCheckConfig::default());
|
|
assert!(res.passed, "max_rel_err = {}", res.max_rel_err);
|
|
}
|
|
|
|
// A deliberately wrong gradient must be rejected.
|
|
#[test]
|
|
fn wrong_grad_is_rejected() {
|
|
let x = vec![1.0f32, 2.0, 3.0];
|
|
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
|
|
let bad_grad = vec![0.0f32, 0.0, 0.0];
|
|
|
|
let res = grad_check(&x, &[3], &f, &bad_grad, GradCheckConfig::default());
|
|
assert!(!res.passed);
|
|
}
|
|
}
|