From 0acfa5df11bab0a3de070a1f5f030622f9774826 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 15 Jun 2026 16:05:20 +0800 Subject: [PATCH] ops: grad-check the T5 structural ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Finite-diff grad-checks (same L=sum(W∘out) harness as autograd.rs) for embedding (incl. repeated ids), reshape, transpose_3d01, transpose_2d, and split/merge_heads round-trip. Gated #![cfg(not(no_cuda))]. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-autodiff/tests/structural.rs | 220 +++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 crates/xtrain-autodiff/tests/structural.rs diff --git a/crates/xtrain-autodiff/tests/structural.rs b/crates/xtrain-autodiff/tests/structural.rs new file mode 100644 index 0000000..f5d076b --- /dev/null +++ b/crates/xtrain-autodiff/tests/structural.rs @@ -0,0 +1,220 @@ +// GPU grad-checks for the Phase T5 structural ops added on top of the T4 set: +// embedding (gather fwd / scatter-add bwd), reshape, transpose_3d01, +// transpose_2d, and split/merge_heads. Same harness as autograd.rs: +// L = sum(W ∘ out), W fixed random ⇒ upstream dOut = W; run backward(), then +// grad-check each leaf's .grad() against central finite differences. +// +// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5. +#![cfg(not(no_cuda))] + +use xtrain_autodiff::ops; +use xtrain_autodiff::tape::Var; +use xtrain_autodiff::{GradCheckConfig, grad_check}; +use xtrain_cuda::device; +use xtrain_tensor::{Device, Tensor}; + +fn fill(n: usize, seed: u64) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + ((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5 + }) + .collect() +} + +fn require_gpu() { + assert!( + device::device_count().expect("device count") > 0, + "no CUDA device" + ); + device::set_device(0).unwrap(); +} + +fn cuda(data: &[f32], shape: &[usize]) -> Tensor { + Tensor::from_slice(data, shape).to_device(Device::Cuda(0)) +} + +fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 { + out.to_device(Device::Cpu) + .as_slice::() + .iter() + .zip(w) + .map(|(o, w)| o * w) + .sum() +} + +// Structural ops are exactly linear in their input → a large eps just sharpens +// f32 resolution (same as add/mul/transpose in autograd.rs). +fn cfg_linear() -> GradCheckConfig { + GradCheckConfig { + eps: 1e-2, + rel_tol: 2e-2, + atol: 1e-3, + } +} + +fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) { + println!( + "{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})", + res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index + ); + assert!(res.passed, "{name} grad-check failed: {res:?}"); +} + +// L = sum(W ∘ out): a constant-W leaf mul + sum-to-scalar reduction. +fn scalar_loss(out: &Var, w: &[f32]) -> Var { + let wt = Var::leaf(cuda(w, out.value().shape())); + sum_all(&ops::mul(out, &wt)) +} + +fn sum_all(x: &Var) -> Var { + let xv = x.value(); + let total: f32 = xv.to_device(Device::Cpu).as_slice::().iter().sum(); + let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device()); + let shape: Vec = xv.shape().to_vec(); + Var::from_op( + scalar, + vec![x.clone()], + Box::new(move |d, parents| { + let dval = d.to_device(Device::Cpu).as_slice::()[0]; + let ones = vec![dval; shape.iter().product()]; + let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0)); + Var::push_grad(&parents[0], g); + }), + ) +} + +// ---- embedding (gather fwd / scatter-add bwd) ---- +// Includes a repeated id so the atomic scatter-add accumulation is exercised. +#[test] +fn embedding_bwd() { + require_gpu(); + let (vocab, dim) = (5, 7); + let ids_host: Vec = vec![0, 3, 1, 3, 2, 0]; // 0 and 3 repeat + let seq = ids_host.len(); + let table_h = fill(vocab * dim, 201); + let w = fill(seq * dim, 202); + + let ids = Tensor::from_slice(&ids_host, &[seq]).to_device(Device::Cuda(0)); + let table = Var::leaf(cuda(&table_h, &[vocab, dim])); + let out = ops::embedding(&table, &ids); + scalar_loss(&out, &w).backward(); + + let dtable = table.grad().unwrap().to_device(Device::Cpu); + let idf = ids_host.clone(); + let wf = w.clone(); + let lt = move |v: &[f32], s: &[usize]| { + let ids = Tensor::from_slice(&idf, &[seq]).to_device(Device::Cuda(0)); + weighted_sum(&cuda(v, s).embedding(&ids), &wf) + }; + report( + "embedding dTable", + &grad_check( + &table_h, + &[vocab, dim], + <, + dtable.as_slice::(), + cfg_linear(), + ), + ); +} + +// ---- reshape ---- +#[test] +fn reshape_bwd() { + require_gpu(); + let (rows, cols) = (6, 8); + let x_h = fill(rows * cols, 211); + let w = fill(rows * cols, 212); + + let x = Var::leaf(cuda(&x_h, &[rows, cols])); + let out = ops::reshape(&x, &[rows * 2, cols / 2]); + scalar_loss(&out, &w).backward(); + + let dx = x.grad().unwrap().to_device(Device::Cpu); + let wf = w.clone(); + let lx = + move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).reshape(&[rows * 2, cols / 2]), &wf); + report( + "reshape dX", + &grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::(), cfg_linear()), + ); +} + +// ---- transpose_3d01 ([a,b,c] -> [b,a,c]) ---- +#[test] +fn transpose_3d01_bwd() { + require_gpu(); + let (a, b, c) = (3, 4, 5); + let x_h = fill(a * b * c, 221); + let w = fill(a * b * c, 222); + + let x = Var::leaf(cuda(&x_h, &[a, b, c])); + let out = ops::transpose_3d01(&x); + scalar_loss(&out, &w).backward(); + + let dx = x.grad().unwrap().to_device(Device::Cpu); + let wf = w.clone(); + let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_3d01(), &wf); + report( + "transpose_3d01 dX", + &grad_check(&x_h, &[a, b, c], &lx, dx.as_slice::(), cfg_linear()), + ); +} + +// ---- transpose_2d ---- +#[test] +fn transpose_2d_bwd() { + require_gpu(); + let (r, c) = (5, 7); + let x_h = fill(r * c, 231); + let w = fill(r * c, 232); + + let x = Var::leaf(cuda(&x_h, &[r, c])); + let out = ops::transpose_2d(&x); + scalar_loss(&out, &w).backward(); + + let dx = x.grad().unwrap().to_device(Device::Cpu); + let wf = w.clone(); + let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_2d(), &wf); + report( + "transpose_2d dX", + &grad_check(&x_h, &[r, c], &lx, dx.as_slice::(), cfg_linear()), + ); +} + +// ---- split_heads + merge_heads round-trip (identity reshuffle of [nh,seq,hd]) ---- +// out = merge_heads(split_heads(x)) must equal x, and its grad must be dOut=W +// reshuffled identically — i.e. dx grad-checks against the identity composition. +#[test] +fn split_merge_heads_bwd() { + require_gpu(); + let (nh, seq, hd) = (3, 4, 5); + let x_h = fill(nh * seq * hd, 241); + let w = fill(nh * seq * hd, 242); + + let x = Var::leaf(cuda(&x_h, &[nh, seq, hd])); + let heads = ops::split_heads(&x); + let out = ops::merge_heads(&heads); // back to [nh,seq,hd] + scalar_loss(&out, &w).backward(); + + let dx = x.grad().unwrap().to_device(Device::Cpu); + // forward is identity, so grad-check the identity map. + let wf = w.clone(); + let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s), &wf); + report( + "split/merge_heads dX", + &grad_check( + &x_h, + &[nh, seq, hd], + &lx, + dx.as_slice::(), + cfg_linear(), + ), + ); +}