Finite-diff grad-checks (same L=sum(W∘out) harness as autograd.rs) for embedding (incl. repeated ids), reshape, transpose_3d01, transpose_2d, and split/merge_heads round-trip. Gated #![cfg(not(no_cuda))]. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
221 lines
6.8 KiB
Rust
221 lines
6.8 KiB
Rust
// GPU grad-checks for the Phase T5 structural ops added on top of the T4 set:
|
|
// embedding (gather fwd / scatter-add bwd), reshape, transpose_3d01,
|
|
// transpose_2d, and split/merge_heads. Same harness as autograd.rs:
|
|
// L = sum(W ∘ out), W fixed random ⇒ upstream dOut = W; run backward(), then
|
|
// grad-check each leaf's .grad() against central finite differences.
|
|
//
|
|
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_autodiff::ops;
|
|
use xtrain_autodiff::tape::Var;
|
|
use xtrain_autodiff::{GradCheckConfig, grad_check};
|
|
use xtrain_cuda::device;
|
|
use xtrain_tensor::{Device, Tensor};
|
|
|
|
fn fill(n: usize, seed: u64) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn require_gpu() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
}
|
|
|
|
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
|
|
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
|
|
}
|
|
|
|
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
|
|
out.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.iter()
|
|
.zip(w)
|
|
.map(|(o, w)| o * w)
|
|
.sum()
|
|
}
|
|
|
|
// Structural ops are exactly linear in their input → a large eps just sharpens
|
|
// f32 resolution (same as add/mul/transpose in autograd.rs).
|
|
fn cfg_linear() -> GradCheckConfig {
|
|
GradCheckConfig {
|
|
eps: 1e-2,
|
|
rel_tol: 2e-2,
|
|
atol: 1e-3,
|
|
}
|
|
}
|
|
|
|
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
|
|
println!(
|
|
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
|
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
|
|
);
|
|
assert!(res.passed, "{name} grad-check failed: {res:?}");
|
|
}
|
|
|
|
// L = sum(W ∘ out): a constant-W leaf mul + sum-to-scalar reduction.
|
|
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
|
|
let wt = Var::leaf(cuda(w, out.value().shape()));
|
|
sum_all(&ops::mul(out, &wt))
|
|
}
|
|
|
|
fn sum_all(x: &Var) -> Var {
|
|
let xv = x.value();
|
|
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
|
|
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
|
|
let shape: Vec<usize> = xv.shape().to_vec();
|
|
Var::from_op(
|
|
scalar,
|
|
vec![x.clone()],
|
|
Box::new(move |d, parents| {
|
|
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
|
let ones = vec![dval; shape.iter().product()];
|
|
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
|
|
Var::push_grad(&parents[0], g);
|
|
}),
|
|
)
|
|
}
|
|
|
|
// ---- embedding (gather fwd / scatter-add bwd) ----
|
|
// Includes a repeated id so the atomic scatter-add accumulation is exercised.
|
|
#[test]
|
|
fn embedding_bwd() {
|
|
require_gpu();
|
|
let (vocab, dim) = (5, 7);
|
|
let ids_host: Vec<i32> = vec![0, 3, 1, 3, 2, 0]; // 0 and 3 repeat
|
|
let seq = ids_host.len();
|
|
let table_h = fill(vocab * dim, 201);
|
|
let w = fill(seq * dim, 202);
|
|
|
|
let ids = Tensor::from_slice(&ids_host, &[seq]).to_device(Device::Cuda(0));
|
|
let table = Var::leaf(cuda(&table_h, &[vocab, dim]));
|
|
let out = ops::embedding(&table, &ids);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dtable = table.grad().unwrap().to_device(Device::Cpu);
|
|
let idf = ids_host.clone();
|
|
let wf = w.clone();
|
|
let lt = move |v: &[f32], s: &[usize]| {
|
|
let ids = Tensor::from_slice(&idf, &[seq]).to_device(Device::Cuda(0));
|
|
weighted_sum(&cuda(v, s).embedding(&ids), &wf)
|
|
};
|
|
report(
|
|
"embedding dTable",
|
|
&grad_check(
|
|
&table_h,
|
|
&[vocab, dim],
|
|
<,
|
|
dtable.as_slice::<f32>(),
|
|
cfg_linear(),
|
|
),
|
|
);
|
|
}
|
|
|
|
// ---- reshape ----
|
|
#[test]
|
|
fn reshape_bwd() {
|
|
require_gpu();
|
|
let (rows, cols) = (6, 8);
|
|
let x_h = fill(rows * cols, 211);
|
|
let w = fill(rows * cols, 212);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
|
let out = ops::reshape(&x, &[rows * 2, cols / 2]);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).reshape(&[rows * 2, cols / 2]), &wf);
|
|
report(
|
|
"reshape dX",
|
|
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- transpose_3d01 ([a,b,c] -> [b,a,c]) ----
|
|
#[test]
|
|
fn transpose_3d01_bwd() {
|
|
require_gpu();
|
|
let (a, b, c) = (3, 4, 5);
|
|
let x_h = fill(a * b * c, 221);
|
|
let w = fill(a * b * c, 222);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[a, b, c]));
|
|
let out = ops::transpose_3d01(&x);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_3d01(), &wf);
|
|
report(
|
|
"transpose_3d01 dX",
|
|
&grad_check(&x_h, &[a, b, c], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- transpose_2d ----
|
|
#[test]
|
|
fn transpose_2d_bwd() {
|
|
require_gpu();
|
|
let (r, c) = (5, 7);
|
|
let x_h = fill(r * c, 231);
|
|
let w = fill(r * c, 232);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[r, c]));
|
|
let out = ops::transpose_2d(&x);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_2d(), &wf);
|
|
report(
|
|
"transpose_2d dX",
|
|
&grad_check(&x_h, &[r, c], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- split_heads + merge_heads round-trip (identity reshuffle of [nh,seq,hd]) ----
|
|
// out = merge_heads(split_heads(x)) must equal x, and its grad must be dOut=W
|
|
// reshuffled identically — i.e. dx grad-checks against the identity composition.
|
|
#[test]
|
|
fn split_merge_heads_bwd() {
|
|
require_gpu();
|
|
let (nh, seq, hd) = (3, 4, 5);
|
|
let x_h = fill(nh * seq * hd, 241);
|
|
let w = fill(nh * seq * hd, 242);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[nh, seq, hd]));
|
|
let heads = ops::split_heads(&x);
|
|
let out = ops::merge_heads(&heads); // back to [nh,seq,hd]
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
// forward is identity, so grad-check the identity map.
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s), &wf);
|
|
report(
|
|
"split/merge_heads dX",
|
|
&grad_check(
|
|
&x_h,
|
|
&[nh, seq, hd],
|
|
&lx,
|
|
dx.as_slice::<f32>(),
|
|
cfg_linear(),
|
|
),
|
|
);
|
|
}
|