Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> # Conflicts: # README.md # crates/xtrain-autodiff/tests/autograd.rs # crates/xtrain-model/src/model.rs # crates/xtrain-train/src/bin/train.rs # crates/xtrain-train/src/train_loop.rs # docs/evolution.md
912 lines
30 KiB
Rust
912 lines
30 KiB
Rust
// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
|
||
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
|
||
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
|
||
// through the tape, call backward(), and grad-check each input's .grad() against
|
||
// central finite differences of L.
|
||
//
|
||
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use xtrain_autodiff::ops;
|
||
use xtrain_autodiff::tape::Var;
|
||
use xtrain_autodiff::{GradCheckConfig, grad_check};
|
||
use xtrain_cuda::device;
|
||
use xtrain_tensor::{Device, Tensor};
|
||
|
||
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
|
||
fn fill(n: usize, seed: u64) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn require_gpu() {
|
||
assert!(
|
||
device::device_count().expect("device count") > 0,
|
||
"no CUDA device"
|
||
);
|
||
device::set_device(0).unwrap();
|
||
}
|
||
|
||
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
|
||
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
|
||
}
|
||
|
||
// L = sum(W ∘ out) for fixed weights W over the op output.
|
||
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
|
||
out.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.iter()
|
||
.zip(w)
|
||
.map(|(o, w)| o * w)
|
||
.sum()
|
||
}
|
||
|
||
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
|
||
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
|
||
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
|
||
// → smaller eps. atol floors near-zero grads.
|
||
fn cfg_linear() -> GradCheckConfig {
|
||
GradCheckConfig {
|
||
eps: 1e-2,
|
||
rel_tol: 2e-2,
|
||
atol: 1e-3,
|
||
}
|
||
}
|
||
fn cfg_nonlinear() -> GradCheckConfig {
|
||
GradCheckConfig {
|
||
eps: 1e-3,
|
||
rel_tol: 3e-2,
|
||
atol: 1e-3,
|
||
}
|
||
}
|
||
|
||
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
|
||
println!(
|
||
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
||
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
|
||
);
|
||
assert!(res.passed, "{name} grad-check failed: {res:?}");
|
||
}
|
||
|
||
// ---- add ----
|
||
#[test]
|
||
fn add_bwd() {
|
||
require_gpu();
|
||
let (m, n) = (8, 6);
|
||
let a_h = fill(m * n, 1);
|
||
let b_h = fill(m * n, 2);
|
||
let w = fill(m * n, 3);
|
||
|
||
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
||
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
||
let out = ops::add(&a, &b);
|
||
let loss = scalar_loss(&out, &w);
|
||
loss.backward();
|
||
|
||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||
let bf = b_h.clone();
|
||
let wf = w.clone();
|
||
let la = move |v: &[f32], s: &[usize]| {
|
||
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
|
||
weighted_sum(&o, &wf)
|
||
};
|
||
report(
|
||
"add dA",
|
||
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
let af = a_h.clone();
|
||
let wf = w.clone();
|
||
let lb = move |v: &[f32], s: &[usize]| {
|
||
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
|
||
weighted_sum(&o, &wf)
|
||
};
|
||
report(
|
||
"add dB",
|
||
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- mul ----
|
||
#[test]
|
||
fn mul_bwd() {
|
||
require_gpu();
|
||
let (m, n) = (8, 6);
|
||
let a_h = fill(m * n, 11);
|
||
let b_h = fill(m * n, 22);
|
||
let w = fill(m * n, 33);
|
||
|
||
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
||
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
||
let out = ops::mul(&a, &b);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||
let bf = b_h.clone();
|
||
let wf = w.clone();
|
||
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
|
||
report(
|
||
"mul dA",
|
||
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
let af = a_h.clone();
|
||
let wf = w.clone();
|
||
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
|
||
report(
|
||
"mul dB",
|
||
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- add_bias (broadcast) ----
|
||
#[test]
|
||
fn add_bias_bwd() {
|
||
require_gpu();
|
||
let (m, n) = (10, 7);
|
||
let x_h = fill(m * n, 5);
|
||
let b_h = fill(n, 6);
|
||
let w = fill(m * n, 7);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
||
let bias = Var::leaf(cuda(&b_h, &[n]));
|
||
let out = ops::add_bias(&x, &bias);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
|
||
let bf = b_h.clone();
|
||
let wf = w.clone();
|
||
let lx =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
|
||
report(
|
||
"add_bias dX",
|
||
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
let xf = x_h.clone();
|
||
let wf = w.clone();
|
||
let lb =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
|
||
report(
|
||
"add_bias dBias",
|
||
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
|
||
#[test]
|
||
fn matmul_bwd() {
|
||
require_gpu();
|
||
let (m, k, n) = (6, 5, 4);
|
||
let a_h = fill(m * k, 41);
|
||
let b_h = fill(k * n, 42);
|
||
let w = fill(m * n, 43);
|
||
|
||
let a = Var::leaf(cuda(&a_h, &[m, k]));
|
||
let b = Var::leaf(cuda(&b_h, &[k, n]));
|
||
let out = ops::matmul(&a, &b);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let da = a.grad().unwrap().to_device(Device::Cpu);
|
||
let db = b.grad().unwrap().to_device(Device::Cpu);
|
||
let bf = b_h.clone();
|
||
let wf = w.clone();
|
||
let la =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
|
||
report(
|
||
"matmul dA",
|
||
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
let af = a_h.clone();
|
||
let wf = w.clone();
|
||
let lb =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
|
||
report(
|
||
"matmul dB",
|
||
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- rms_norm ----
|
||
#[test]
|
||
fn rms_norm_bwd() {
|
||
require_gpu();
|
||
let (rows, cols) = (5, 16);
|
||
let eps = 1e-5;
|
||
let x_h = fill(rows * cols, 51);
|
||
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
|
||
let w = fill(rows * cols, 53);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||
let gamma = Var::leaf(cuda(&g_h, &[cols]));
|
||
let out = ops::rms_norm(&x, &gamma, eps);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
|
||
let gf = g_h.clone();
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| {
|
||
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
|
||
weighted_sum(&o, &wf)
|
||
};
|
||
report(
|
||
"rms_norm dX",
|
||
&grad_check(
|
||
&x_h,
|
||
&[rows, cols],
|
||
&lx,
|
||
dx.as_slice::<f32>(),
|
||
cfg_nonlinear(),
|
||
),
|
||
);
|
||
let xf = x_h.clone();
|
||
let wf = w.clone();
|
||
let lg = move |v: &[f32], s: &[usize]| {
|
||
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
|
||
weighted_sum(&o, &wf)
|
||
};
|
||
report(
|
||
"rms_norm dGamma",
|
||
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
||
);
|
||
}
|
||
|
||
// ---- silu ----
|
||
#[test]
|
||
fn silu_bwd() {
|
||
require_gpu();
|
||
let n = 64;
|
||
let x_h = fill(n, 61);
|
||
let w = fill(n, 62);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[n]));
|
||
let out = ops::silu(&x);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
|
||
report(
|
||
"silu dX",
|
||
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||
);
|
||
}
|
||
|
||
// ---- swiglu (composed: silu(gate) ∘ up) ----
|
||
#[test]
|
||
fn swiglu_bwd() {
|
||
require_gpu();
|
||
let n = 48;
|
||
let g_h = fill(n, 71);
|
||
let u_h = fill(n, 72);
|
||
let w = fill(n, 73);
|
||
|
||
let gate = Var::leaf(cuda(&g_h, &[n]));
|
||
let up = Var::leaf(cuda(&u_h, &[n]));
|
||
let out = ops::swiglu(&gate, &up);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dg = gate.grad().unwrap().to_device(Device::Cpu);
|
||
let du = up.grad().unwrap().to_device(Device::Cpu);
|
||
let uf = u_h.clone();
|
||
let wf = w.clone();
|
||
let lg =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
|
||
report(
|
||
"swiglu dGate",
|
||
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
||
);
|
||
let gf = g_h.clone();
|
||
let wf = w.clone();
|
||
let lu =
|
||
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
|
||
report(
|
||
"swiglu dUp",
|
||
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- rope ----
|
||
#[test]
|
||
fn rope_bwd() {
|
||
require_gpu();
|
||
let (tokens, heads, head_dim) = (4, 2, 8);
|
||
let n = tokens * heads * head_dim;
|
||
let theta = 10000.0;
|
||
let x_h = fill(n, 81);
|
||
let w = fill(n, 82);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
||
let out = ops::rope(&x, theta, tokens);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
|
||
report(
|
||
"rope dX",
|
||
&grad_check(
|
||
&x_h,
|
||
&[tokens, heads, head_dim],
|
||
&lx,
|
||
dx.as_slice::<f32>(),
|
||
cfg_linear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// ---- rope batched (per-sequence position = row % period) ----
|
||
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
|
||
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
|
||
#[test]
|
||
fn rope_batched_bwd() {
|
||
require_gpu();
|
||
let (b, s, heads, head_dim) = (3, 4, 2, 8);
|
||
let tokens = b * s;
|
||
let n = tokens * heads * head_dim;
|
||
let theta = 10000.0;
|
||
let x_h = fill(n, 83);
|
||
let w = fill(n, 84);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
||
let out = ops::rope(&x, theta, s);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
|
||
report(
|
||
"rope batched dX",
|
||
&grad_check(
|
||
&x_h,
|
||
&[tokens, heads, head_dim],
|
||
&lx,
|
||
dx.as_slice::<f32>(),
|
||
cfg_linear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// ---- softmax ----
|
||
#[test]
|
||
fn softmax_bwd() {
|
||
require_gpu();
|
||
let (rows, cols) = (4, 10);
|
||
let x_h = fill(rows * cols, 91);
|
||
let w = fill(rows * cols, 92);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||
let out = ops::softmax(&x);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
|
||
report(
|
||
"softmax dX",
|
||
&grad_check(
|
||
&x_h,
|
||
&[rows, cols],
|
||
&lx,
|
||
dx.as_slice::<f32>(),
|
||
cfg_nonlinear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
|
||
#[test]
|
||
fn cross_entropy_bwd() {
|
||
require_gpu();
|
||
let (rows, cols) = (5, 8);
|
||
let x_h = fill(rows * cols, 101);
|
||
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
|
||
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||
let loss = ops::cross_entropy(&x, &target);
|
||
loss.backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
|
||
let tgt = targets.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| {
|
||
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
|
||
let (_, per_row) = cuda(v, s).cross_entropy(&t);
|
||
per_row
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.iter()
|
||
.sum::<f32>()
|
||
/ rows as f32
|
||
};
|
||
report(
|
||
"cross_entropy dX",
|
||
&grad_check(
|
||
&x_h,
|
||
&[rows, cols],
|
||
&lx,
|
||
dx.as_slice::<f32>(),
|
||
cfg_nonlinear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
|
||
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
|
||
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
|
||
#[test]
|
||
fn fanout_grad_accumulation() {
|
||
require_gpu();
|
||
let n = 12;
|
||
let x_h = fill(n, 111);
|
||
let w = vec![1.0f32; n];
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[n]));
|
||
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
|
||
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
|
||
let out = ops::add(&sq1, &sq2); // 2x²
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| {
|
||
let t = cuda(v, s);
|
||
let o = t.mul(&t).add(&t.mul(&t));
|
||
weighted_sum(&o, &wf)
|
||
};
|
||
// Analytic dx should be 4x; fan-out summed all four uses of x.
|
||
report(
|
||
"fanout dX",
|
||
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
|
||
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
|
||
#[test]
|
||
fn attention_composed_bwd() {
|
||
require_gpu();
|
||
let (s, d) = (5, 6); // seq_len, head_dim
|
||
let scale = 1.0 / (d as f32).sqrt();
|
||
let q_h = fill(s * d, 121);
|
||
let k_h = fill(s * d, 122);
|
||
let v_h = fill(s * d, 123);
|
||
let w = fill(s * d, 124); // weights over the [s,d] attention output
|
||
|
||
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
|
||
let kt = transpose_var(k); // [d,s] (manual transpose node)
|
||
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
|
||
let probs = ops::softmax(&scores);
|
||
ops::matmul(&probs, v) // [s,d]
|
||
};
|
||
|
||
let q = Var::leaf(cuda(&q_h, &[s, d]));
|
||
let k = Var::leaf(cuda(&k_h, &[s, d]));
|
||
let v = Var::leaf(cuda(&v_h, &[s, d]));
|
||
let out = attn(&q, &k, &v);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||
|
||
// Re-run the same forward inside the loss closures (host-side) per input.
|
||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||
let qv = cuda(qh, &[s, d]);
|
||
let kv = cuda(kh, &[s, d]);
|
||
let vv = cuda(vh, &[s, d]);
|
||
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
|
||
let probs = scores.softmax();
|
||
weighted_sum(&probs.matmul(&vv), &w)
|
||
};
|
||
|
||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||
report(
|
||
"attn dQ",
|
||
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
|
||
);
|
||
|
||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||
report(
|
||
"attn dK",
|
||
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
|
||
);
|
||
|
||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||
report(
|
||
"attn dV",
|
||
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
|
||
#[test]
|
||
fn transpose_4d12_bwd() {
|
||
require_gpu();
|
||
let (a, b, c, d) = (2, 3, 4, 5);
|
||
let n = a * b * c * d;
|
||
let x_h = fill(n, 131);
|
||
let w = fill(n, 132);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
|
||
let out = ops::transpose_4d12(&x);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
|
||
report(
|
||
"transpose_4d12 dX",
|
||
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// ---- fused batched causal attention (the T10 op) ----
|
||
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
|
||
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
|
||
// GEMM stride; the causal mask is applied inside the op.
|
||
#[test]
|
||
fn attention_batched_bwd() {
|
||
require_gpu();
|
||
let (bh, seq, hd) = (2, 5, 6);
|
||
let n = bh * seq * hd;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
let q_h = fill(n, 141);
|
||
let k_h = fill(n, 142);
|
||
let v_h = fill(n, 143);
|
||
let w = fill(n, 144);
|
||
|
||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||
let out = ops::attention(&q, &k, &v, scale);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||
|
||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||
let qv = cuda(qh, &[bh, seq, hd]);
|
||
let kv = cuda(kh, &[bh, seq, hd]);
|
||
let vv = cuda(vh, &[bh, seq, hd]);
|
||
let (o, _) = qv.attention(&kv, &vv, scale);
|
||
weighted_sum(&o, &w)
|
||
};
|
||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||
report(
|
||
"attn(batched) dQ",
|
||
&grad_check(
|
||
&q_h,
|
||
&[bh, seq, hd],
|
||
&lq,
|
||
dq.as_slice::<f32>(),
|
||
cfg_nonlinear(),
|
||
),
|
||
);
|
||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||
report(
|
||
"attn(batched) dK",
|
||
&grad_check(
|
||
&k_h,
|
||
&[bh, seq, hd],
|
||
&lk,
|
||
dk.as_slice::<f32>(),
|
||
cfg_nonlinear(),
|
||
),
|
||
);
|
||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||
report(
|
||
"attn(batched) dV",
|
||
&grad_check(
|
||
&v_h,
|
||
&[bh, seq, hd],
|
||
&lv,
|
||
dv.as_slice::<f32>(),
|
||
cfg_linear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// ---- fused FLASH causal attention (the T14 op) ----
|
||
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
|
||
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
|
||
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
|
||
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
|
||
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
|
||
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
|
||
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
|
||
// produces.
|
||
#[test]
|
||
fn flash_attention_batched_bwd() {
|
||
require_gpu();
|
||
let (bh, seq, hd) = (2, 5, 6);
|
||
let n = bh * seq * hd;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
|
||
// gradients are well-conditioned, not the near-zero saddle values a uniform
|
||
// softmax produces (those make central finite-diff give spurious 0.0 / sign
|
||
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
|
||
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
|
||
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
|
||
let v_h = fill(n, 243);
|
||
let w = fill(n, 244);
|
||
|
||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||
let out = ops::flash_attention(&q, &k, &v, scale);
|
||
scalar_loss(&out, &w).backward();
|
||
|
||
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
||
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
||
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
||
|
||
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
||
let qv = cuda(qh, &[bh, seq, hd]);
|
||
let kv = cuda(kh, &[bh, seq, hd]);
|
||
let vv = cuda(vh, &[bh, seq, hd]);
|
||
let (o, _) = qv.flash_attention(&kv, &vv, scale);
|
||
weighted_sum(&o, &w)
|
||
};
|
||
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
|
||
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
|
||
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
|
||
let cfg_attn = GradCheckConfig {
|
||
eps: 2e-3,
|
||
rel_tol: 3e-2,
|
||
atol: 1e-3,
|
||
};
|
||
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
||
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
||
report(
|
||
"flash dQ",
|
||
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
|
||
);
|
||
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
||
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
||
report(
|
||
"flash dK",
|
||
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
|
||
);
|
||
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
||
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
||
report(
|
||
"flash dV",
|
||
&grad_check(
|
||
&v_h,
|
||
&[bh, seq, hd],
|
||
&lv,
|
||
dv.as_slice::<f32>(),
|
||
cfg_linear(),
|
||
),
|
||
);
|
||
}
|
||
|
||
// flash forward must equal the composed attention forward (same SDPA math).
|
||
#[test]
|
||
fn flash_matches_composed_fwd() {
|
||
require_gpu();
|
||
let (bh, seq, hd) = (2, 40, 16);
|
||
let n = bh * seq * hd;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
|
||
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
|
||
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
|
||
let (oc, _) = q.attention(&k, &v, scale);
|
||
let (of, _) = q.flash_attention(&k, &v, scale);
|
||
let oc = oc.to_device(Device::Cpu);
|
||
let of = of.to_device(Device::Cpu);
|
||
let max_rel = oc
|
||
.as_slice::<f32>()
|
||
.iter()
|
||
.zip(of.as_slice::<f32>())
|
||
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
|
||
.fold(0.0f32, f32::max);
|
||
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
|
||
assert!(
|
||
max_rel < 1e-4,
|
||
"flash fwd diverges from composed: {max_rel:.3e}"
|
||
);
|
||
}
|
||
|
||
// flash backward must equal the (already grad-checked) composed backward. This is
|
||
// a sharper test than finite-diff: both share the trusted composed forward as the
|
||
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
|
||
// near-zero gradient elements.
|
||
#[test]
|
||
fn flash_bwd_matches_composed_bwd() {
|
||
require_gpu();
|
||
let (bh, seq, hd) = (2, 40, 16);
|
||
let n = bh * seq * hd;
|
||
let scale = 1.0 / (hd as f32).sqrt();
|
||
let q_h = fill(n, 441);
|
||
let k_h = fill(n, 442);
|
||
let v_h = fill(n, 443);
|
||
let w = fill(n, 444);
|
||
|
||
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
|
||
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
|
||
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
|
||
let out = if flash {
|
||
ops::flash_attention(&q, &k, &v, scale)
|
||
} else {
|
||
ops::attention(&q, &k, &v, scale)
|
||
};
|
||
scalar_loss(&out, &w).backward();
|
||
let g = |x: &Var| {
|
||
x.grad()
|
||
.unwrap()
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.to_vec()
|
||
};
|
||
(g(&q), g(&k), g(&v))
|
||
};
|
||
let (cq, ck, cv) = run(false);
|
||
let (fq, fk, fv) = run(true);
|
||
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
|
||
a.iter()
|
||
.zip(b)
|
||
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
|
||
.fold(0.0f32, f32::max)
|
||
};
|
||
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
|
||
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
|
||
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
|
||
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
|
||
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
|
||
}
|
||
|
||
// ---- dropout (Phase T18) ----
|
||
//
|
||
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
|
||
// (it depends only on (seed, index), NOT on x), so dropout is a fixed elementwise
|
||
// linear map `out_i = c_i·x_i` and the central difference of L is differentiable:
|
||
// the ± perturbation of each x_i sees the SAME mask. The forward function in the
|
||
// closure calls `ops::dropout(x, p, SEED)` with the same SEED, so it reproduces
|
||
// the same mask both times.
|
||
#[test]
|
||
fn dropout_bwd() {
|
||
require_gpu();
|
||
const SEED: u64 = 0xD120_FE5E;
|
||
let p = 0.3f32;
|
||
let (m, n) = (16, 12);
|
||
let x_h = fill(m * n, 71);
|
||
let w = fill(m * n, 72);
|
||
|
||
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
||
let out = ops::dropout(&x, p, SEED);
|
||
scalar_loss(&out, &w).backward();
|
||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||
|
||
let wf = w.clone();
|
||
let lx = move |v: &[f32], s: &[usize]| {
|
||
let o = ops::dropout(&Var::leaf(cuda(v, s)), p, SEED);
|
||
weighted_sum(&o.value(), &wf)
|
||
};
|
||
report(
|
||
"dropout dX",
|
||
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
||
);
|
||
}
|
||
|
||
// Inverted-dropout expectation + keep-rate check. Over a large tensor and a sweep
|
||
// of seeds, the mean of dropout(x) tracks the mean of x (E[out] ≈ x, the inverted
|
||
// 1/(1-p) scaling), and the kept fraction tracks 1-p (the RNG is ~Bernoulli).
|
||
#[test]
|
||
fn dropout_expectation_and_keep_rate() {
|
||
require_gpu();
|
||
let p = 0.25f32;
|
||
let n = 200_000usize;
|
||
let x_h = vec![1.0f32; n]; // mean(x) = 1 → mean(out) should ≈ 1
|
||
let x = cuda(&x_h, &[n]);
|
||
|
||
let trials = 8;
|
||
let mut mean_out_acc = 0.0f64;
|
||
let mut keep_acc = 0.0f64;
|
||
for t in 0..trials {
|
||
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
|
||
let out_h = out.to_device(Device::Cpu);
|
||
let mask_h = mask.to_device(Device::Cpu);
|
||
let mean_out: f64 =
|
||
out_h.as_slice::<f32>().iter().map(|&v| v as f64).sum::<f64>() / n as f64;
|
||
let kept = mask_h.as_slice::<f32>().iter().filter(|&&m| m != 0.0).count();
|
||
mean_out_acc += mean_out;
|
||
keep_acc += kept as f64 / n as f64;
|
||
}
|
||
let mean_out = mean_out_acc / trials as f64;
|
||
let keep_rate = keep_acc / trials as f64;
|
||
println!(
|
||
"dropout p={p}: E[out]={mean_out:.5} (input mean 1.0), keep_rate={keep_rate:.5} (1-p={:.3})",
|
||
1.0 - p
|
||
);
|
||
assert!(
|
||
(mean_out - 1.0).abs() < 0.01,
|
||
"E[out] {mean_out} not ≈ input mean 1.0 (inverted scaling broken)"
|
||
);
|
||
assert!(
|
||
(keep_rate - (1.0 - p) as f64).abs() < 0.01,
|
||
"keep_rate {keep_rate} not ≈ 1-p {}",
|
||
1.0 - p
|
||
);
|
||
}
|
||
|
||
// p=0 is a no-op (the op returns x.clone(), no node) → output is bit-identical to
|
||
// x and its grad flows straight through (the default-graph regression guard at the
|
||
// op level; the model-level bit-identity is in xtrain-model/tests/dropout.rs).
|
||
#[test]
|
||
fn dropout_p0_is_identity() {
|
||
require_gpu();
|
||
let (m, n) = (8, 5);
|
||
let x_h = fill(m * n, 91);
|
||
let x = cuda(&x_h, &[m, n]);
|
||
let (out, _mask) = x.dropout(0.0, 12345);
|
||
let out_h = out.to_device(Device::Cpu);
|
||
for (a, b) in x_h.iter().zip(out_h.as_slice::<f32>()) {
|
||
assert_eq!(*a, *b, "p=0 dropout must be identity");
|
||
}
|
||
}
|
||
|
||
// --- test helpers ---
|
||
|
||
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
||
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
|
||
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
|
||
let wt = Var::leaf(cuda(w, out.value().shape()));
|
||
let prod = ops::mul(out, &wt);
|
||
sum_all(&prod)
|
||
}
|
||
|
||
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
|
||
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
|
||
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
|
||
fn sum_all(x: &Var) -> Var {
|
||
let xv = x.value();
|
||
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
|
||
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
|
||
let shape: Vec<usize> = xv.shape().to_vec();
|
||
Var::from_op(
|
||
scalar,
|
||
vec![x.clone()],
|
||
Box::new(move |d, parents| {
|
||
// d is [1]; broadcast d to a same-shape tensor over the input.
|
||
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||
let ones = vec![dval; shape.iter().product()];
|
||
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
|
||
Var::push_grad(&parents[0], g);
|
||
}),
|
||
)
|
||
}
|
||
|
||
// Manual transpose node for the composed-attention test (the engine has no
|
||
// transpose op; xserv does the equivalent host-side reshape around RoPE).
|
||
fn transpose_var(x: &Var) -> Var {
|
||
let xt = x.value().transpose_2d();
|
||
Var::from_op(
|
||
xt,
|
||
vec![x.clone()],
|
||
Box::new(|d, parents| {
|
||
Var::push_grad(&parents[0], d.transpose_2d());
|
||
}),
|
||
)
|
||
}
|