test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped): - xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded ragged sequences == per-sequence single-seq forward on the real rows. fp32 max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is free under causal". - xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32). bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards (loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms (9.2×), training forwards 2526→280 ms (9.0×). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -1177,3 +1177,94 @@ fn clipped_pg_loss_bwd_and_degenerate() {
|
||||
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
|
||||
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
|
||||
}
|
||||
|
||||
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
|
||||
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
|
||||
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
|
||||
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
|
||||
#[test]
|
||||
fn clipped_pg_loss_batched_matches_looped() {
|
||||
require_gpu();
|
||||
let (n, lmax, cols) = (3usize, 5usize, 10usize);
|
||||
let rows = n * lmax;
|
||||
let x_h = fill(rows * cols, 909);
|
||||
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
|
||||
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
|
||||
let real_len = [3usize, 4, 2];
|
||||
let adv_s = [0.7f32, -0.5, 0.3];
|
||||
let mut targets = vec![-100i32; rows];
|
||||
for s in 0..n {
|
||||
for r in 1..real_len[s] {
|
||||
let t = s * lmax + r;
|
||||
targets[t] = ((t * 3) % cols) as i32;
|
||||
}
|
||||
}
|
||||
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||
|
||||
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
|
||||
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
|
||||
let logp_old: Vec<f32> = per_row0
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect();
|
||||
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
|
||||
let (eps, beta) = (0.2f32, 0.1f32);
|
||||
|
||||
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
|
||||
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
|
||||
let mut advantage = vec![0f32; rows];
|
||||
let mut weight = vec![0f32; rows];
|
||||
for s in 0..n {
|
||||
let w = (1.0 / n as f32) * (1.0 / n_of(s));
|
||||
for r in 0..lmax {
|
||||
advantage[s * lmax + r] = adv_s[s];
|
||||
weight[s * lmax + r] = w;
|
||||
}
|
||||
}
|
||||
|
||||
// Batched: one packed [R, vocab] forward + one backward.
|
||||
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let lb = ops::clipped_pg_loss_batched(
|
||||
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
|
||||
);
|
||||
lb.backward();
|
||||
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
|
||||
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
|
||||
let mut g_ref = vec![0f32; rows * cols];
|
||||
let mut loss_ref = 0f32;
|
||||
for s in 0..n {
|
||||
let r0 = s * lmax;
|
||||
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
|
||||
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
|
||||
let lo_s = logp_old[r0..r0 + lmax].to_vec();
|
||||
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
|
||||
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
|
||||
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
|
||||
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
|
||||
let scaled = ops::scale(&ls, 1.0 / n as f32);
|
||||
scaled.backward();
|
||||
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
|
||||
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
}
|
||||
|
||||
let max_g = gb
|
||||
.iter()
|
||||
.zip(&g_ref)
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.fold(0.0f32, f32::max);
|
||||
assert!(
|
||||
(lb_val - loss_ref).abs() < 1e-5,
|
||||
"batched loss {lb_val} vs looped {loss_ref}"
|
||||
);
|
||||
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
|
||||
println!(
|
||||
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
|
||||
(lb_val - loss_ref).abs(),
|
||||
max_g
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user