post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)
After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.
Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).
- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
(the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
forward_batched(N) → slice back to real length) + looped baselines +
inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
[chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
(rollout / capture / inner) to keep the step decomposition honest. Default micro =
B·G; bench-measured 9× on the training forwards.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
|
||||
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
|
||||
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
|
||||
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
|
||||
/// (`target < 0`), so they contribute nothing and carry no gradient — the
|
||||
/// **right-pad-is-free-under-causal-attention** property (a real completion row
|
||||
/// never attends to the trailing pad rows, so its logits equal the unpadded
|
||||
/// single-sequence forward's).
|
||||
///
|
||||
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
|
||||
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
|
||||
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
|
||||
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
|
||||
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
|
||||
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
|
||||
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
|
||||
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
|
||||
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
|
||||
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
|
||||
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
|
||||
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn clipped_pg_loss_batched(
|
||||
logits: &Var,
|
||||
target: &Tensor,
|
||||
logp_old: &[f32],
|
||||
logp_ref: &[f32],
|
||||
advantage: &[f32],
|
||||
weight: &[f32],
|
||||
eps: f32,
|
||||
beta: f32,
|
||||
) -> Var {
|
||||
use xtrain_tensor::Device;
|
||||
let logit_dtype = logits.value().dtype();
|
||||
let (probs, per_row) = logits.value().cross_entropy(target);
|
||||
let rows = per_row.shape()[0];
|
||||
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
|
||||
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
|
||||
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
|
||||
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
|
||||
assert_eq!(weight.len(), rows, "weight must have one entry per row");
|
||||
|
||||
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
|
||||
let mut loss_val = 0f32;
|
||||
for t in 0..rows {
|
||||
if target_h[t] < 0 {
|
||||
continue; // masked (prompt or pad) row — no contribution, no gradient
|
||||
}
|
||||
let (a, w) = (advantage[t], weight[t]);
|
||||
let lp = -per_row_h[t]; // logπθ_t
|
||||
let ratio = (lp - logp_old[t]).exp();
|
||||
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
|
||||
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
|
||||
let pg_t = unclipped_term.min(clipped_term);
|
||||
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
|
||||
let d = logp_ref[t] - lp;
|
||||
let kl_t = d.exp() - d - 1.0;
|
||||
let pg_grad = if active { -a * ratio } else { 0.0 };
|
||||
let kl_grad = beta * (1.0 - d.exp());
|
||||
// The full per-row normaliser is folded into s (no global inv_n here).
|
||||
s[t] = -(pg_grad + kl_grad) * w;
|
||||
loss_val += (-pg_t + beta * kl_t) * w;
|
||||
}
|
||||
let dev = logits.value().device();
|
||||
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
|
||||
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
|
||||
|
||||
let target = target.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![logits.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
|
||||
let mut dx = ce.scale_rows(&s_dev);
|
||||
if up != 1.0 {
|
||||
dx = dx.scale(up);
|
||||
}
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user