Files
xtrain/crates/xtrain-autodiff/tests/autograd.rs
Gahow Wang 0e82b2438e test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
  ragged sequences == per-sequence single-seq forward on the real rows. fp32
  max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
  free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
  Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).

bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:09 +08:00

1271 lines
44 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
// through the tape, call backward(), and grad-check each input's .grad() against
// central finite differences of L.
//
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_autodiff::{GradCheckConfig, grad_check};
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
fn fill(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
}
// L = sum(W ∘ out) for fixed weights W over the op output.
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
out.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.zip(w)
.map(|(o, w)| o * w)
.sum()
}
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
// → smaller eps. atol floors near-zero grads.
fn cfg_linear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
}
}
fn cfg_nonlinear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-3,
rel_tol: 3e-2,
atol: 1e-3,
}
}
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
println!(
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
);
assert!(res.passed, "{name} grad-check failed: {res:?}");
}
// ---- add ----
#[test]
fn add_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 1);
let b_h = fill(m * n, 2);
let w = fill(m * n, 3);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::add(&a, &b);
let loss = scalar_loss(&out, &w);
loss.backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| {
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
weighted_sum(&o, &wf)
};
report(
"add dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| {
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
weighted_sum(&o, &wf)
};
report(
"add dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- mul ----
#[test]
fn mul_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 11);
let b_h = fill(m * n, 22);
let w = fill(m * n, 33);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::mul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
report(
"mul dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
report(
"mul dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- add_bias (broadcast) ----
#[test]
fn add_bias_bwd() {
require_gpu();
let (m, n) = (10, 7);
let x_h = fill(m * n, 5);
let b_h = fill(n, 6);
let w = fill(m * n, 7);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let bias = Var::leaf(cuda(&b_h, &[n]));
let out = ops::add_bias(&x, &bias);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let lx =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
report(
"add_bias dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
let xf = x_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
report(
"add_bias dBias",
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
);
}
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
#[test]
fn matmul_bwd() {
require_gpu();
let (m, k, n) = (6, 5, 4);
let a_h = fill(m * k, 41);
let b_h = fill(k * n, 42);
let w = fill(m * n, 43);
let a = Var::leaf(cuda(&a_h, &[m, k]));
let b = Var::leaf(cuda(&b_h, &[k, n]));
let out = ops::matmul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
report(
"matmul dA",
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
report(
"matmul dB",
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rms_norm ----
#[test]
fn rms_norm_bwd() {
require_gpu();
let (rows, cols) = (5, 16);
let eps = 1e-5;
let x_h = fill(rows * cols, 51);
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
let w = fill(rows * cols, 53);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let gamma = Var::leaf(cuda(&g_h, &[cols]));
let out = ops::rms_norm(&x, &gamma, eps);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
let gf = g_h.clone();
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let xf = x_h.clone();
let wf = w.clone();
let lg = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dGamma",
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- silu ----
#[test]
fn silu_bwd() {
require_gpu();
let n = 64;
let x_h = fill(n, 61);
let w = fill(n, 62);
let x = Var::leaf(cuda(&x_h, &[n]));
let out = ops::silu(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
report(
"silu dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- swiglu (composed: silu(gate) ∘ up) ----
#[test]
fn swiglu_bwd() {
require_gpu();
let n = 48;
let g_h = fill(n, 71);
let u_h = fill(n, 72);
let w = fill(n, 73);
let gate = Var::leaf(cuda(&g_h, &[n]));
let up = Var::leaf(cuda(&u_h, &[n]));
let out = ops::swiglu(&gate, &up);
scalar_loss(&out, &w).backward();
let dg = gate.grad().unwrap().to_device(Device::Cpu);
let du = up.grad().unwrap().to_device(Device::Cpu);
let uf = u_h.clone();
let wf = w.clone();
let lg =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
report(
"swiglu dGate",
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
let gf = g_h.clone();
let wf = w.clone();
let lu =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
report(
"swiglu dUp",
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rope ----
#[test]
fn rope_bwd() {
require_gpu();
let (tokens, heads, head_dim) = (4, 2, 8);
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 81);
let w = fill(n, 82);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, tokens);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
report(
"rope dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- rope batched (per-sequence position = row % period) ----
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
#[test]
fn rope_batched_bwd() {
require_gpu();
let (b, s, heads, head_dim) = (3, 4, 2, 8);
let tokens = b * s;
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 83);
let w = fill(n, 84);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, s);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
report(
"rope batched dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- softmax ----
#[test]
fn softmax_bwd() {
require_gpu();
let (rows, cols) = (4, 10);
let x_h = fill(rows * cols, 91);
let w = fill(rows * cols, 92);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let out = ops::softmax(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
report(
"softmax dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
#[test]
fn cross_entropy_bwd() {
require_gpu();
let (rows, cols) = (5, 8);
let x_h = fill(rows * cols, 101);
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::cross_entropy(&x, &target);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
/ rows as f32
};
report(
"cross_entropy dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
#[test]
fn fanout_grad_accumulation() {
require_gpu();
let n = 12;
let x_h = fill(n, 111);
let w = vec![1.0f32; n];
let x = Var::leaf(cuda(&x_h, &[n]));
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
let out = ops::add(&sq1, &sq2); // 2x²
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = cuda(v, s);
let o = t.mul(&t).add(&t.mul(&t));
weighted_sum(&o, &wf)
};
// Analytic dx should be 4x; fan-out summed all four uses of x.
report(
"fanout dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
#[test]
fn attention_composed_bwd() {
require_gpu();
let (s, d) = (5, 6); // seq_len, head_dim
let scale = 1.0 / (d as f32).sqrt();
let q_h = fill(s * d, 121);
let k_h = fill(s * d, 122);
let v_h = fill(s * d, 123);
let w = fill(s * d, 124); // weights over the [s,d] attention output
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
let kt = transpose_var(k); // [d,s] (manual transpose node)
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
let probs = ops::softmax(&scores);
ops::matmul(&probs, v) // [s,d]
};
let q = Var::leaf(cuda(&q_h, &[s, d]));
let k = Var::leaf(cuda(&k_h, &[s, d]));
let v = Var::leaf(cuda(&v_h, &[s, d]));
let out = attn(&q, &k, &v);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
// Re-run the same forward inside the loss closures (host-side) per input.
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[s, d]);
let kv = cuda(kh, &[s, d]);
let vv = cuda(vh, &[s, d]);
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
let probs = scores.softmax();
weighted_sum(&probs.matmul(&vv), &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn dQ",
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn dK",
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn dV",
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
);
}
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
#[test]
fn transpose_4d12_bwd() {
require_gpu();
let (a, b, c, d) = (2, 3, 4, 5);
let n = a * b * c * d;
let x_h = fill(n, 131);
let w = fill(n, 132);
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
let out = ops::transpose_4d12(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
report(
"transpose_4d12 dX",
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- fused batched causal attention (the T10 op) ----
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
// GEMM stride; the causal mask is applied inside the op.
#[test]
fn attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 141);
let k_h = fill(n, 142);
let v_h = fill(n, 143);
let w = fill(n, 144);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn(batched) dQ",
&grad_check(
&q_h,
&[bh, seq, hd],
&lq,
dq.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn(batched) dK",
&grad_check(
&k_h,
&[bh, seq, hd],
&lk,
dk.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn(batched) dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
// produces.
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
// gradients are well-conditioned, not the near-zero saddle values a uniform
// softmax produces (those make central finite-diff give spurious 0.0 / sign
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
let v_h = fill(n, 243);
let w = fill(n, 244);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::flash_attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.flash_attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
let cfg_attn = GradCheckConfig {
eps: 2e-3,
rel_tol: 3e-2,
atol: 1e-3,
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"flash dQ",
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"flash dK",
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"flash dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// flash forward must equal the composed attention forward (same SDPA math).
#[test]
fn flash_matches_composed_fwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
let (oc, _) = q.attention(&k, &v, scale);
let (of, _) = q.flash_attention(&k, &v, scale);
let oc = oc.to_device(Device::Cpu);
let of = of.to_device(Device::Cpu);
let max_rel = oc
.as_slice::<f32>()
.iter()
.zip(of.as_slice::<f32>())
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
.fold(0.0f32, f32::max);
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
assert!(
max_rel < 1e-4,
"flash fwd diverges from composed: {max_rel:.3e}"
);
}
// flash backward must equal the (already grad-checked) composed backward. This is
// a sharper test than finite-diff: both share the trusted composed forward as the
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
// near-zero gradient elements.
#[test]
fn flash_bwd_matches_composed_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 441);
let k_h = fill(n, 442);
let v_h = fill(n, 443);
let w = fill(n, 444);
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = if flash {
ops::flash_attention(&q, &k, &v, scale)
} else {
ops::attention(&q, &k, &v, scale)
};
scalar_loss(&out, &w).backward();
let g = |x: &Var| {
x.grad()
.unwrap()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
};
(g(&q), g(&k), g(&v))
};
let (cq, ck, cv) = run(false);
let (fq, fk, fv) = run(true);
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
.fold(0.0f32, f32::max)
};
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// ---- GQA repeat_kv head broadcast (Phase T15) ----
//
// repeat_kv expands K/V from [batch·num_kv, seq, hd] to [batch·nh, seq, hd]; each
// kv head is broadcast to its `group = nh/num_kv` query heads. The forward is a
// gather (a linear map), so finite-diff is clean. The CRITICAL gate is the
// BACKWARD: a kv head receives the SUM of the `group` query heads sharing it —
// the multi-group-to-one grad accumulation GQA correctness hinges on. We grad-check
// din against finite-diff of L = sum(W∘out) with group>1, plus assert the forward
// actually broadcasts and that group==1 is exact identity.
#[test]
fn repeat_kv_grad() {
require_gpu();
// batch 2, num_kv 2 → bh_kv 4 input rows; nh 6 → group 3, bh_q 12 output rows.
let (batch, num_kv, nh, seq, hd) = (2usize, 2usize, 6usize, 4usize, 5usize);
let n_in = batch * num_kv * seq * hd;
let n_out = batch * nh * seq * hd;
let x_h = fill(n_in, 711);
let w = fill(n_out, 712);
let kv = Var::leaf(cuda(&x_h, &[batch * num_kv, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch);
assert_eq!(out.value().shape(), &[batch * nh, seq, hd]);
// Forward sanity: out head (b·nh + qh) must equal in head (b·num_kv + qh/group).
let group = nh / num_kv;
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
let row = seq * hd;
for b in 0..batch {
for qh in 0..nh {
let kvh = qh / group;
let o0 = (b * nh + qh) * row;
let i0 = (b * num_kv + kvh) * row;
for e in 0..row {
assert_eq!(out_h[o0 + e], x_h[i0 + e], "repeat_kv fwd mismatch");
}
}
}
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
let fwd = move |xh: &[f32], _s: &[usize]| -> f32 {
let kv = cuda(xh, &[batch * num_kv, seq, hd]);
let o = kv.repeat_kv(nh, batch);
weighted_sum(&o, &w)
};
// repeat_kv is exactly linear (gather/sum), so the linear-op tolerances apply.
report(
"repeat_kv din",
&grad_check(
&x_h,
&[batch * num_kv, seq, hd],
&fwd,
din.as_slice::<f32>(),
cfg_linear(),
),
);
}
// group==1 (num_kv == nh) must be a bit-exact identity in BOTH directions — this is
// the regression guard that makes the MHA path (kv_heads == n_heads) unchanged.
#[test]
fn repeat_kv_identity_group1() {
require_gpu();
let (batch, nh, seq, hd) = (2usize, 3usize, 4usize, 5usize);
let n = batch * nh * seq * hd;
let x_h = fill(n, 721);
let w = fill(n, 722);
let kv = Var::leaf(cuda(&x_h, &[batch * nh, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch); // group 1
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
assert_eq!(out_h, x_h, "group-1 repeat_kv fwd must be identity");
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
// dL/din = w exactly (identity forward → grad passes through unchanged).
for (g, expect) in din.as_slice::<f32>().iter().zip(&w) {
assert_eq!(*g, *expect, "group-1 repeat_kv bwd must be identity");
}
}
// ---- dropout (Phase T18) ----
//
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
// (it depends only on (seed, index), NOT on x), so dropout is a fixed elementwise
// linear map `out_i = c_i·x_i` and the central difference of L is differentiable:
// the ± perturbation of each x_i sees the SAME mask. The forward function in the
// closure calls `ops::dropout(x, p, SEED)` with the same SEED, so it reproduces
// the same mask both times.
#[test]
fn dropout_bwd() {
require_gpu();
const SEED: u64 = 0xD120_FE5E;
let p = 0.3f32;
let (m, n) = (16, 12);
let x_h = fill(m * n, 71);
let w = fill(m * n, 72);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let out = ops::dropout(&x, p, SEED);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let o = ops::dropout(&Var::leaf(cuda(v, s)), p, SEED);
weighted_sum(&o.value(), &wf)
};
report(
"dropout dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// Inverted-dropout expectation + keep-rate check. Over a large tensor and a sweep
// of seeds, the mean of dropout(x) tracks the mean of x (E[out] ≈ x, the inverted
// 1/(1-p) scaling), and the kept fraction tracks 1-p (the RNG is ~Bernoulli).
#[test]
fn dropout_expectation_and_keep_rate() {
require_gpu();
let p = 0.25f32;
let n = 200_000usize;
let x_h = vec![1.0f32; n]; // mean(x) = 1 → mean(out) should ≈ 1
let x = cuda(&x_h, &[n]);
let trials = 8;
let mut mean_out_acc = 0.0f64;
let mut keep_acc = 0.0f64;
for t in 0..trials {
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
let out_h = out.to_device(Device::Cpu);
let mask_h = mask.to_device(Device::Cpu);
let mean_out: f64 = out_h
.as_slice::<f32>()
.iter()
.map(|&v| v as f64)
.sum::<f64>()
/ n as f64;
let kept = mask_h
.as_slice::<f32>()
.iter()
.filter(|&&m| m != 0.0)
.count();
mean_out_acc += mean_out;
keep_acc += kept as f64 / n as f64;
}
let mean_out = mean_out_acc / trials as f64;
let keep_rate = keep_acc / trials as f64;
println!(
"dropout p={p}: E[out]={mean_out:.5} (input mean 1.0), keep_rate={keep_rate:.5} (1-p={:.3})",
1.0 - p
);
assert!(
(mean_out - 1.0).abs() < 0.01,
"E[out] {mean_out} not ≈ input mean 1.0 (inverted scaling broken)"
);
assert!(
(keep_rate - (1.0 - p) as f64).abs() < 0.01,
"keep_rate {keep_rate} not ≈ 1-p {}",
1.0 - p
);
}
// p=0 is a no-op (the op returns x.clone(), no node) → output is bit-identical to
// x and its grad flows straight through (the default-graph regression guard at the
// op level; the model-level bit-identity is in xtrain-model/tests/dropout.rs).
#[test]
fn dropout_p0_is_identity() {
require_gpu();
let (m, n) = (8, 5);
let x_h = fill(m * n, 91);
let x = cuda(&x_h, &[m, n]);
let (out, _mask) = x.dropout(0.0, 12345);
let out_h = out.to_device(Device::Cpu);
for (a, b) in x_h.iter().zip(out_h.as_slice::<f32>()) {
assert_eq!(*a, *b, "p=0 dropout must be identity");
}
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
let wt = Var::leaf(cuda(w, out.value().shape()));
let prod = ops::mul(out, &wt);
sum_all(&prod)
}
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
fn sum_all(x: &Var) -> Var {
let xv = x.value();
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
let shape: Vec<usize> = xv.shape().to_vec();
Var::from_op(
scalar,
vec![x.clone()],
Box::new(move |d, parents| {
// d is [1]; broadcast d to a same-shape tensor over the input.
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ones = vec![dval; shape.iter().product()];
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
Var::push_grad(&parents[0], g);
}),
)
}
// Manual transpose node for the composed-attention test (the engine has no
// transpose op; xserv does the equivalent host-side reshape around RoPE).
fn transpose_var(x: &Var) -> Var {
let xt = x.value().transpose_2d();
Var::from_op(
xt,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_2d());
}),
)
}
// seq_logprob (M3 DPO): Σ log p(target) over non-ignored rows. Grad-check with a
// completion mask — rows 0,1 are -100 (prompt, contribute 0), rows 2..6 supervised.
#[test]
fn seq_logprob_bwd() {
require_gpu();
let (rows, cols) = (6usize, 9usize);
let x_h = fill(rows * cols, 202);
let targets: Vec<i32> = (0..rows)
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
.collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let lp = ops::seq_logprob(&x, &target);
lp.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Numeric scalar = seq_logprob = −Σ per_row (per_row is 0 for ignored rows).
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
-per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
};
report(
"seq_logprob dX",
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// dpo_loss (M3): scalar DPO loss with the two policy logprobs as parents. Grad-check
// each parent (finite diff of softplus(−Δ)) + the degenerate points the gate pins:
// policy==reference ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0.
#[test]
fn dpo_loss_bwd_and_degenerate() {
require_gpu();
let (ref_c, ref_r, beta) = (0.5f32, 0.9f32, 0.1f32);
let (pc0, pr0) = (1.2f32, 0.7f32);
let softplus = |z: f32| z.max(0.0) + (-(z.abs())).exp().ln_1p();
let pc = Var::leaf(cuda(&[pc0], &[1]));
let pr = Var::leaf(cuda(&[pr0], &[1]));
let l = ops::dpo_loss(&pc, &pr, ref_c, ref_r, beta);
l.backward();
let dpc = pc.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let dpr = pr.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let l_of_pc = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((v[0] - ref_c) - (pr0 - ref_r))));
report("dpo_loss dpc", &grad_check(&[pc0], &[1], &l_of_pc, &[dpc], cfg_nonlinear()));
let l_of_pr = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((pc0 - ref_c) - (v[0] - ref_r))));
report("dpo_loss dpr", &grad_check(&[pr0], &[1], &l_of_pr, &[dpr], cfg_nonlinear()));
// Degenerate 1: policy == reference ⇒ Δ=0 ⇒ L=log2, grads = (∓β/2).
let pc2 = Var::leaf(cuda(&[ref_c], &[1]));
let pr2 = Var::leaf(cuda(&[ref_r], &[1]));
let l2 = ops::dpo_loss(&pc2, &pr2, ref_c, ref_r, beta);
let lval = l2.value().to_device(Device::Cpu).as_slice::<f32>()[0];
l2.backward();
let d2c = pc2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let d2r = pr2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
assert!((lval - 2f32.ln()).abs() < 1e-5, "L at Δ=0 must be log2, got {lval}");
assert!(
(d2c + beta * 0.5).abs() < 1e-5 && (d2r - beta * 0.5).abs() < 1e-5,
"grads at Δ=0 must be ∓β/2, got ({d2c},{d2r})"
);
// Degenerate 2: β=0 ⇒ grads 0.
let pc3 = Var::leaf(cuda(&[pc0], &[1]));
let pr3 = Var::leaf(cuda(&[pr0], &[1]));
let l3 = ops::dpo_loss(&pc3, &pr3, ref_c, ref_r, 0.0);
l3.backward();
let d3c = pc3.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
}
// clipped_pg_loss (M4 GRPO): per-token clipped PG + k3 KL, one completion. Grad-check
// the active (in-trust-region) path + the A=0 (KL-only) path, plus value-level
// degenerate checks (ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL).
#[test]
fn clipped_pg_loss_bwd_and_degenerate() {
require_gpu();
let (rows, cols) = (6usize, 10usize);
let x_h = fill(rows * cols, 303);
// rows 0,1 masked (prompt); 2..6 supervised (completion).
let targets: Vec<i32> = (0..rows)
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
.collect();
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
// logp_old = logπθ at the base logits ⇒ ρ≈1 (in trust region → active path).
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
let logp_old: Vec<f32> = per_row0
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect();
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect(); // exercise KL
let (eps, beta) = (0.2f32, 0.1f32);
// Host replica of the forward loss as a function of per-row CE values.
let host_loss = {
let (tg, lo, lr) = (targets.clone(), logp_old.clone(), logp_ref.clone());
move |per_row_h: &[f32], a: f32, e: f32, b: f32| -> f32 {
let (mut pg, mut kl, mut n) = (0f32, 0f32, 0f32);
for t in 0..per_row_h.len() {
if tg[t] < 0 {
continue;
}
n += 1.0;
let lp = -per_row_h[t];
let ratio = (lp - lo[t]).exp();
let clipped = ratio.clamp(1.0 - e, 1.0 + e);
pg += (ratio * a).min(clipped * a);
let d = lr[t] - lp;
kl += d.exp() - d - 1.0;
}
let inv = if n > 0.0 { 1.0 / n } else { 1.0 };
-pg * inv + b * kl * inv
}
};
let per_row_of = |v: &[f32], s: &[usize]| {
let (_, pr) = cuda(v, s).cross_entropy(&mk_target());
pr.to_device(Device::Cpu).as_slice::<f32>().to_vec()
};
// (1) grad-check the active PG path (A>0, ρ≈1).
let adv = 0.7f32;
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::clipped_pg_loss(&x, &mk_target(), &logp_old, &logp_ref, adv, eps, beta);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let hl = host_loss.clone();
let lx = move |v: &[f32], s: &[usize]| hl(&per_row_of(v, s), adv, eps, beta);
report(
"clipped_pg dX (active)",
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
// (2) grad-check the A=0 path (loss = β·mean KL; PG gradient must vanish).
let x0 = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss0 = ops::clipped_pg_loss(&x0, &mk_target(), &logp_old, &logp_ref, 0.0, eps, beta);
loss0.backward();
let dx0 = x0.grad().unwrap().to_device(Device::Cpu);
let hl0 = host_loss.clone();
let lx0 = move |v: &[f32], s: &[usize]| hl0(&per_row_of(v, s), 0.0, eps, beta);
report(
"clipped_pg dX (A=0, KL only)",
&grad_check(&x_h, &[rows, cols], &lx0, dx0.as_slice::<f32>(), cfg_nonlinear()),
);
// (3) ε→∞ ⇒ vanilla PG (no clip): loss value == mean(ρA) + β·mean KL.
let big = 1e9f32;
let lv = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, big, beta);
let got = lv.value().to_device(Device::Cpu).as_slice::<f32>()[0];
let pr0 = per_row_of(&x_h, &[rows, cols]);
let want = host_loss(&pr0, adv, big, beta);
assert!((got - want).abs() < 1e-4, "ε→∞ vanilla loss mismatch: {got} vs {want}");
// (4) β=0 ⇒ no KL term (loss == mean pg only).
let lvb = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, eps, 0.0);
let gotb = lvb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
let wantb = host_loss(&pr0, adv, eps, 0.0);
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
}
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
#[test]
fn clipped_pg_loss_batched_matches_looped() {
require_gpu();
let (n, lmax, cols) = (3usize, 5usize, 10usize);
let rows = n * lmax;
let x_h = fill(rows * cols, 909);
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
let real_len = [3usize, 4, 2];
let adv_s = [0.7f32, -0.5, 0.3];
let mut targets = vec![-100i32; rows];
for s in 0..n {
for r in 1..real_len[s] {
let t = s * lmax + r;
targets[t] = ((t * 3) % cols) as i32;
}
}
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
let logp_old: Vec<f32> = per_row0
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect();
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
let (eps, beta) = (0.2f32, 0.1f32);
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
let mut advantage = vec![0f32; rows];
let mut weight = vec![0f32; rows];
for s in 0..n {
let w = (1.0 / n as f32) * (1.0 / n_of(s));
for r in 0..lmax {
advantage[s * lmax + r] = adv_s[s];
weight[s * lmax + r] = w;
}
}
// Batched: one packed [R, vocab] forward + one backward.
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
let lb = ops::clipped_pg_loss_batched(
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
);
lb.backward();
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
let mut g_ref = vec![0f32; rows * cols];
let mut loss_ref = 0f32;
for s in 0..n {
let r0 = s * lmax;
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
let lo_s = logp_old[r0..r0 + lmax].to_vec();
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
let scaled = ops::scale(&ls, 1.0 / n as f32);
scaled.backward();
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
}
let max_g = gb
.iter()
.zip(&g_ref)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
(lb_val - loss_ref).abs() < 1e-5,
"batched loss {lb_val} vs looped {loss_ref}"
);
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
println!(
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
(lb_val - loss_ref).abs(),
max_g
);
}