test: M2d — ragged-forward + batched-op equivalence gates + throughput bench

Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
  ragged sequences == per-sequence single-seq forward on the real rows. fp32
  max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
  free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
  Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).

bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 23:03:09 +08:00
parent c2ebf62ae1
commit 0e82b2438e
3 changed files with 456 additions and 0 deletions

View File

@@ -1177,3 +1177,94 @@ fn clipped_pg_loss_bwd_and_degenerate() {
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
}
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
#[test]
fn clipped_pg_loss_batched_matches_looped() {
require_gpu();
let (n, lmax, cols) = (3usize, 5usize, 10usize);
let rows = n * lmax;
let x_h = fill(rows * cols, 909);
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
let real_len = [3usize, 4, 2];
let adv_s = [0.7f32, -0.5, 0.3];
let mut targets = vec![-100i32; rows];
for s in 0..n {
for r in 1..real_len[s] {
let t = s * lmax + r;
targets[t] = ((t * 3) % cols) as i32;
}
}
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
let logp_old: Vec<f32> = per_row0
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect();
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
let (eps, beta) = (0.2f32, 0.1f32);
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
let mut advantage = vec![0f32; rows];
let mut weight = vec![0f32; rows];
for s in 0..n {
let w = (1.0 / n as f32) * (1.0 / n_of(s));
for r in 0..lmax {
advantage[s * lmax + r] = adv_s[s];
weight[s * lmax + r] = w;
}
}
// Batched: one packed [R, vocab] forward + one backward.
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
let lb = ops::clipped_pg_loss_batched(
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
);
lb.backward();
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
let mut g_ref = vec![0f32; rows * cols];
let mut loss_ref = 0f32;
for s in 0..n {
let r0 = s * lmax;
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
let lo_s = logp_old[r0..r0 + lmax].to_vec();
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
let scaled = ops::scale(&ls, 1.0 / n as f32);
scaled.backward();
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
}
let max_g = gb
.iter()
.zip(&g_ref)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
(lb_val - loss_ref).abs() < 1e-5,
"batched loss {lb_val} vs looped {loss_ref}"
);
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
println!(
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
(lb_val - loss_ref).abs(),
max_g
);
}