Files
xtrain/crates/xtrain-autodiff/tests/structural.rs
Gahow Wang 0acfa5df11 ops: grad-check the T5 structural ops
Finite-diff grad-checks (same L=sum(W∘out) harness as autograd.rs) for
embedding (incl. repeated ids), reshape, transpose_3d01, transpose_2d,
and split/merge_heads round-trip. Gated #![cfg(not(no_cuda))].

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:05:20 +08:00

221 lines
6.8 KiB
Rust

// GPU grad-checks for the Phase T5 structural ops added on top of the T4 set:
// embedding (gather fwd / scatter-add bwd), reshape, transpose_3d01,
// transpose_2d, and split/merge_heads. Same harness as autograd.rs:
// L = sum(W ∘ out), W fixed random ⇒ upstream dOut = W; run backward(), then
// grad-check each leaf's .grad() against central finite differences.
//
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_autodiff::{GradCheckConfig, grad_check};
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
fn fill(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
}
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
out.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.zip(w)
.map(|(o, w)| o * w)
.sum()
}
// Structural ops are exactly linear in their input → a large eps just sharpens
// f32 resolution (same as add/mul/transpose in autograd.rs).
fn cfg_linear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
}
}
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
println!(
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
);
assert!(res.passed, "{name} grad-check failed: {res:?}");
}
// L = sum(W ∘ out): a constant-W leaf mul + sum-to-scalar reduction.
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
let wt = Var::leaf(cuda(w, out.value().shape()));
sum_all(&ops::mul(out, &wt))
}
fn sum_all(x: &Var) -> Var {
let xv = x.value();
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
let shape: Vec<usize> = xv.shape().to_vec();
Var::from_op(
scalar,
vec![x.clone()],
Box::new(move |d, parents| {
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ones = vec![dval; shape.iter().product()];
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
Var::push_grad(&parents[0], g);
}),
)
}
// ---- embedding (gather fwd / scatter-add bwd) ----
// Includes a repeated id so the atomic scatter-add accumulation is exercised.
#[test]
fn embedding_bwd() {
require_gpu();
let (vocab, dim) = (5, 7);
let ids_host: Vec<i32> = vec![0, 3, 1, 3, 2, 0]; // 0 and 3 repeat
let seq = ids_host.len();
let table_h = fill(vocab * dim, 201);
let w = fill(seq * dim, 202);
let ids = Tensor::from_slice(&ids_host, &[seq]).to_device(Device::Cuda(0));
let table = Var::leaf(cuda(&table_h, &[vocab, dim]));
let out = ops::embedding(&table, &ids);
scalar_loss(&out, &w).backward();
let dtable = table.grad().unwrap().to_device(Device::Cpu);
let idf = ids_host.clone();
let wf = w.clone();
let lt = move |v: &[f32], s: &[usize]| {
let ids = Tensor::from_slice(&idf, &[seq]).to_device(Device::Cuda(0));
weighted_sum(&cuda(v, s).embedding(&ids), &wf)
};
report(
"embedding dTable",
&grad_check(
&table_h,
&[vocab, dim],
&lt,
dtable.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- reshape ----
#[test]
fn reshape_bwd() {
require_gpu();
let (rows, cols) = (6, 8);
let x_h = fill(rows * cols, 211);
let w = fill(rows * cols, 212);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let out = ops::reshape(&x, &[rows * 2, cols / 2]);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).reshape(&[rows * 2, cols / 2]), &wf);
report(
"reshape dX",
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- transpose_3d01 ([a,b,c] -> [b,a,c]) ----
#[test]
fn transpose_3d01_bwd() {
require_gpu();
let (a, b, c) = (3, 4, 5);
let x_h = fill(a * b * c, 221);
let w = fill(a * b * c, 222);
let x = Var::leaf(cuda(&x_h, &[a, b, c]));
let out = ops::transpose_3d01(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_3d01(), &wf);
report(
"transpose_3d01 dX",
&grad_check(&x_h, &[a, b, c], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- transpose_2d ----
#[test]
fn transpose_2d_bwd() {
require_gpu();
let (r, c) = (5, 7);
let x_h = fill(r * c, 231);
let w = fill(r * c, 232);
let x = Var::leaf(cuda(&x_h, &[r, c]));
let out = ops::transpose_2d(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_2d(), &wf);
report(
"transpose_2d dX",
&grad_check(&x_h, &[r, c], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- split_heads + merge_heads round-trip (identity reshuffle of [nh,seq,hd]) ----
// out = merge_heads(split_heads(x)) must equal x, and its grad must be dOut=W
// reshuffled identically — i.e. dx grad-checks against the identity composition.
#[test]
fn split_merge_heads_bwd() {
require_gpu();
let (nh, seq, hd) = (3, 4, 5);
let x_h = fill(nh * seq * hd, 241);
let w = fill(nh * seq * hd, 242);
let x = Var::leaf(cuda(&x_h, &[nh, seq, hd]));
let heads = ops::split_heads(&x);
let out = ops::merge_heads(&heads); // back to [nh,seq,hd]
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// forward is identity, so grad-check the identity map.
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s), &wf);
report(
"split/merge_heads dX",
&grad_check(
&x_h,
&[nh, seq, hd],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}