post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)

After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.

Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).

- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
  (the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
  compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
  looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
  row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
  forward_batched(N) → slice back to real length) + looped baselines +
  inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
  [chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
  grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
  (rollout / capture / inner) to keep the step decomposition honest. Default micro =
  B·G; bench-measured 9× on the training forwards.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 23:02:56 +08:00
parent 41d46208a6
commit c2ebf62ae1
4 changed files with 303 additions and 49 deletions

View File

@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
}),
)
}
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
/// (`target < 0`), so they contribute nothing and carry no gradient — the
/// **right-pad-is-free-under-causal-attention** property (a real completion row
/// never attends to the trailing pad rows, so its logits equal the unpadded
/// single-sequence forward's).
///
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
#[allow(clippy::too_many_arguments)]
pub fn clipped_pg_loss_batched(
logits: &Var,
target: &Tensor,
logp_old: &[f32],
logp_ref: &[f32],
advantage: &[f32],
weight: &[f32],
eps: f32,
beta: f32,
) -> Var {
use xtrain_tensor::Device;
let logit_dtype = logits.value().dtype();
let (probs, per_row) = logits.value().cross_entropy(target);
let rows = per_row.shape()[0];
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
assert_eq!(weight.len(), rows, "weight must have one entry per row");
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
let mut loss_val = 0f32;
for t in 0..rows {
if target_h[t] < 0 {
continue; // masked (prompt or pad) row — no contribution, no gradient
}
let (a, w) = (advantage[t], weight[t]);
let lp = -per_row_h[t]; // logπθ_t
let ratio = (lp - logp_old[t]).exp();
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
let pg_t = unclipped_term.min(clipped_term);
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
let d = logp_ref[t] - lp;
let kl_t = d.exp() - d - 1.0;
let pg_grad = if active { -a * ratio } else { 0.0 };
let kl_grad = beta * (1.0 - d.exp());
// The full per-row normaliser is folded into s (no global inv_n here).
s[t] = -(pg_grad + kl_grad) * w;
loss_val += (-pg_t + beta * kl_t) * w;
}
let dev = logits.value().device();
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
let target = target.clone();
Var::from_op(
out,
vec![logits.clone()],
Box::new(move |d, parents| {
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
let mut dx = ce.scale_rows(&s_dev);
if up != 1.0 {
dx = dx.scale(up);
}
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}