Files
xtrain/crates/xtrain-autodiff/tests/autograd.rs

822 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
// through the tape, call backward(), and grad-check each input's .grad() against
// central finite differences of L.
//
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_autodiff::{GradCheckConfig, grad_check};
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
fn fill(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
}
// L = sum(W ∘ out) for fixed weights W over the op output.
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
out.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.zip(w)
.map(|(o, w)| o * w)
.sum()
}
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
// → smaller eps. atol floors near-zero grads.
fn cfg_linear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
}
}
fn cfg_nonlinear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-3,
rel_tol: 3e-2,
atol: 1e-3,
}
}
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
println!(
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
);
assert!(res.passed, "{name} grad-check failed: {res:?}");
}
// ---- add ----
#[test]
fn add_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 1);
let b_h = fill(m * n, 2);
let w = fill(m * n, 3);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::add(&a, &b);
let loss = scalar_loss(&out, &w);
loss.backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| {
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
weighted_sum(&o, &wf)
};
report(
"add dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| {
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
weighted_sum(&o, &wf)
};
report(
"add dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- mul ----
#[test]
fn mul_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 11);
let b_h = fill(m * n, 22);
let w = fill(m * n, 33);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::mul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
report(
"mul dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
report(
"mul dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- add_bias (broadcast) ----
#[test]
fn add_bias_bwd() {
require_gpu();
let (m, n) = (10, 7);
let x_h = fill(m * n, 5);
let b_h = fill(n, 6);
let w = fill(m * n, 7);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let bias = Var::leaf(cuda(&b_h, &[n]));
let out = ops::add_bias(&x, &bias);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let lx =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
report(
"add_bias dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
let xf = x_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
report(
"add_bias dBias",
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
);
}
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
#[test]
fn matmul_bwd() {
require_gpu();
let (m, k, n) = (6, 5, 4);
let a_h = fill(m * k, 41);
let b_h = fill(k * n, 42);
let w = fill(m * n, 43);
let a = Var::leaf(cuda(&a_h, &[m, k]));
let b = Var::leaf(cuda(&b_h, &[k, n]));
let out = ops::matmul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
report(
"matmul dA",
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
report(
"matmul dB",
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rms_norm ----
#[test]
fn rms_norm_bwd() {
require_gpu();
let (rows, cols) = (5, 16);
let eps = 1e-5;
let x_h = fill(rows * cols, 51);
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
let w = fill(rows * cols, 53);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let gamma = Var::leaf(cuda(&g_h, &[cols]));
let out = ops::rms_norm(&x, &gamma, eps);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
let gf = g_h.clone();
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let xf = x_h.clone();
let wf = w.clone();
let lg = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dGamma",
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- silu ----
#[test]
fn silu_bwd() {
require_gpu();
let n = 64;
let x_h = fill(n, 61);
let w = fill(n, 62);
let x = Var::leaf(cuda(&x_h, &[n]));
let out = ops::silu(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
report(
"silu dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- swiglu (composed: silu(gate) ∘ up) ----
#[test]
fn swiglu_bwd() {
require_gpu();
let n = 48;
let g_h = fill(n, 71);
let u_h = fill(n, 72);
let w = fill(n, 73);
let gate = Var::leaf(cuda(&g_h, &[n]));
let up = Var::leaf(cuda(&u_h, &[n]));
let out = ops::swiglu(&gate, &up);
scalar_loss(&out, &w).backward();
let dg = gate.grad().unwrap().to_device(Device::Cpu);
let du = up.grad().unwrap().to_device(Device::Cpu);
let uf = u_h.clone();
let wf = w.clone();
let lg =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
report(
"swiglu dGate",
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
let gf = g_h.clone();
let wf = w.clone();
let lu =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
report(
"swiglu dUp",
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rope ----
#[test]
fn rope_bwd() {
require_gpu();
let (tokens, heads, head_dim) = (4, 2, 8);
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 81);
let w = fill(n, 82);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, tokens);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
report(
"rope dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- rope batched (per-sequence position = row % period) ----
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
#[test]
fn rope_batched_bwd() {
require_gpu();
let (b, s, heads, head_dim) = (3, 4, 2, 8);
let tokens = b * s;
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 83);
let w = fill(n, 84);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, s);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
report(
"rope batched dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- softmax ----
#[test]
fn softmax_bwd() {
require_gpu();
let (rows, cols) = (4, 10);
let x_h = fill(rows * cols, 91);
let w = fill(rows * cols, 92);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let out = ops::softmax(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
report(
"softmax dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
#[test]
fn cross_entropy_bwd() {
require_gpu();
let (rows, cols) = (5, 8);
let x_h = fill(rows * cols, 101);
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::cross_entropy(&x, &target);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
/ rows as f32
};
report(
"cross_entropy dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
#[test]
fn fanout_grad_accumulation() {
require_gpu();
let n = 12;
let x_h = fill(n, 111);
let w = vec![1.0f32; n];
let x = Var::leaf(cuda(&x_h, &[n]));
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
let out = ops::add(&sq1, &sq2); // 2x²
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = cuda(v, s);
let o = t.mul(&t).add(&t.mul(&t));
weighted_sum(&o, &wf)
};
// Analytic dx should be 4x; fan-out summed all four uses of x.
report(
"fanout dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
#[test]
fn attention_composed_bwd() {
require_gpu();
let (s, d) = (5, 6); // seq_len, head_dim
let scale = 1.0 / (d as f32).sqrt();
let q_h = fill(s * d, 121);
let k_h = fill(s * d, 122);
let v_h = fill(s * d, 123);
let w = fill(s * d, 124); // weights over the [s,d] attention output
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
let kt = transpose_var(k); // [d,s] (manual transpose node)
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
let probs = ops::softmax(&scores);
ops::matmul(&probs, v) // [s,d]
};
let q = Var::leaf(cuda(&q_h, &[s, d]));
let k = Var::leaf(cuda(&k_h, &[s, d]));
let v = Var::leaf(cuda(&v_h, &[s, d]));
let out = attn(&q, &k, &v);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
// Re-run the same forward inside the loss closures (host-side) per input.
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[s, d]);
let kv = cuda(kh, &[s, d]);
let vv = cuda(vh, &[s, d]);
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
let probs = scores.softmax();
weighted_sum(&probs.matmul(&vv), &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn dQ",
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn dK",
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn dV",
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
);
}
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
#[test]
fn transpose_4d12_bwd() {
require_gpu();
let (a, b, c, d) = (2, 3, 4, 5);
let n = a * b * c * d;
let x_h = fill(n, 131);
let w = fill(n, 132);
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
let out = ops::transpose_4d12(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
report(
"transpose_4d12 dX",
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- fused batched causal attention (the T10 op) ----
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
// GEMM stride; the causal mask is applied inside the op.
#[test]
fn attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 141);
let k_h = fill(n, 142);
let v_h = fill(n, 143);
let w = fill(n, 144);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn(batched) dQ",
&grad_check(
&q_h,
&[bh, seq, hd],
&lq,
dq.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn(batched) dK",
&grad_check(
&k_h,
&[bh, seq, hd],
&lk,
dk.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn(batched) dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
// produces.
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
// gradients are well-conditioned, not the near-zero saddle values a uniform
// softmax produces (those make central finite-diff give spurious 0.0 / sign
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
let v_h = fill(n, 243);
let w = fill(n, 244);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::flash_attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.flash_attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
let cfg_attn = GradCheckConfig {
eps: 2e-3,
rel_tol: 3e-2,
atol: 1e-3,
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"flash dQ",
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"flash dK",
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"flash dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// flash forward must equal the composed attention forward (same SDPA math).
#[test]
fn flash_matches_composed_fwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
let (oc, _) = q.attention(&k, &v, scale);
let (of, _) = q.flash_attention(&k, &v, scale);
let oc = oc.to_device(Device::Cpu);
let of = of.to_device(Device::Cpu);
let max_rel = oc
.as_slice::<f32>()
.iter()
.zip(of.as_slice::<f32>())
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
.fold(0.0f32, f32::max);
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
assert!(
max_rel < 1e-4,
"flash fwd diverges from composed: {max_rel:.3e}"
);
}
// flash backward must equal the (already grad-checked) composed backward. This is
// a sharper test than finite-diff: both share the trusted composed forward as the
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
// near-zero gradient elements.
#[test]
fn flash_bwd_matches_composed_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 441);
let k_h = fill(n, 442);
let v_h = fill(n, 443);
let w = fill(n, 444);
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = if flash {
ops::flash_attention(&q, &k, &v, scale)
} else {
ops::attention(&q, &k, &v, scale)
};
scalar_loss(&out, &w).backward();
let g = |x: &Var| {
x.grad()
.unwrap()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
};
(g(&q), g(&k), g(&v))
};
let (cq, ck, cv) = run(false);
let (fq, fk, fv) = run(true);
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
.fold(0.0f32, f32::max)
};
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
let wt = Var::leaf(cuda(w, out.value().shape()));
let prod = ops::mul(out, &wt);
sum_all(&prod)
}
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
fn sum_all(x: &Var) -> Var {
let xv = x.value();
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
let shape: Vec<usize> = xv.shape().to_vec();
Var::from_op(
scalar,
vec![x.clone()],
Box::new(move |d, parents| {
// d is [1]; broadcast d to a same-shape tensor over the input.
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ones = vec![dval; shape.iter().product()];
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
Var::push_grad(&parents[0], g);
}),
)
}
// Manual transpose node for the composed-attention test (the engine has no
// transpose op; xserv does the equivalent host-side reshape around RoPE).
fn transpose_var(x: &Var) -> Var {
let xt = x.value().transpose_2d();
Var::from_op(
xt,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_2d());
}),
)
}