autodiff: finite-diff gradient-check harness

New xtrain-autodiff crate with a reusable central finite-difference
gradient check: grad_check(x, shape, f, analytic_grad, cfg) compares an
analytic gradient against (f(x+ε)-f(x-ε))/2ε per element with a relative
tolerance. Host-only (no CUDA): the loss closure owns any GPU work, so
T4's per-op backward checks can reuse it directly. Includes host unit
tests (sum(x²) grad 2x passes; a wrong grad is rejected).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 15:26:42 +08:00
parent fbd07a578c
commit 9ca98efd98
5 changed files with 157 additions and 0 deletions

8
Cargo.lock generated
View File

@@ -88,6 +88,13 @@ version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
[[package]]
name = "xtrain-autodiff"
version = "0.1.0"
dependencies = [
"xtrain-tensor",
]
[[package]]
name = "xtrain-cuda"
version = "0.1.0"
@@ -101,6 +108,7 @@ version = "0.1.0"
dependencies = [
"half",
"smallvec",
"xtrain-autodiff",
"xtrain-cuda",
]

View File

@@ -3,6 +3,7 @@ resolver = "2"
members = [
"crates/xtrain-cuda",
"crates/xtrain-tensor",
"crates/xtrain-autodiff",
]
[workspace.package]

View File

@@ -0,0 +1,7 @@
[package]
name = "xtrain-autodiff"
version.workspace = true
edition.workspace = true
[dependencies]
xtrain-tensor = { path = "../xtrain-tensor" }

View File

@@ -0,0 +1,126 @@
//! Central finite-difference gradient check.
/// A scalar loss as a function of a flat parameter vector with a given shape.
/// `(data, shape) -> loss`. The closure owns how `data` becomes a `Tensor`
/// (e.g. `Tensor::from_slice(data, shape).to_device(...)`) and runs forward.
pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a;
#[derive(Debug, Clone, Copy)]
pub struct GradCheckConfig {
/// Perturbation magnitude per element.
pub eps: f32,
/// Max allowed relative error: `|num - ana| / (|num| + |ana| + atol)`.
pub rel_tol: f32,
/// Absolute floor in the denominator, so near-zero grads don't blow up the
/// relative error.
pub atol: f32,
}
impl Default for GradCheckConfig {
fn default() -> Self {
// eps=1e-3 balances truncation error (∝ eps²) against the ~1e-7 f32
// rounding noise on f(x±eps); rel_tol=2e-2 is the usual slack for an
// f32 GPU GEMM checked against an f32 central difference.
Self {
eps: 1e-3,
rel_tol: 2e-2,
atol: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct GradCheckResult {
pub passed: bool,
pub max_rel_err: f32,
/// Index of the worst element (largest relative error).
pub worst_index: usize,
pub worst_numeric: f32,
pub worst_analytic: f32,
}
/// Check `analytic_grad` against the central finite difference of `f` at `x`.
///
/// - `x`: flat parameter values (the point at which the gradient is taken).
/// - `shape`: logical shape passed through to `f`.
/// - `f`: scalar loss; called `2 * x.len()` times.
/// - `analytic_grad`: candidate gradient, same length as `x`.
pub fn grad_check(
x: &[f32],
shape: &[usize],
f: &ParamFn,
analytic_grad: &[f32],
cfg: GradCheckConfig,
) -> GradCheckResult {
assert_eq!(
x.len(),
analytic_grad.len(),
"param/grad length mismatch: {} vs {}",
x.len(),
analytic_grad.len()
);
let mut perturbed = x.to_vec();
let mut max_rel_err = 0.0f32;
let mut worst_index = 0;
let mut worst_numeric = 0.0f32;
let mut worst_analytic = 0.0f32;
for i in 0..x.len() {
let orig = x[i];
perturbed[i] = orig + cfg.eps;
let f_plus = f(&perturbed, shape);
perturbed[i] = orig - cfg.eps;
let f_minus = f(&perturbed, shape);
perturbed[i] = orig; // restore for the next element
let numeric = (f_plus - f_minus) / (2.0 * cfg.eps);
let analytic = analytic_grad[i];
let rel_err = (numeric - analytic).abs() / (numeric.abs() + analytic.abs() + cfg.atol);
if rel_err > max_rel_err {
max_rel_err = rel_err;
worst_index = i;
worst_numeric = numeric;
worst_analytic = analytic;
}
}
GradCheckResult {
passed: max_rel_err <= cfg.rel_tol,
max_rel_err,
worst_index,
worst_numeric,
worst_analytic,
}
}
#[cfg(test)]
mod tests {
use super::*;
// Host-only sanity check (no GPU): loss = sum(x²), grad = 2x.
#[test]
fn quadratic_grad_check() {
let x = vec![1.0f32, -2.0, 3.0, 0.5];
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
let grad: Vec<f32> = x.iter().map(|t| 2.0 * t).collect();
let res = grad_check(&x, &[4], &f, &grad, GradCheckConfig::default());
assert!(res.passed, "max_rel_err = {}", res.max_rel_err);
}
// A deliberately wrong gradient must be rejected.
#[test]
fn wrong_grad_is_rejected() {
let x = vec![1.0f32, 2.0, 3.0];
let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::<f32>();
let bad_grad = vec![0.0f32, 0.0, 0.0];
let res = grad_check(&x, &[3], &f, &bad_grad, GradCheckConfig::default());
assert!(!res.passed);
}
}

View File

@@ -0,0 +1,15 @@
//! Reusable numerical-gradient checking for xtrain (Phase T3+).
//!
//! Given a scalar loss `f(x)` and an analytic gradient `g`, verify that `g`
//! matches the central finite-difference estimate
//! `(f(x+ε·eᵢ) - f(x-ε·eᵢ)) / 2ε` for every element `i`, within a relative
//! tolerance. Later phases (T4 autograd) reuse this per-op: wrap each op's
//! forward as the loss, run its backward to get `g`, and `grad_check`.
//!
//! The harness is host-only and dtype-agnostic at this layer: it works on a
//! flat `&[f32]` parameter vector + shape and a closure. The closure is free to
//! push the data to the GPU and run kernels — that detail stays out of here.
pub mod finite_diff;
pub use finite_diff::{GradCheckConfig, GradCheckResult, ParamFn, grad_check};