ops.rs wraps each Tensor op as a Var node with its backward closure (forward caches captured by move). swiglu = mul(silu(gate), up); attention is composed (matmul+scale+softmax+matmul), no fused kernel. tests/autograd.rs grad-checks every op via the L=sum(W∘out) template, plus a fan-out grad-accumulation test (dL/dx=4x) and an end-to-end composed-attention grad-check (dQ/dK/dV). Adds xtrain-cuda dev-dep for device selection in tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
547 lines
17 KiB
Rust
547 lines
17 KiB
Rust
// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
|
|
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
|
|
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
|
|
// through the tape, call backward(), and grad-check each input's .grad() against
|
|
// central finite differences of L.
|
|
//
|
|
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_autodiff::ops;
|
|
use xtrain_autodiff::tape::Var;
|
|
use xtrain_autodiff::{GradCheckConfig, grad_check};
|
|
use xtrain_cuda::device;
|
|
use xtrain_tensor::{Device, Tensor};
|
|
|
|
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
|
|
fn fill(n: usize, seed: u64) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn require_gpu() {
|
|
assert!(
|
|
device::device_count().expect("device count") > 0,
|
|
"no CUDA device"
|
|
);
|
|
device::set_device(0).unwrap();
|
|
}
|
|
|
|
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
|
|
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
|
|
}
|
|
|
|
// L = sum(W ∘ out) for fixed weights W over the op output.
|
|
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
|
|
out.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.iter()
|
|
.zip(w)
|
|
.map(|(o, w)| o * w)
|
|
.sum()
|
|
}
|
|
|
|
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
|
|
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
|
|
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
|
|
// → smaller eps. atol floors near-zero grads.
|
|
fn cfg_linear() -> GradCheckConfig {
|
|
GradCheckConfig {
|
|
eps: 1e-2,
|
|
rel_tol: 2e-2,
|
|
atol: 1e-3,
|
|
}
|
|
}
|
|
fn cfg_nonlinear() -> GradCheckConfig {
|
|
GradCheckConfig {
|
|
eps: 1e-3,
|
|
rel_tol: 3e-2,
|
|
atol: 1e-3,
|
|
}
|
|
}
|
|
|
|
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
|
|
println!(
|
|
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
|
|
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
|
|
);
|
|
assert!(res.passed, "{name} grad-check failed: {res:?}");
|
|
}
|
|
|
|
// ---- add ----
|
|
#[test]
|
|
fn add_bwd() {
|
|
require_gpu();
|
|
let (m, n) = (8, 6);
|
|
let a_h = fill(m * n, 1);
|
|
let b_h = fill(m * n, 2);
|
|
let w = fill(m * n, 3);
|
|
|
|
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
|
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
|
let out = ops::add(&a, &b);
|
|
let loss = scalar_loss(&out, &w);
|
|
loss.backward();
|
|
|
|
let da = a.grad().unwrap().to_device(Device::Cpu);
|
|
let db = b.grad().unwrap().to_device(Device::Cpu);
|
|
let bf = b_h.clone();
|
|
let wf = w.clone();
|
|
let la = move |v: &[f32], s: &[usize]| {
|
|
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
|
|
weighted_sum(&o, &wf)
|
|
};
|
|
report(
|
|
"add dA",
|
|
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
let af = a_h.clone();
|
|
let wf = w.clone();
|
|
let lb = move |v: &[f32], s: &[usize]| {
|
|
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
|
|
weighted_sum(&o, &wf)
|
|
};
|
|
report(
|
|
"add dB",
|
|
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- mul ----
|
|
#[test]
|
|
fn mul_bwd() {
|
|
require_gpu();
|
|
let (m, n) = (8, 6);
|
|
let a_h = fill(m * n, 11);
|
|
let b_h = fill(m * n, 22);
|
|
let w = fill(m * n, 33);
|
|
|
|
let a = Var::leaf(cuda(&a_h, &[m, n]));
|
|
let b = Var::leaf(cuda(&b_h, &[m, n]));
|
|
let out = ops::mul(&a, &b);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let da = a.grad().unwrap().to_device(Device::Cpu);
|
|
let db = b.grad().unwrap().to_device(Device::Cpu);
|
|
let bf = b_h.clone();
|
|
let wf = w.clone();
|
|
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
|
|
report(
|
|
"mul dA",
|
|
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
let af = a_h.clone();
|
|
let wf = w.clone();
|
|
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
|
|
report(
|
|
"mul dB",
|
|
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- add_bias (broadcast) ----
|
|
#[test]
|
|
fn add_bias_bwd() {
|
|
require_gpu();
|
|
let (m, n) = (10, 7);
|
|
let x_h = fill(m * n, 5);
|
|
let b_h = fill(n, 6);
|
|
let w = fill(m * n, 7);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[m, n]));
|
|
let bias = Var::leaf(cuda(&b_h, &[n]));
|
|
let out = ops::add_bias(&x, &bias);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
|
|
let bf = b_h.clone();
|
|
let wf = w.clone();
|
|
let lx =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
|
|
report(
|
|
"add_bias dX",
|
|
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
let xf = x_h.clone();
|
|
let wf = w.clone();
|
|
let lb =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
|
|
report(
|
|
"add_bias dBias",
|
|
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
|
|
#[test]
|
|
fn matmul_bwd() {
|
|
require_gpu();
|
|
let (m, k, n) = (6, 5, 4);
|
|
let a_h = fill(m * k, 41);
|
|
let b_h = fill(k * n, 42);
|
|
let w = fill(m * n, 43);
|
|
|
|
let a = Var::leaf(cuda(&a_h, &[m, k]));
|
|
let b = Var::leaf(cuda(&b_h, &[k, n]));
|
|
let out = ops::matmul(&a, &b);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let da = a.grad().unwrap().to_device(Device::Cpu);
|
|
let db = b.grad().unwrap().to_device(Device::Cpu);
|
|
let bf = b_h.clone();
|
|
let wf = w.clone();
|
|
let la =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
|
|
report(
|
|
"matmul dA",
|
|
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
let af = a_h.clone();
|
|
let wf = w.clone();
|
|
let lb =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
|
|
report(
|
|
"matmul dB",
|
|
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- rms_norm ----
|
|
#[test]
|
|
fn rms_norm_bwd() {
|
|
require_gpu();
|
|
let (rows, cols) = (5, 16);
|
|
let eps = 1e-5;
|
|
let x_h = fill(rows * cols, 51);
|
|
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
|
|
let w = fill(rows * cols, 53);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
|
let gamma = Var::leaf(cuda(&g_h, &[cols]));
|
|
let out = ops::rms_norm(&x, &gamma, eps);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
|
|
let gf = g_h.clone();
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| {
|
|
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
|
|
weighted_sum(&o, &wf)
|
|
};
|
|
report(
|
|
"rms_norm dX",
|
|
&grad_check(
|
|
&x_h,
|
|
&[rows, cols],
|
|
&lx,
|
|
dx.as_slice::<f32>(),
|
|
cfg_nonlinear(),
|
|
),
|
|
);
|
|
let xf = x_h.clone();
|
|
let wf = w.clone();
|
|
let lg = move |v: &[f32], s: &[usize]| {
|
|
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
|
|
weighted_sum(&o, &wf)
|
|
};
|
|
report(
|
|
"rms_norm dGamma",
|
|
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
|
);
|
|
}
|
|
|
|
// ---- silu ----
|
|
#[test]
|
|
fn silu_bwd() {
|
|
require_gpu();
|
|
let n = 64;
|
|
let x_h = fill(n, 61);
|
|
let w = fill(n, 62);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[n]));
|
|
let out = ops::silu(&x);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
|
|
report(
|
|
"silu dX",
|
|
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
|
);
|
|
}
|
|
|
|
// ---- swiglu (composed: silu(gate) ∘ up) ----
|
|
#[test]
|
|
fn swiglu_bwd() {
|
|
require_gpu();
|
|
let n = 48;
|
|
let g_h = fill(n, 71);
|
|
let u_h = fill(n, 72);
|
|
let w = fill(n, 73);
|
|
|
|
let gate = Var::leaf(cuda(&g_h, &[n]));
|
|
let up = Var::leaf(cuda(&u_h, &[n]));
|
|
let out = ops::swiglu(&gate, &up);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dg = gate.grad().unwrap().to_device(Device::Cpu);
|
|
let du = up.grad().unwrap().to_device(Device::Cpu);
|
|
let uf = u_h.clone();
|
|
let wf = w.clone();
|
|
let lg =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
|
|
report(
|
|
"swiglu dGate",
|
|
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
|
|
);
|
|
let gf = g_h.clone();
|
|
let wf = w.clone();
|
|
let lu =
|
|
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
|
|
report(
|
|
"swiglu dUp",
|
|
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- rope ----
|
|
#[test]
|
|
fn rope_bwd() {
|
|
require_gpu();
|
|
let (tokens, heads, head_dim) = (4, 2, 8);
|
|
let n = tokens * heads * head_dim;
|
|
let theta = 10000.0;
|
|
let x_h = fill(n, 81);
|
|
let w = fill(n, 82);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
|
|
let out = ops::rope(&x, theta);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta), &wf);
|
|
report(
|
|
"rope dX",
|
|
&grad_check(
|
|
&x_h,
|
|
&[tokens, heads, head_dim],
|
|
&lx,
|
|
dx.as_slice::<f32>(),
|
|
cfg_linear(),
|
|
),
|
|
);
|
|
}
|
|
|
|
// ---- softmax ----
|
|
#[test]
|
|
fn softmax_bwd() {
|
|
require_gpu();
|
|
let (rows, cols) = (4, 10);
|
|
let x_h = fill(rows * cols, 91);
|
|
let w = fill(rows * cols, 92);
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
|
let out = ops::softmax(&x);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
|
|
report(
|
|
"softmax dX",
|
|
&grad_check(
|
|
&x_h,
|
|
&[rows, cols],
|
|
&lx,
|
|
dx.as_slice::<f32>(),
|
|
cfg_nonlinear(),
|
|
),
|
|
);
|
|
}
|
|
|
|
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
|
|
#[test]
|
|
fn cross_entropy_bwd() {
|
|
require_gpu();
|
|
let (rows, cols) = (5, 8);
|
|
let x_h = fill(rows * cols, 101);
|
|
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
|
|
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
|
let loss = ops::cross_entropy(&x, &target);
|
|
loss.backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
|
|
let tgt = targets.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| {
|
|
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
|
|
let (_, per_row) = cuda(v, s).cross_entropy(&t);
|
|
per_row
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.iter()
|
|
.sum::<f32>()
|
|
/ rows as f32
|
|
};
|
|
report(
|
|
"cross_entropy dX",
|
|
&grad_check(
|
|
&x_h,
|
|
&[rows, cols],
|
|
&lx,
|
|
dx.as_slice::<f32>(),
|
|
cfg_nonlinear(),
|
|
),
|
|
);
|
|
}
|
|
|
|
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
|
|
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
|
|
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
|
|
#[test]
|
|
fn fanout_grad_accumulation() {
|
|
require_gpu();
|
|
let n = 12;
|
|
let x_h = fill(n, 111);
|
|
let w = vec![1.0f32; n];
|
|
|
|
let x = Var::leaf(cuda(&x_h, &[n]));
|
|
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
|
|
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
|
|
let out = ops::add(&sq1, &sq2); // 2x²
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
|
let wf = w.clone();
|
|
let lx = move |v: &[f32], s: &[usize]| {
|
|
let t = cuda(v, s);
|
|
let o = t.mul(&t).add(&t.mul(&t));
|
|
weighted_sum(&o, &wf)
|
|
};
|
|
// Analytic dx should be 4x; fan-out summed all four uses of x.
|
|
report(
|
|
"fanout dX",
|
|
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
|
|
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
|
|
#[test]
|
|
fn attention_composed_bwd() {
|
|
require_gpu();
|
|
let (s, d) = (5, 6); // seq_len, head_dim
|
|
let scale = 1.0 / (d as f32).sqrt();
|
|
let q_h = fill(s * d, 121);
|
|
let k_h = fill(s * d, 122);
|
|
let v_h = fill(s * d, 123);
|
|
let w = fill(s * d, 124); // weights over the [s,d] attention output
|
|
|
|
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
|
|
let kt = transpose_var(k); // [d,s] (manual transpose node)
|
|
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
|
|
let probs = ops::softmax(&scores);
|
|
ops::matmul(&probs, v) // [s,d]
|
|
};
|
|
|
|
let q = Var::leaf(cuda(&q_h, &[s, d]));
|
|
let k = Var::leaf(cuda(&k_h, &[s, d]));
|
|
let v = Var::leaf(cuda(&v_h, &[s, d]));
|
|
let out = attn(&q, &k, &v);
|
|
scalar_loss(&out, &w).backward();
|
|
|
|
let dq = q.grad().unwrap().to_device(Device::Cpu);
|
|
let dk = k.grad().unwrap().to_device(Device::Cpu);
|
|
let dv = v.grad().unwrap().to_device(Device::Cpu);
|
|
|
|
// Re-run the same forward inside the loss closures (host-side) per input.
|
|
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
|
|
let qv = cuda(qh, &[s, d]);
|
|
let kv = cuda(kh, &[s, d]);
|
|
let vv = cuda(vh, &[s, d]);
|
|
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
|
|
let probs = scores.softmax();
|
|
weighted_sum(&probs.matmul(&vv), &w)
|
|
};
|
|
|
|
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
|
|
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
|
|
report(
|
|
"attn dQ",
|
|
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
|
|
);
|
|
|
|
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
|
|
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
|
|
report(
|
|
"attn dK",
|
|
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
|
|
);
|
|
|
|
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
|
|
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
|
|
report(
|
|
"attn dV",
|
|
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
|
|
);
|
|
}
|
|
|
|
// --- test helpers ---
|
|
|
|
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
|
|
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
|
|
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
|
|
let wt = Var::leaf(cuda(w, out.value().shape()));
|
|
let prod = ops::mul(out, &wt);
|
|
sum_all(&prod)
|
|
}
|
|
|
|
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
|
|
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
|
|
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
|
|
fn sum_all(x: &Var) -> Var {
|
|
let xv = x.value();
|
|
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
|
|
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
|
|
let shape: Vec<usize> = xv.shape().to_vec();
|
|
Var::from_op(
|
|
scalar,
|
|
vec![x.clone()],
|
|
Box::new(move |d, parents| {
|
|
// d is [1]; broadcast d to a same-shape tensor over the input.
|
|
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
|
let ones = vec![dval; shape.iter().product()];
|
|
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
|
|
Var::push_grad(&parents[0], g);
|
|
}),
|
|
)
|
|
}
|
|
|
|
// Manual transpose node for the composed-attention test (the engine has no
|
|
// transpose op; xserv does the equivalent host-side reshape around RoPE).
|
|
fn transpose_var(x: &Var) -> Var {
|
|
let xt = x.value().transpose_2d();
|
|
Var::from_op(
|
|
xt,
|
|
vec![x.clone()],
|
|
Box::new(|d, parents| {
|
|
Var::push_grad(&parents[0], d.transpose_2d());
|
|
}),
|
|
)
|
|
}
|