diff --git a/Cargo.lock b/Cargo.lock index 7b18280..31f0ebb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,13 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "xtrain-autodiff" +version = "0.1.0" +dependencies = [ + "xtrain-tensor", +] + [[package]] name = "xtrain-cuda" version = "0.1.0" @@ -101,6 +108,7 @@ version = "0.1.0" dependencies = [ "half", "smallvec", + "xtrain-autodiff", "xtrain-cuda", ] diff --git a/Cargo.toml b/Cargo.toml index b6f0129..63d1380 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ resolver = "2" members = [ "crates/xtrain-cuda", "crates/xtrain-tensor", + "crates/xtrain-autodiff", ] [workspace.package] diff --git a/crates/xtrain-autodiff/Cargo.toml b/crates/xtrain-autodiff/Cargo.toml new file mode 100644 index 0000000..5ecc5c4 --- /dev/null +++ b/crates/xtrain-autodiff/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "xtrain-autodiff" +version.workspace = true +edition.workspace = true + +[dependencies] +xtrain-tensor = { path = "../xtrain-tensor" } diff --git a/crates/xtrain-autodiff/src/finite_diff.rs b/crates/xtrain-autodiff/src/finite_diff.rs new file mode 100644 index 0000000..8af7c9e --- /dev/null +++ b/crates/xtrain-autodiff/src/finite_diff.rs @@ -0,0 +1,126 @@ +//! Central finite-difference gradient check. + +/// A scalar loss as a function of a flat parameter vector with a given shape. +/// `(data, shape) -> loss`. The closure owns how `data` becomes a `Tensor` +/// (e.g. `Tensor::from_slice(data, shape).to_device(...)`) and runs forward. +pub type ParamFn<'a> = dyn Fn(&[f32], &[usize]) -> f32 + 'a; + +#[derive(Debug, Clone, Copy)] +pub struct GradCheckConfig { + /// Perturbation magnitude per element. + pub eps: f32, + /// Max allowed relative error: `|num - ana| / (|num| + |ana| + atol)`. + pub rel_tol: f32, + /// Absolute floor in the denominator, so near-zero grads don't blow up the + /// relative error. + pub atol: f32, +} + +impl Default for GradCheckConfig { + fn default() -> Self { + // eps=1e-3 balances truncation error (∝ eps²) against the ~1e-7 f32 + // rounding noise on f(x±eps); rel_tol=2e-2 is the usual slack for an + // f32 GPU GEMM checked against an f32 central difference. + Self { + eps: 1e-3, + rel_tol: 2e-2, + atol: 1e-4, + } + } +} + +#[derive(Debug, Clone)] +pub struct GradCheckResult { + pub passed: bool, + pub max_rel_err: f32, + /// Index of the worst element (largest relative error). + pub worst_index: usize, + pub worst_numeric: f32, + pub worst_analytic: f32, +} + +/// Check `analytic_grad` against the central finite difference of `f` at `x`. +/// +/// - `x`: flat parameter values (the point at which the gradient is taken). +/// - `shape`: logical shape passed through to `f`. +/// - `f`: scalar loss; called `2 * x.len()` times. +/// - `analytic_grad`: candidate gradient, same length as `x`. +pub fn grad_check( + x: &[f32], + shape: &[usize], + f: &ParamFn, + analytic_grad: &[f32], + cfg: GradCheckConfig, +) -> GradCheckResult { + assert_eq!( + x.len(), + analytic_grad.len(), + "param/grad length mismatch: {} vs {}", + x.len(), + analytic_grad.len() + ); + + let mut perturbed = x.to_vec(); + let mut max_rel_err = 0.0f32; + let mut worst_index = 0; + let mut worst_numeric = 0.0f32; + let mut worst_analytic = 0.0f32; + + for i in 0..x.len() { + let orig = x[i]; + + perturbed[i] = orig + cfg.eps; + let f_plus = f(&perturbed, shape); + + perturbed[i] = orig - cfg.eps; + let f_minus = f(&perturbed, shape); + + perturbed[i] = orig; // restore for the next element + + let numeric = (f_plus - f_minus) / (2.0 * cfg.eps); + let analytic = analytic_grad[i]; + + let rel_err = (numeric - analytic).abs() / (numeric.abs() + analytic.abs() + cfg.atol); + if rel_err > max_rel_err { + max_rel_err = rel_err; + worst_index = i; + worst_numeric = numeric; + worst_analytic = analytic; + } + } + + GradCheckResult { + passed: max_rel_err <= cfg.rel_tol, + max_rel_err, + worst_index, + worst_numeric, + worst_analytic, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Host-only sanity check (no GPU): loss = sum(x²), grad = 2x. + #[test] + fn quadratic_grad_check() { + let x = vec![1.0f32, -2.0, 3.0, 0.5]; + let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::(); + let grad: Vec = x.iter().map(|t| 2.0 * t).collect(); + + let res = grad_check(&x, &[4], &f, &grad, GradCheckConfig::default()); + assert!(res.passed, "max_rel_err = {}", res.max_rel_err); + } + + // A deliberately wrong gradient must be rejected. + #[test] + fn wrong_grad_is_rejected() { + let x = vec![1.0f32, 2.0, 3.0]; + let f = |v: &[f32], _shape: &[usize]| v.iter().map(|t| t * t).sum::(); + let bad_grad = vec![0.0f32, 0.0, 0.0]; + + let res = grad_check(&x, &[3], &f, &bad_grad, GradCheckConfig::default()); + assert!(!res.passed); + } +} diff --git a/crates/xtrain-autodiff/src/lib.rs b/crates/xtrain-autodiff/src/lib.rs new file mode 100644 index 0000000..d8584ad --- /dev/null +++ b/crates/xtrain-autodiff/src/lib.rs @@ -0,0 +1,15 @@ +//! Reusable numerical-gradient checking for xtrain (Phase T3+). +//! +//! Given a scalar loss `f(x)` and an analytic gradient `g`, verify that `g` +//! matches the central finite-difference estimate +//! `(f(x+ε·eᵢ) - f(x-ε·eᵢ)) / 2ε` for every element `i`, within a relative +//! tolerance. Later phases (T4 autograd) reuse this per-op: wrap each op's +//! forward as the loss, run its backward to get `g`, and `grad_check`. +//! +//! The harness is host-only and dtype-agnostic at this layer: it works on a +//! flat `&[f32]` parameter vector + shape and a closure. The closure is free to +//! push the data to the GPU and run kernels — that detail stays out of here. + +pub mod finite_diff; + +pub use finite_diff::{GradCheckConfig, GradCheckResult, ParamFn, grad_check};