ops: differentiable autograd nodes + per-op grad-check tests

ops.rs wraps each Tensor op as a Var node with its backward closure (forward
caches captured by move). swiglu = mul(silu(gate), up); attention is composed
(matmul+scale+softmax+matmul), no fused kernel. tests/autograd.rs grad-checks
every op via the L=sum(W∘out) template, plus a fan-out grad-accumulation test
(dL/dx=4x) and an end-to-end composed-attention grad-check (dQ/dK/dV). Adds
xtrain-cuda dev-dep for device selection in tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 15:47:43 +08:00
parent 224f750ee4
commit e7ce504b1f
4 changed files with 723 additions and 0 deletions

1
Cargo.lock generated
View File

@@ -92,6 +92,7 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
name = "xtrain-autodiff"
version = "0.1.0"
dependencies = [
"xtrain-cuda",
"xtrain-tensor",
]

View File

@@ -5,3 +5,7 @@ edition.workspace = true
[dependencies]
xtrain-tensor = { path = "../xtrain-tensor" }
[dev-dependencies]
# Acceptance tests need device selection (set_device) to drive the GPU.
xtrain-cuda = { path = "../xtrain-cuda" }

View File

@@ -0,0 +1,172 @@
//! Differentiable ops as autograd nodes (Phase T4).
//!
//! Each function runs the forward [`Tensor`] kernel, then builds a [`Var`] whose
//! backward closure computes the analytic gradient (see
//! `docs/03-autograd-engine.md` for the math) and pushes it to each parent via
//! [`Var::push_grad`] (which SUMs — correct under fan-out). Forward outputs that
//! the backward needs (softmax `y`, rms `inv_rms`, cross-entropy `probs`) are
//! cached by moving them into the closure.
//!
//! Attention is NOT a node here: it is composed from `matmul` + `scale` +
//! `softmax` in user code, and its backward falls out of theirs.
#![cfg(not(no_cuda))]
use crate::tape::Var;
use xtrain_tensor::Tensor;
/// `C = A @ B` (2D). Backward: `dA = dC @ Bᵀ`, `dB = Aᵀ @ dC`.
pub fn matmul(a: &Var, b: &Var) -> Var {
let out = a.value().matmul(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|dc, parents| {
let a = parents[0].value();
let b = parents[1].value();
let (da, db) = Tensor::matmul_backward(&a, &b, dc);
Var::push_grad(&parents[0], da);
Var::push_grad(&parents[1], db);
}),
)
}
/// Elementwise `out = a + b` (same shape). Backward: grad flows unchanged to both.
pub fn add(a: &Var, b: &Var) -> Var {
let out = a.value().add(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.clone());
Var::push_grad(&parents[1], d.clone());
}),
)
}
/// Elementwise `out = a * b` (Hadamard). Backward: `da = d∘b`, `db = d∘a`.
pub fn mul(a: &Var, b: &Var) -> Var {
let out = a.value().mul(&b.value());
Var::from_op(
out,
vec![a.clone(), b.clone()],
Box::new(|d, parents| {
let a = parents[0].value();
let b = parents[1].value();
Var::push_grad(&parents[0], d.mul(&b));
Var::push_grad(&parents[1], d.mul(&a));
}),
)
}
/// Broadcast bias add: `out[r,c] = x[r,c] + bias[c]`. Backward: `dx = d`,
/// `dbias[c] = sum_r d[r,c]` (sum over the broadcast dim).
pub fn add_bias(x: &Var, bias: &Var) -> Var {
let out = x.value().add_bias(&bias.value());
Var::from_op(
out,
vec![x.clone(), bias.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.clone());
Var::push_grad(&parents[1], d.sum_rows());
}),
)
}
/// Scale by a constant: `out = x * alpha`. Backward: `dx = d * alpha`.
pub fn scale(x: &Var, alpha: f32) -> Var {
let out = x.value().scale(alpha);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
Var::push_grad(&parents[0], d.scale(alpha));
}),
)
}
/// RMSNorm: `y = x * rsqrt(mean(x²)+eps) * gamma`. Caches `inv_rms` for backward.
pub fn rms_norm(x: &Var, gamma: &Var, eps: f32) -> Var {
let (y, inv_rms) = x.value().rms_norm(&gamma.value(), eps);
Var::from_op(
y,
vec![x.clone(), gamma.clone()],
Box::new(move |dy, parents| {
let x = parents[0].value();
let gamma = parents[1].value();
let (dx, dgamma) = Tensor::rms_norm_backward(&x, &gamma, dy, &inv_rms);
Var::push_grad(&parents[0], dx);
Var::push_grad(&parents[1], dgamma);
}),
)
}
/// SiLU: `y = x * sigmoid(x)`. Backward uses the forward `x`.
pub fn silu(x: &Var) -> Var {
let out = x.value().silu();
Var::from_op(
out,
vec![x.clone()],
Box::new(|dy, parents| {
let x = parents[0].value();
Var::push_grad(&parents[0], Tensor::silu_backward(&x, dy));
}),
)
}
/// SwiGLU (SiLU-gated GLU): `out = silu(gate) ∘ up`. Composed from `silu` + `mul`
/// so its backward comes from theirs — no dedicated kernel needed.
pub fn swiglu(gate: &Var, up: &Var) -> Var {
mul(&silu(gate), up)
}
/// RoPE (rotate_half) over `x:[tokens,heads,head_dim]`. Orthogonal map, so the
/// backward is the inverse rotation of `dy` — no cached forward values needed.
pub fn rope(x: &Var, theta: f32) -> Var {
let out = x.value().rope(theta);
Var::from_op(
out,
vec![x.clone()],
Box::new(move |dy, parents| {
Var::push_grad(&parents[0], Tensor::rope_backward(dy, theta));
}),
)
}
/// Row-wise softmax. Caches the output `y` for the Jacobian backward.
pub fn softmax(x: &Var) -> Var {
let y = x.value().softmax();
let y_cache = y.clone();
Var::from_op(
y,
vec![x.clone()],
Box::new(move |dy, parents| {
Var::push_grad(&parents[0], Tensor::softmax_backward(&y_cache, dy));
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.
pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
let (probs, per_row) = x.value().cross_entropy(target);
let rows = x.value().shape()[0];
// Mean loss as a host scalar wrapped back into a [1] tensor.
let mean = per_row.to_device(xtrain_tensor::Device::Cpu);
let mean_val: f32 = mean.as_slice::<f32>().iter().sum::<f32>() / rows as f32;
let loss = Tensor::from_slice(&[mean_val], &[1]).to_device(x.value().device());
let target = target.clone();
Var::from_op(
loss,
vec![x.clone()],
Box::new(move |d, parents| {
// `d` is the scalar upstream grad (1.0 when this is the loss root).
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
let scale = upstream / rows as f32;
let dx = Tensor::cross_entropy_backward(&probs, &target, scale);
Var::push_grad(&parents[0], dx);
}),
)
}

View File

@@ -0,0 +1,546 @@
// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
// through the tape, call backward(), and grad-check each input's .grad() against
// central finite differences of L.
//
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_autodiff::{GradCheckConfig, grad_check};
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
fn fill(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
}
// L = sum(W ∘ out) for fixed weights W over the op output.
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
out.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.zip(w)
.map(|(o, w)| o * w)
.sum()
}
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
// → smaller eps. atol floors near-zero grads.
fn cfg_linear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
}
}
fn cfg_nonlinear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-3,
rel_tol: 3e-2,
atol: 1e-3,
}
}
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
println!(
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
);
assert!(res.passed, "{name} grad-check failed: {res:?}");
}
// ---- add ----
#[test]
fn add_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 1);
let b_h = fill(m * n, 2);
let w = fill(m * n, 3);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::add(&a, &b);
let loss = scalar_loss(&out, &w);
loss.backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| {
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
weighted_sum(&o, &wf)
};
report(
"add dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| {
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
weighted_sum(&o, &wf)
};
report(
"add dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- mul ----
#[test]
fn mul_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 11);
let b_h = fill(m * n, 22);
let w = fill(m * n, 33);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::mul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
report(
"mul dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
report(
"mul dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- add_bias (broadcast) ----
#[test]
fn add_bias_bwd() {
require_gpu();
let (m, n) = (10, 7);
let x_h = fill(m * n, 5);
let b_h = fill(n, 6);
let w = fill(m * n, 7);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let bias = Var::leaf(cuda(&b_h, &[n]));
let out = ops::add_bias(&x, &bias);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let lx =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
report(
"add_bias dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
let xf = x_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
report(
"add_bias dBias",
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
);
}
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
#[test]
fn matmul_bwd() {
require_gpu();
let (m, k, n) = (6, 5, 4);
let a_h = fill(m * k, 41);
let b_h = fill(k * n, 42);
let w = fill(m * n, 43);
let a = Var::leaf(cuda(&a_h, &[m, k]));
let b = Var::leaf(cuda(&b_h, &[k, n]));
let out = ops::matmul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
report(
"matmul dA",
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
report(
"matmul dB",
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rms_norm ----
#[test]
fn rms_norm_bwd() {
require_gpu();
let (rows, cols) = (5, 16);
let eps = 1e-5;
let x_h = fill(rows * cols, 51);
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
let w = fill(rows * cols, 53);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let gamma = Var::leaf(cuda(&g_h, &[cols]));
let out = ops::rms_norm(&x, &gamma, eps);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
let gf = g_h.clone();
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let xf = x_h.clone();
let wf = w.clone();
let lg = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dGamma",
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- silu ----
#[test]
fn silu_bwd() {
require_gpu();
let n = 64;
let x_h = fill(n, 61);
let w = fill(n, 62);
let x = Var::leaf(cuda(&x_h, &[n]));
let out = ops::silu(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
report(
"silu dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- swiglu (composed: silu(gate) ∘ up) ----
#[test]
fn swiglu_bwd() {
require_gpu();
let n = 48;
let g_h = fill(n, 71);
let u_h = fill(n, 72);
let w = fill(n, 73);
let gate = Var::leaf(cuda(&g_h, &[n]));
let up = Var::leaf(cuda(&u_h, &[n]));
let out = ops::swiglu(&gate, &up);
scalar_loss(&out, &w).backward();
let dg = gate.grad().unwrap().to_device(Device::Cpu);
let du = up.grad().unwrap().to_device(Device::Cpu);
let uf = u_h.clone();
let wf = w.clone();
let lg =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
report(
"swiglu dGate",
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
let gf = g_h.clone();
let wf = w.clone();
let lu =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
report(
"swiglu dUp",
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rope ----
#[test]
fn rope_bwd() {
require_gpu();
let (tokens, heads, head_dim) = (4, 2, 8);
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 81);
let w = fill(n, 82);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta), &wf);
report(
"rope dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- softmax ----
#[test]
fn softmax_bwd() {
require_gpu();
let (rows, cols) = (4, 10);
let x_h = fill(rows * cols, 91);
let w = fill(rows * cols, 92);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let out = ops::softmax(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
report(
"softmax dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
#[test]
fn cross_entropy_bwd() {
require_gpu();
let (rows, cols) = (5, 8);
let x_h = fill(rows * cols, 101);
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::cross_entropy(&x, &target);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
/ rows as f32
};
report(
"cross_entropy dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
#[test]
fn fanout_grad_accumulation() {
require_gpu();
let n = 12;
let x_h = fill(n, 111);
let w = vec![1.0f32; n];
let x = Var::leaf(cuda(&x_h, &[n]));
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
let out = ops::add(&sq1, &sq2); // 2x²
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = cuda(v, s);
let o = t.mul(&t).add(&t.mul(&t));
weighted_sum(&o, &wf)
};
// Analytic dx should be 4x; fan-out summed all four uses of x.
report(
"fanout dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
#[test]
fn attention_composed_bwd() {
require_gpu();
let (s, d) = (5, 6); // seq_len, head_dim
let scale = 1.0 / (d as f32).sqrt();
let q_h = fill(s * d, 121);
let k_h = fill(s * d, 122);
let v_h = fill(s * d, 123);
let w = fill(s * d, 124); // weights over the [s,d] attention output
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
let kt = transpose_var(k); // [d,s] (manual transpose node)
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
let probs = ops::softmax(&scores);
ops::matmul(&probs, v) // [s,d]
};
let q = Var::leaf(cuda(&q_h, &[s, d]));
let k = Var::leaf(cuda(&k_h, &[s, d]));
let v = Var::leaf(cuda(&v_h, &[s, d]));
let out = attn(&q, &k, &v);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
// Re-run the same forward inside the loss closures (host-side) per input.
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[s, d]);
let kv = cuda(kh, &[s, d]);
let vv = cuda(vh, &[s, d]);
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
let probs = scores.softmax();
weighted_sum(&probs.matmul(&vv), &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn dQ",
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn dK",
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn dV",
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
);
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
let wt = Var::leaf(cuda(w, out.value().shape()));
let prod = ops::mul(out, &wt);
sum_all(&prod)
}
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
fn sum_all(x: &Var) -> Var {
let xv = x.value();
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
let shape: Vec<usize> = xv.shape().to_vec();
Var::from_op(
scalar,
vec![x.clone()],
Box::new(move |d, parents| {
// d is [1]; broadcast d to a same-shape tensor over the input.
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ones = vec![dval; shape.iter().product()];
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
Var::push_grad(&parents[0], g);
}),
)
}
// Manual transpose node for the composed-attention test (the engine has no
// transpose op; xserv does the equivalent host-side reshape around RoPE).
fn transpose_var(x: &Var) -> Var {
let xt = x.value().transpose_2d();
Var::from_op(
xt,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_2d());
}),
)
}