Files
xtrain/crates/xtrain-autodiff/src/finite_diff.rs
Gahow Wang 9ca98efd98 autodiff: finite-diff gradient-check harness
New xtrain-autodiff crate with a reusable central finite-difference
gradient check: grad_check(x, shape, f, analytic_grad, cfg) compares an
analytic gradient against (f(x+ε)-f(x-ε))/2ε per element with a relative
tolerance. Host-only (no CUDA): the loss closure owns any GPU work, so
T4's per-op backward checks can reuse it directly. Includes host unit
tests (sum(x²) grad 2x passes; a wrong grad is rejected).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 15:26:42 +08:00

127 lines
3.8 KiB
Rust

//! Central finite-difference gradient check.
/// A scalar loss as a function of a flat parameter vector with a given shape.
/// `(data, shape) -> loss`. The closure owns how `data` becomes a `Tensor`
/// (e.g. `Tensor::from_slice(data, shape).to_device(...)`) and runs forward.
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a;
#[derive(Debug, Clone, Copy)]
pub struct GradCheckConfig {
/// Perturbation magnitude per element.
pub eps: f32,
/// Max allowed relative error: `|num - ana| / (|num| + |ana| + atol)`.
pub rel_tol: f32,
/// Absolute floor in the denominator, so near-zero grads don't blow up the
/// relative error.
pub atol: f32,
}
impl Default for GradCheckConfig {
fn default() -> Self {
// eps=1e-3 balances truncation error (∝ eps²) against the ~1e-7 f32
// rounding noise on f(x±eps); rel_tol=2e-2 is the usual slack for an
// f32 GPU GEMM checked against an f32 central difference.
Self {
eps: 1e-3,
rel_tol: 2e-2,
atol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct GradCheckResult {
pub passed: bool,
pub max_rel_err: f32,
/// Index of the worst element (largest relative error).
pub worst_index: usize,
pub worst_numeric: f32,
pub worst_analytic: f32,
}
/// Check `analytic_grad` against the central finite difference of `f` at `x`.
///
/// - `x`: flat parameter values (the point at which the gradient is taken).
/// - `shape`: logical shape passed through to `f`.
/// - `f`: scalar loss; called `2 * x.len()` times.
/// - `analytic_grad`: candidate gradient, same length as `x`.
pub fn grad_check(
x: &[f32],
shape: &[usize],
f: &ParamFn,
analytic_grad: &[f32],
cfg: GradCheckConfig,
) -> GradCheckResult {
assert_eq!(
x.len(),
analytic_grad.len(),
"param/grad length mismatch: {} vs {}",
x.len(),
analytic_grad.len()
);
let mut perturbed = x.to_vec();
let mut max_rel_err = 0.0f32;
let mut worst_index = 0;
let mut worst_numeric = 0.0f32;
let mut worst_analytic = 0.0f32;
for i in 0..x.len() {
let orig = x[i];
perturbed[i] = orig + cfg.eps;
let f_plus = f(&perturbed, shape);
perturbed[i] = orig - cfg.eps;
let f_minus = f(&perturbed, shape);
perturbed[i] = orig; // restore for the next element
let numeric = (f_plus - f_minus) / (2.0 * cfg.eps);
let analytic = analytic_grad[i];
let rel_err = (numeric - analytic).abs() / (numeric.abs() + analytic.abs() + cfg.atol);
if rel_err > max_rel_err {
max_rel_err = rel_err;
worst_index = i;
worst_numeric = numeric;
worst_analytic = analytic;
}
}
GradCheckResult {
passed: max_rel_err <= cfg.rel_tol,
max_rel_err,
worst_index,
worst_numeric,
worst_analytic,
}
}
#[cfg(test)]
mod tests {
use super::*;
// Host-only sanity check (no GPU): loss = sum(x²), grad = 2x.
#[test]
fn quadratic_grad_check() {
let x = vec![1.0f32, -2.0, 3.0, 0.5];
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
let grad: Vec<f32> = x.iter().map(|t| 2.0 * t).collect();
let res = grad_check(&x, &[4], &f, &grad, GradCheckConfig::default());
assert!(res.passed, "max_rel_err = {}", res.max_rel_err);
}
// A deliberately wrong gradient must be rejected.
#[test]
fn wrong_grad_is_rejected() {
let x = vec![1.0f32, 2.0, 3.0];
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
let bad_grad = vec![0.0f32, 0.0, 0.0];
let res = grad_check(&x, &[3], &f, &bad_grad, GradCheckConfig::default());
assert!(!res.passed);
}
}