Files
xtrain/crates/xtrain-train
Gahow Wang 0e82b2438e test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
  ragged sequences == per-sequence single-seq forward on the real rows. fp32
  max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
  free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
  Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).

bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:09 +08:00
..