Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
ragged sequences == per-sequence single-seq forward on the real rows. fp32
max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).
bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>