autodiff: finite-diff gradient-check harness
New xtrain-autodiff crate with a reusable central finite-difference gradient check: grad_check(x, shape, f, analytic_grad, cfg) compares an analytic gradient against (f(x+ε)-f(x-ε))/2ε per element with a relative tolerance. Host-only (no CUDA): the loss closure owns any GPU work, so T4's per-op backward checks can reuse it directly. Includes host unit tests (sum(x²) grad 2x passes; a wrong grad is rejected). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
7
crates/xtrain-autodiff/Cargo.toml
Normal file
7
crates/xtrain-autodiff/Cargo.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[package]
|
||||
name = "xtrain-autodiff"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xtrain-tensor = { path = "../xtrain-tensor" }
|
||||
126
crates/xtrain-autodiff/src/finite_diff.rs
Normal file
126
crates/xtrain-autodiff/src/finite_diff.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Central finite-difference gradient check.
|
||||
|
||||
/// A scalar loss as a function of a flat parameter vector with a given shape.
|
||||
/// `(data, shape) -> loss`. The closure owns how `data` becomes a `Tensor`
|
||||
/// (e.g. `Tensor::from_slice(data, shape).to_device(...)`) and runs forward.
|
||||
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GradCheckConfig {
|
||||
/// Perturbation magnitude per element.
|
||||
pub eps: f32,
|
||||
/// Max allowed relative error: `|num - ana| / (|num| + |ana| + atol)`.
|
||||
pub rel_tol: f32,
|
||||
/// Absolute floor in the denominator, so near-zero grads don't blow up the
|
||||
/// relative error.
|
||||
pub atol: f32,
|
||||
}
|
||||
|
||||
impl Default for GradCheckConfig {
|
||||
fn default() -> Self {
|
||||
// eps=1e-3 balances truncation error (∝ eps²) against the ~1e-7 f32
|
||||
// rounding noise on f(x±eps); rel_tol=2e-2 is the usual slack for an
|
||||
// f32 GPU GEMM checked against an f32 central difference.
|
||||
Self {
|
||||
eps: 1e-3,
|
||||
rel_tol: 2e-2,
|
||||
atol: 1e-4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GradCheckResult {
|
||||
pub passed: bool,
|
||||
pub max_rel_err: f32,
|
||||
/// Index of the worst element (largest relative error).
|
||||
pub worst_index: usize,
|
||||
pub worst_numeric: f32,
|
||||
pub worst_analytic: f32,
|
||||
}
|
||||
|
||||
/// Check `analytic_grad` against the central finite difference of `f` at `x`.
|
||||
///
|
||||
/// - `x`: flat parameter values (the point at which the gradient is taken).
|
||||
/// - `shape`: logical shape passed through to `f`.
|
||||
/// - `f`: scalar loss; called `2 * x.len()` times.
|
||||
/// - `analytic_grad`: candidate gradient, same length as `x`.
|
||||
pub fn grad_check(
|
||||
x: &[f32],
|
||||
shape: &[usize],
|
||||
f: &ParamFn,
|
||||
analytic_grad: &[f32],
|
||||
cfg: GradCheckConfig,
|
||||
) -> GradCheckResult {
|
||||
assert_eq!(
|
||||
x.len(),
|
||||
analytic_grad.len(),
|
||||
"param/grad length mismatch: {} vs {}",
|
||||
x.len(),
|
||||
analytic_grad.len()
|
||||
);
|
||||
|
||||
let mut perturbed = x.to_vec();
|
||||
let mut max_rel_err = 0.0f32;
|
||||
let mut worst_index = 0;
|
||||
let mut worst_numeric = 0.0f32;
|
||||
let mut worst_analytic = 0.0f32;
|
||||
|
||||
for i in 0..x.len() {
|
||||
let orig = x[i];
|
||||
|
||||
perturbed[i] = orig + cfg.eps;
|
||||
let f_plus = f(&perturbed, shape);
|
||||
|
||||
perturbed[i] = orig - cfg.eps;
|
||||
let f_minus = f(&perturbed, shape);
|
||||
|
||||
perturbed[i] = orig; // restore for the next element
|
||||
|
||||
let numeric = (f_plus - f_minus) / (2.0 * cfg.eps);
|
||||
let analytic = analytic_grad[i];
|
||||
|
||||
let rel_err = (numeric - analytic).abs() / (numeric.abs() + analytic.abs() + cfg.atol);
|
||||
if rel_err > max_rel_err {
|
||||
max_rel_err = rel_err;
|
||||
worst_index = i;
|
||||
worst_numeric = numeric;
|
||||
worst_analytic = analytic;
|
||||
}
|
||||
}
|
||||
|
||||
GradCheckResult {
|
||||
passed: max_rel_err <= cfg.rel_tol,
|
||||
max_rel_err,
|
||||
worst_index,
|
||||
worst_numeric,
|
||||
worst_analytic,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Host-only sanity check (no GPU): loss = sum(x²), grad = 2x.
|
||||
#[test]
|
||||
fn quadratic_grad_check() {
|
||||
let x = vec![1.0f32, -2.0, 3.0, 0.5];
|
||||
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
|
||||
let grad: Vec<f32> = x.iter().map(|t| 2.0 * t).collect();
|
||||
|
||||
let res = grad_check(&x, &[4], &f, &grad, GradCheckConfig::default());
|
||||
assert!(res.passed, "max_rel_err = {}", res.max_rel_err);
|
||||
}
|
||||
|
||||
// A deliberately wrong gradient must be rejected.
|
||||
#[test]
|
||||
fn wrong_grad_is_rejected() {
|
||||
let x = vec![1.0f32, 2.0, 3.0];
|
||||
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
|
||||
let bad_grad = vec![0.0f32, 0.0, 0.0];
|
||||
|
||||
let res = grad_check(&x, &[3], &f, &bad_grad, GradCheckConfig::default());
|
||||
assert!(!res.passed);
|
||||
}
|
||||
}
|
||||
15
crates/xtrain-autodiff/src/lib.rs
Normal file
15
crates/xtrain-autodiff/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
//! Reusable numerical-gradient checking for xtrain (Phase T3+).
|
||||
//!
|
||||
//! Given a scalar loss `f(x)` and an analytic gradient `g`, verify that `g`
|
||||
//! matches the central finite-difference estimate
|
||||
//! `(f(x+ε·eᵢ) - f(x-ε·eᵢ)) / 2ε` for every element `i`, within a relative
|
||||
//! tolerance. Later phases (T4 autograd) reuse this per-op: wrap each op's
|
||||
//! forward as the loss, run its backward to get `g`, and `grad_check`.
|
||||
//!
|
||||
//! The harness is host-only and dtype-agnostic at this layer: it works on a
|
||||
//! flat `&[f32]` parameter vector + shape and a closure. The closure is free to
|
||||
//! push the data to the GPU and run kernels — that detail stays out of here.
|
||||
|
||||
pub mod finite_diff;
|
||||
|
||||
pub use finite_diff::{GradCheckConfig, GradCheckResult, ParamFn, grad_check};
|
||||
Reference in New Issue
Block a user