Files
xtrain/crates
Gahow Wang c2ebf62ae1 post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)
After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.

Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).

- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
  (the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
  compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
  looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
  row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
  forward_batched(N) → slice back to real length) + looped baselines +
  inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
  [chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
  grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
  (rollout / capture / inner) to keep the step decomposition honest. Default micro =
  B·G; bench-measured 9× on the training forwards.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:02:56 +08:00
..