ops: differentiable autograd nodes + per-op grad-check tests
ops.rs wraps each Tensor op as a Var node with its backward closure (forward caches captured by move). swiglu = mul(silu(gate), up); attention is composed (matmul+scale+softmax+matmul), no fused kernel. tests/autograd.rs grad-checks every op via the L=sum(W∘out) template, plus a fan-out grad-accumulation test (dL/dx=4x) and an end-to-end composed-attention grad-check (dQ/dK/dV). Adds xtrain-cuda dev-dep for device selection in tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -92,6 +92,7 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
|
||||
name = "xtrain-autodiff"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"xtrain-cuda",
|
||||
"xtrain-tensor",
|
||||
]
|
||||
|
||||
|
||||
@@ -5,3 +5,7 @@ edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
xtrain-tensor = { path = "../xtrain-tensor" }
|
||||
|
||||
[dev-dependencies]
|
||||
# Acceptance tests need device selection (set_device) to drive the GPU.
|
||||
xtrain-cuda = { path = "../xtrain-cuda" }
|
||||
|
||||
172
crates/xtrain-autodiff/src/ops.rs
Normal file
172
crates/xtrain-autodiff/src/ops.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
//! Differentiable ops as autograd nodes (Phase T4).
|
||||
//!
|
||||
//! Each function runs the forward [`Tensor`] kernel, then builds a [`Var`] whose
|
||||
//! backward closure computes the analytic gradient (see
|
||||
//! `docs/03-autograd-engine.md` for the math) and pushes it to each parent via
|
||||
//! [`Var::push_grad`] (which SUMs — correct under fan-out). Forward outputs that
|
||||
//! the backward needs (softmax `y`, rms `inv_rms`, cross-entropy `probs`) are
|
||||
//! cached by moving them into the closure.
|
||||
//!
|
||||
//! Attention is NOT a node here: it is composed from `matmul` + `scale` +
|
||||
//! `softmax` in user code, and its backward falls out of theirs.
|
||||
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use crate::tape::Var;
|
||||
use xtrain_tensor::Tensor;
|
||||
|
||||
/// `C = A @ B` (2D). Backward: `dA = dC @ Bᵀ`, `dB = Aᵀ @ dC`.
|
||||
pub fn matmul(a: &Var, b: &Var) -> Var {
|
||||
let out = a.value().matmul(&b.value());
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![a.clone(), b.clone()],
|
||||
Box::new(|dc, parents| {
|
||||
let a = parents[0].value();
|
||||
let b = parents[1].value();
|
||||
let (da, db) = Tensor::matmul_backward(&a, &b, dc);
|
||||
Var::push_grad(&parents[0], da);
|
||||
Var::push_grad(&parents[1], db);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Elementwise `out = a + b` (same shape). Backward: grad flows unchanged to both.
|
||||
pub fn add(a: &Var, b: &Var) -> Var {
|
||||
let out = a.value().add(&b.value());
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![a.clone(), b.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.clone());
|
||||
Var::push_grad(&parents[1], d.clone());
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Elementwise `out = a * b` (Hadamard). Backward: `da = d∘b`, `db = d∘a`.
|
||||
pub fn mul(a: &Var, b: &Var) -> Var {
|
||||
let out = a.value().mul(&b.value());
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![a.clone(), b.clone()],
|
||||
Box::new(|d, parents| {
|
||||
let a = parents[0].value();
|
||||
let b = parents[1].value();
|
||||
Var::push_grad(&parents[0], d.mul(&b));
|
||||
Var::push_grad(&parents[1], d.mul(&a));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Broadcast bias add: `out[r,c] = x[r,c] + bias[c]`. Backward: `dx = d`,
|
||||
/// `dbias[c] = sum_r d[r,c]` (sum over the broadcast dim).
|
||||
pub fn add_bias(x: &Var, bias: &Var) -> Var {
|
||||
let out = x.value().add_bias(&bias.value());
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone(), bias.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.clone());
|
||||
Var::push_grad(&parents[1], d.sum_rows());
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Scale by a constant: `out = x * alpha`. Backward: `dx = d * alpha`.
|
||||
pub fn scale(x: &Var, alpha: f32) -> Var {
|
||||
let out = x.value().scale(alpha);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
Var::push_grad(&parents[0], d.scale(alpha));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// RMSNorm: `y = x * rsqrt(mean(x²)+eps) * gamma`. Caches `inv_rms` for backward.
|
||||
pub fn rms_norm(x: &Var, gamma: &Var, eps: f32) -> Var {
|
||||
let (y, inv_rms) = x.value().rms_norm(&gamma.value(), eps);
|
||||
Var::from_op(
|
||||
y,
|
||||
vec![x.clone(), gamma.clone()],
|
||||
Box::new(move |dy, parents| {
|
||||
let x = parents[0].value();
|
||||
let gamma = parents[1].value();
|
||||
let (dx, dgamma) = Tensor::rms_norm_backward(&x, &gamma, dy, &inv_rms);
|
||||
Var::push_grad(&parents[0], dx);
|
||||
Var::push_grad(&parents[1], dgamma);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// SiLU: `y = x * sigmoid(x)`. Backward uses the forward `x`.
|
||||
pub fn silu(x: &Var) -> Var {
|
||||
let out = x.value().silu();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(|dy, parents| {
|
||||
let x = parents[0].value();
|
||||
Var::push_grad(&parents[0], Tensor::silu_backward(&x, dy));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// SwiGLU (SiLU-gated GLU): `out = silu(gate) ∘ up`. Composed from `silu` + `mul`
|
||||
/// so its backward comes from theirs — no dedicated kernel needed.
|
||||
pub fn swiglu(gate: &Var, up: &Var) -> Var {
|
||||
mul(&silu(gate), up)
|
||||
}
|
||||
|
||||
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]`. Orthogonal map, so the
|
||||
/// backward is the inverse rotation of `dy` — no cached forward values needed.
|
||||
pub fn rope(x: &Var, theta: f32) -> Var {
|
||||
let out = x.value().rope(theta);
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |dy, parents| {
|
||||
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Row-wise softmax. Caches the output `y` for the Jacobian backward.
|
||||
pub fn softmax(x: &Var) -> Var {
|
||||
let y = x.value().softmax();
|
||||
let y_cache = y.clone();
|
||||
Var::from_op(
|
||||
y,
|
||||
vec![x.clone()],
|
||||
Box::new(move |dy, parents| {
|
||||
Var::push_grad(&parents[0], Tensor::softmax_backward(&y_cache, dy));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
|
||||
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
|
||||
/// scaled by the upstream scalar grad.
|
||||
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
let (probs, per_row) = x.value().cross_entropy(target);
|
||||
let rows = x.value().shape()[0];
|
||||
// Mean loss as a host scalar wrapped back into a [1] tensor.
|
||||
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
|
||||
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
|
||||
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
|
||||
|
||||
let target = target.clone();
|
||||
Var::from_op(
|
||||
loss,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
// `d` is the scalar upstream grad (1.0 when this is the loss root).
|
||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||
let scale = upstream / rows as f32;
|
||||
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
|
||||
Var::push_grad(&parents[0], dx);
|
||||
}),
|
||||
)
|
||||
}
|
||||
546
crates/xtrain-autodiff/tests/autograd.rs
Normal file
546
crates/xtrain-autodiff/tests/autograd.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
|
||||
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
|
||||
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
|
||||
// through the tape, call backward(), and grad-check each input's .grad() against
|
||||
// central finite differences of L.
|
||||
//
|
||||
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_autodiff::{GradCheckConfig, grad_check};
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
|
||||
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
|
||||
fn fill(n: usize, seed: u64) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn require_gpu() {
|
||||
assert!(
|
||||
device::device_count().expect("device count") > 0,
|
||||
"no CUDA device"
|
||||
);
|
||||
device::set_device(0).unwrap();
|
||||
}
|
||||
|
||||
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
|
||||
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
|
||||
}
|
||||
|
||||
// L = sum(W ∘ out) for fixed weights W over the op output.
|
||||
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
|
||||
out.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.zip(w)
|
||||
.map(|(o, w)| o * w)
|
||||
.sum()
|
||||
}
|
||||
|
||||
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
|
||||
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
|
||||
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
|
||||
// → smaller eps. atol floors near-zero grads.
|
||||
fn cfg_linear() -> GradCheckConfig {
|
||||
GradCheckConfig {
|
||||
eps: 1e-2,
|
||||
rel_tol: 2e-2,
|
||||
atol: 1e-3,
|
||||
}
|
||||
}
|
||||
fn cfg_nonlinear() -> GradCheckConfig {
|
||||
GradCheckConfig {
|
||||
eps: 1e-3,
|
||||
rel_tol: 3e-2,
|
||||
atol: 1e-3,
|
||||
}
|
||||
}
|
||||
|
||||
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
|
||||
println!(
|
||||
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
||||
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
|
||||
);
|
||||
assert!(res.passed, "{name} grad-check failed: {res:?}");
|
||||
}
|
||||
|
||||
// ---- add ----
|
||||
#[test]
|
||||
fn add_bwd() {
|
||||
require_gpu();
|
||||
let (m, n) = (8, 6);
|
||||
let a_h = fill(m * n, 1);
|
||||
let b_h = fill(m * n, 2);
|
||||
let w = fill(m * n, 3);
|
||||
|
||||
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
||||
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
||||
let out = ops::add(&a, &b);
|
||||
let loss = scalar_loss(&out, &w);
|
||||
loss.backward();
|
||||
|
||||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||||
let bf = b_h.clone();
|
||||
let wf = w.clone();
|
||||
let la = move |v: &[f32], s: &[usize]| {
|
||||
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
|
||||
weighted_sum(&o, &wf)
|
||||
};
|
||||
report(
|
||||
"add dA",
|
||||
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
let af = a_h.clone();
|
||||
let wf = w.clone();
|
||||
let lb = move |v: &[f32], s: &[usize]| {
|
||||
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
|
||||
weighted_sum(&o, &wf)
|
||||
};
|
||||
report(
|
||||
"add dB",
|
||||
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- mul ----
|
||||
#[test]
|
||||
fn mul_bwd() {
|
||||
require_gpu();
|
||||
let (m, n) = (8, 6);
|
||||
let a_h = fill(m * n, 11);
|
||||
let b_h = fill(m * n, 22);
|
||||
let w = fill(m * n, 33);
|
||||
|
||||
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
||||
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
||||
let out = ops::mul(&a, &b);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||||
let bf = b_h.clone();
|
||||
let wf = w.clone();
|
||||
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
|
||||
report(
|
||||
"mul dA",
|
||||
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
let af = a_h.clone();
|
||||
let wf = w.clone();
|
||||
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
|
||||
report(
|
||||
"mul dB",
|
||||
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- add_bias (broadcast) ----
|
||||
#[test]
|
||||
fn add_bias_bwd() {
|
||||
require_gpu();
|
||||
let (m, n) = (10, 7);
|
||||
let x_h = fill(m * n, 5);
|
||||
let b_h = fill(n, 6);
|
||||
let w = fill(m * n, 7);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
||||
let bias = Var::leaf(cuda(&b_h, &[n]));
|
||||
let out = ops::add_bias(&x, &bias);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
|
||||
let bf = b_h.clone();
|
||||
let wf = w.clone();
|
||||
let lx =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
|
||||
report(
|
||||
"add_bias dX",
|
||||
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
let xf = x_h.clone();
|
||||
let wf = w.clone();
|
||||
let lb =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
|
||||
report(
|
||||
"add_bias dBias",
|
||||
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
|
||||
#[test]
|
||||
fn matmul_bwd() {
|
||||
require_gpu();
|
||||
let (m, k, n) = (6, 5, 4);
|
||||
let a_h = fill(m * k, 41);
|
||||
let b_h = fill(k * n, 42);
|
||||
let w = fill(m * n, 43);
|
||||
|
||||
let a = Var::leaf(cuda(&a_h, &[m, k]));
|
||||
let b = Var::leaf(cuda(&b_h, &[k, n]));
|
||||
let out = ops::matmul(&a, &b);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||||
let bf = b_h.clone();
|
||||
let wf = w.clone();
|
||||
let la =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
|
||||
report(
|
||||
"matmul dA",
|
||||
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
let af = a_h.clone();
|
||||
let wf = w.clone();
|
||||
let lb =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
|
||||
report(
|
||||
"matmul dB",
|
||||
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- rms_norm ----
|
||||
#[test]
|
||||
fn rms_norm_bwd() {
|
||||
require_gpu();
|
||||
let (rows, cols) = (5, 16);
|
||||
let eps = 1e-5;
|
||||
let x_h = fill(rows * cols, 51);
|
||||
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
|
||||
let w = fill(rows * cols, 53);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let gamma = Var::leaf(cuda(&g_h, &[cols]));
|
||||
let out = ops::rms_norm(&x, &gamma, eps);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
|
||||
let gf = g_h.clone();
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
|
||||
weighted_sum(&o, &wf)
|
||||
};
|
||||
report(
|
||||
"rms_norm dX",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[rows, cols],
|
||||
&lx,
|
||||
dx.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
let xf = x_h.clone();
|
||||
let wf = w.clone();
|
||||
let lg = move |v: &[f32], s: &[usize]| {
|
||||
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
|
||||
weighted_sum(&o, &wf)
|
||||
};
|
||||
report(
|
||||
"rms_norm dGamma",
|
||||
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- silu ----
|
||||
#[test]
|
||||
fn silu_bwd() {
|
||||
require_gpu();
|
||||
let n = 64;
|
||||
let x_h = fill(n, 61);
|
||||
let w = fill(n, 62);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[n]));
|
||||
let out = ops::silu(&x);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
|
||||
report(
|
||||
"silu dX",
|
||||
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- swiglu (composed: silu(gate) ∘ up) ----
|
||||
#[test]
|
||||
fn swiglu_bwd() {
|
||||
require_gpu();
|
||||
let n = 48;
|
||||
let g_h = fill(n, 71);
|
||||
let u_h = fill(n, 72);
|
||||
let w = fill(n, 73);
|
||||
|
||||
let gate = Var::leaf(cuda(&g_h, &[n]));
|
||||
let up = Var::leaf(cuda(&u_h, &[n]));
|
||||
let out = ops::swiglu(&gate, &up);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dg = gate.grad().unwrap().to_device(Device::Cpu);
|
||||
let du = up.grad().unwrap().to_device(Device::Cpu);
|
||||
let uf = u_h.clone();
|
||||
let wf = w.clone();
|
||||
let lg =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
|
||||
report(
|
||||
"swiglu dGate",
|
||||
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
let gf = g_h.clone();
|
||||
let wf = w.clone();
|
||||
let lu =
|
||||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
|
||||
report(
|
||||
"swiglu dUp",
|
||||
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- rope ----
|
||||
#[test]
|
||||
fn rope_bwd() {
|
||||
require_gpu();
|
||||
let (tokens, heads, head_dim) = (4, 2, 8);
|
||||
let n = tokens * heads * head_dim;
|
||||
let theta = 10000.0;
|
||||
let x_h = fill(n, 81);
|
||||
let w = fill(n, 82);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
||||
let out = ops::rope(&x, theta);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta), &wf);
|
||||
report(
|
||||
"rope dX",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[tokens, heads, head_dim],
|
||||
&lx,
|
||||
dx.as_slice::<f32>(),
|
||||
cfg_linear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- softmax ----
|
||||
#[test]
|
||||
fn softmax_bwd() {
|
||||
require_gpu();
|
||||
let (rows, cols) = (4, 10);
|
||||
let x_h = fill(rows * cols, 91);
|
||||
let w = fill(rows * cols, 92);
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let out = ops::softmax(&x);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
|
||||
report(
|
||||
"softmax dX",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[rows, cols],
|
||||
&lx,
|
||||
dx.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
|
||||
#[test]
|
||||
fn cross_entropy_bwd() {
|
||||
require_gpu();
|
||||
let (rows, cols) = (5, 8);
|
||||
let x_h = fill(rows * cols, 101);
|
||||
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
|
||||
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let loss = ops::cross_entropy(&x, &target);
|
||||
loss.backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
|
||||
let tgt = targets.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
|
||||
let (_, per_row) = cuda(v, s).cross_entropy(&t);
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.sum::<f32>()
|
||||
/ rows as f32
|
||||
};
|
||||
report(
|
||||
"cross_entropy dX",
|
||||
&grad_check(
|
||||
&x_h,
|
||||
&[rows, cols],
|
||||
&lx,
|
||||
dx.as_slice::<f32>(),
|
||||
cfg_nonlinear(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
|
||||
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
|
||||
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
|
||||
#[test]
|
||||
fn fanout_grad_accumulation() {
|
||||
require_gpu();
|
||||
let n = 12;
|
||||
let x_h = fill(n, 111);
|
||||
let w = vec![1.0f32; n];
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[n]));
|
||||
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
|
||||
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
|
||||
let out = ops::add(&sq1, &sq2); // 2x²
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
let wf = w.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let t = cuda(v, s);
|
||||
let o = t.mul(&t).add(&t.mul(&t));
|
||||
weighted_sum(&o, &wf)
|
||||
};
|
||||
// Analytic dx should be 4x; fan-out summed all four uses of x.
|
||||
report(
|
||||
"fanout dX",
|
||||
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
|
||||
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
|
||||
#[test]
|
||||
fn attention_composed_bwd() {
|
||||
require_gpu();
|
||||
let (s, d) = (5, 6); // seq_len, head_dim
|
||||
let scale = 1.0 / (d as f32).sqrt();
|
||||
let q_h = fill(s * d, 121);
|
||||
let k_h = fill(s * d, 122);
|
||||
let v_h = fill(s * d, 123);
|
||||
let w = fill(s * d, 124); // weights over the [s,d] attention output
|
||||
|
||||
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
|
||||
let kt = transpose_var(k); // [d,s] (manual transpose node)
|
||||
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
|
||||
let probs = ops::softmax(&scores);
|
||||
ops::matmul(&probs, v) // [s,d]
|
||||
};
|
||||
|
||||
let q = Var::leaf(cuda(&q_h, &[s, d]));
|
||||
let k = Var::leaf(cuda(&k_h, &[s, d]));
|
||||
let v = Var::leaf(cuda(&v_h, &[s, d]));
|
||||
let out = attn(&q, &k, &v);
|
||||
scalar_loss(&out, &w).backward();
|
||||
|
||||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
// Re-run the same forward inside the loss closures (host-side) per input.
|
||||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||||
let qv = cuda(qh, &[s, d]);
|
||||
let kv = cuda(kh, &[s, d]);
|
||||
let vv = cuda(vh, &[s, d]);
|
||||
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
|
||||
let probs = scores.softmax();
|
||||
weighted_sum(&probs.matmul(&vv), &w)
|
||||
};
|
||||
|
||||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||||
report(
|
||||
"attn dQ",
|
||||
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
|
||||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||||
report(
|
||||
"attn dK",
|
||||
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
|
||||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||||
report(
|
||||
"attn dV",
|
||||
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
|
||||
);
|
||||
}
|
||||
|
||||
// --- test helpers ---
|
||||
|
||||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||||
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
|
||||
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
|
||||
let wt = Var::leaf(cuda(w, out.value().shape()));
|
||||
let prod = ops::mul(out, &wt);
|
||||
sum_all(&prod)
|
||||
}
|
||||
|
||||
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
|
||||
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
|
||||
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
|
||||
fn sum_all(x: &Var) -> Var {
|
||||
let xv = x.value();
|
||||
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
|
||||
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
|
||||
let shape: Vec<usize> = xv.shape().to_vec();
|
||||
Var::from_op(
|
||||
scalar,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
// d is [1]; broadcast d to a same-shape tensor over the input.
|
||||
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let ones = vec![dval; shape.iter().product()];
|
||||
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
|
||||
Var::push_grad(&parents[0], g);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// Manual transpose node for the composed-attention test (the engine has no
|
||||
// transpose op; xserv does the equivalent host-side reshape around RoPE).
|
||||
fn transpose_var(x: &Var) -> Var {
|
||||
let xt = x.value().transpose_2d();
|
||||
Var::from_op(
|
||||
xt,
|
||||
vec![x.clone()],
|
||||
Box::new(|d, parents| {
|
||||
Var::push_grad(&parents[0], d.transpose_2d());
|
||||
}),
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user