post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)

After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.

Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).

- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
  (the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
  compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
  looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
  row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
  forward_batched(N) → slice back to real length) + looped baselines +
  inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
  [chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
  grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
  (rollout / capture / inner) to keep the step decomposition honest. Default micro =
  B·G; bench-measured 9× on the training forwards.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 23:02:56 +08:00
parent 41d46208a6
commit c2ebf62ae1
4 changed files with 303 additions and 49 deletions

View File

@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
}),
)
}
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
/// (`target < 0`), so they contribute nothing and carry no gradient — the
/// **right-pad-is-free-under-causal-attention** property (a real completion row
/// never attends to the trailing pad rows, so its logits equal the unpadded
/// single-sequence forward's).
///
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
#[allow(clippy::too_many_arguments)]
pub fn clipped_pg_loss_batched(
logits: &Var,
target: &Tensor,
logp_old: &[f32],
logp_ref: &[f32],
advantage: &[f32],
weight: &[f32],
eps: f32,
beta: f32,
) -> Var {
use xtrain_tensor::Device;
let logit_dtype = logits.value().dtype();
let (probs, per_row) = logits.value().cross_entropy(target);
let rows = per_row.shape()[0];
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
assert_eq!(weight.len(), rows, "weight must have one entry per row");
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
let mut loss_val = 0f32;
for t in 0..rows {
if target_h[t] < 0 {
continue; // masked (prompt or pad) row — no contribution, no gradient
}
let (a, w) = (advantage[t], weight[t]);
let lp = -per_row_h[t]; // logπθ_t
let ratio = (lp - logp_old[t]).exp();
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
let pg_t = unclipped_term.min(clipped_term);
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
let d = logp_ref[t] - lp;
let kl_t = d.exp() - d - 1.0;
let pg_grad = if active { -a * ratio } else { 0.0 };
let kl_grad = beta * (1.0 - d.exp());
// The full per-row normaliser is folded into s (no global inv_n here).
s[t] = -(pg_grad + kl_grad) * w;
loss_val += (-pg_t + beta * kl_t) * w;
}
let dev = logits.value().device();
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
let target = target.clone();
Var::from_op(
out,
vec![logits.clone()],
Box::new(move |d, parents| {
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
let mut dx = ce.scale_rows(&s_dev);
if up != 1.0 {
dx = dx.scale(up);
}
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}

View File

@@ -23,15 +23,15 @@ fn main() {
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
use xtrain_model::{Config, TinyTransformer, generate_cached_batch};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))]
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, per_token_logp_batched};
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
#[cfg(not(no_cuda))]
@@ -117,20 +117,6 @@ fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) ->
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= per_row
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
#[cfg(not(no_cuda))]
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
@@ -149,6 +135,9 @@ fn main() {
let group: usize = flag(&args, "--group", 6);
let n_prompts: usize = flag(&args, "--prompts", 8);
let inner: usize = flag(&args, "--inner", 1);
// M2d: pack the step's N=B·G ragged samples into forward_batched chunks of this
// many samples (bounds the [chunk·Lmax, vocab] logits memory). Default = whole batch.
let micro: usize = flag(&args, "--micro", n_prompts * group.max(1));
let temp: f32 = flag(&args, "--temp", 1.0);
let beta: f32 = flag(&args, "--beta", 0.04);
let eps: f32 = flag(&args, "--eps", 0.2);
@@ -188,16 +177,17 @@ fn main() {
let start = std::time::Instant::now();
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
// Per-window phase timers (ms): rollout / capture / inner — to keep the step
// decomposition honest (M2d cut the training-side forwards 9×, so the question is
// what now dominates the step).
let (mut t_roll, mut t_cap, mut t_inner) = (0f32, 0f32, 0f32);
for step in 0..steps {
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
struct Sample {
input: Vec<i32>,
target: Vec<i32>,
adv: f32,
logp_old: Vec<f32>,
logp_ref: Vec<f32>,
}
let mut batch: Vec<Sample> = Vec::new();
// Collect ALL the step's framed samples first (input, target, adv), so the
// training-side forwards can be batched across the whole step (M2d) instead of
// run one ragged sequence at a time.
let t0 = std::time::Instant::now();
let mut raw: Vec<(Vec<i32>, Vec<i32>, f32)> = Vec::new();
for _ in 0..n_prompts {
let p = gen_problem(&mut rng, &gcfg);
let prompt_ids: Vec<i32> = tok
@@ -230,53 +220,69 @@ fn main() {
for (seg, r) in &comps {
let adv = (r - mean) / (std + 1e-4);
let (input, target) = frame(&tok, &p.question(), seg);
let logp_old = per_token_logp(&policy, device, &input, &target);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp(r, device, &input, &target),
None => vec![0.0; logp_old.len()],
};
batch.push(Sample { input, target, adv, logp_old, logp_ref });
raw.push((input, target, adv));
}
}
// ---- K inner clipped-PG epochs over the captured batch ----
if !batch.is_empty() {
let scale = 1.0 / batch.len() as f32;
t_roll += t0.elapsed().as_secs_f32() * 1e3;
// ---- Batched capture (M2d): logπ_old (policy) + logπ_ref (frozen) over ALL
// samples in forward_batched chunks, instead of one forward per sample. ----
if !raw.is_empty() {
let t1 = std::time::Instant::now();
let io: Vec<(Vec<i32>, Vec<i32>)> = raw.iter().map(|(i, t, _)| (i.clone(), t.clone())).collect();
let logp_old = per_token_logp_batched(&policy, device, &io, micro);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp_batched(r, device, &io, micro),
None => raw.iter().map(|(i, _, _)| vec![0.0; i.len()]).collect(),
};
let batch: Vec<PgSample> = raw
.iter()
.zip(logp_old)
.zip(logp_ref)
.map(|(((input, target, adv), lo), lr)| PgSample {
input: input.clone(),
target: target.clone(),
adv: *adv,
logp_old: lo,
logp_ref: lr,
})
.collect();
t_cap += t1.elapsed().as_secs_f32() * 1e3;
// ---- K inner clipped-PG epochs, batched over the captured samples ----
let t2 = std::time::Instant::now();
for _ in 0..inner {
for s in &batch {
let logits = policy.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(
&logits,
&ids_tensor(&s.target, device),
&s.logp_old,
&s.logp_ref,
s.adv,
eps,
beta,
);
ops::scale(&loss, scale).backward();
}
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
t_inner += t2.elapsed().as_secs_f32() * 1e3;
}
if (step + 1) % log_every == 0 || step == steps - 1 {
let w = log_every.min(step + 1) as f32; // steps in this window
println!(
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s | ms/step roll {:.0} cap {:.0} inner {:.0}",
step + 1,
win_reward / win_n.max(1) as f32,
win_solved,
win_n,
start.elapsed().as_secs_f32(),
t_roll / w,
t_cap / w,
t_inner / w,
);
win_reward = 0.0;
win_solved = 0;
win_n = 0;
t_roll = 0.0;
t_cap = 0.0;
t_inner = 0.0;
// Periodic save so a later OOM (naive rollout fragments the allocator —
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save");

View File

@@ -0,0 +1,162 @@
//! Batched GRPO training-side forwards (post-training M2d). After M2b/M2c made the
//! rollout cheap, the GRPO **step** is dominated by the per-sample full-sequence
//! forwards: the `per_token_logp` captures (policy + reference) and the inner
//! clipped-PG `forward`/`backward`s — each a single-sequence `forward` over a short
//! ragged completion. This module packs the `N = B·G` ragged samples of a step into
//! ONE `forward_batched`, amortising the per-launch overhead across N (the same win
//! M2b gave the rollout).
//!
//! The enabling property: **right-padding is free under causal attention.** Pad each
//! ragged completion on the RIGHT to the batch's `Lmax`; a real completion row is at
//! an earlier position than the trailing pad, and causal masking forbids attending
//! forward, so its logits are bit-identical to the unpadded single-sequence forward.
//! The pad rows' own outputs are garbage but are masked out (`target = -100`).
//!
//! Both the looped (baseline) and batched paths live here so they share one source of
//! truth — `bin/bench_grpo_batch` A/Bs them (timing + a closeness gate), and the
//! per-row equivalence of the loss op is pinned by `clipped_pg_loss_batched_matches_looped`
//! in `xtrain-autodiff/tests/autograd.rs`.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_tensor::{Device, Tensor};
/// One framed completion of a GRPO step: the next-token `(input, target)` pair
/// (prompt positions masked to `-100` in `target`), its group-relative `adv`, and the
/// per-position rollout-time / reference logprobs the clipped-PG loss needs.
pub struct PgSample {
pub input: Vec<i32>,
pub target: Vec<i32>,
pub adv: f32,
pub logp_old: Vec<f32>,
pub logp_ref: Vec<f32>,
}
// ------------------------------- looped (baseline) -------------------------------
/// Per-position `logπ(target_t)` of one framed `(input, target)` pair (= `per_row`
/// of cross_entropy; masked positions are 0). One single-sequence forward, no grad.
pub fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
/// One inner clipped-PG epoch the looped way: per sample, a single-sequence forward +
/// [`ops::clipped_pg_loss`] scaled by `1/N` + backward (grads accumulate on `model`'s
/// params). Returns the summed scaled loss. Caller does clip + opt.step + zero_grad.
pub fn inner_pg_step_looped(
model: &TinyTransformer,
device: Device,
batch: &[PgSample],
eps: f32,
beta: f32,
) -> f32 {
let scale = 1.0 / batch.len() as f32;
let mut total = 0f32;
for s in batch {
let logits = model.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(&logits, &ids_tensor(&s.target, device), &s.logp_old, &s.logp_ref, s.adv, eps, beta);
let scaled = ops::scale(&loss, scale);
total += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
scaled.backward();
}
total
}
// ------------------------------- batched (M2d) -----------------------------------
/// Right-pad `m` ragged `i32` rows (each `< lmax` long) to `[m*lmax]` sequence-major,
/// filling with `pad`. Used for both the id stream (pad = 0, arbitrary) and the target
/// stream (pad = 100, ignored by cross_entropy).
fn pack_i32(rows: &[&[i32]], lmax: usize, pad: i32) -> Vec<i32> {
let mut flat = vec![pad; rows.len() * lmax];
for (i, r) in rows.iter().enumerate() {
flat[i * lmax..i * lmax + r.len()].copy_from_slice(r);
}
flat
}
/// Batched [`per_token_logp`]: pack `samples` (each `(input, target)`) right-padded to
/// `Lmax`, run ONE `forward_batched(batch = N)`, and slice each sample's `logπ` back to
/// its real length. Equal to looping [`per_token_logp`] (right-pad is free under causal
/// attention), to bf16 batch-reduction tolerance. `samples` are processed in chunks of
/// `micro` (≥1) to bound the `[chunk*Lmax, vocab]` logits memory.
pub fn per_token_logp_batched(
model: &TinyTransformer,
device: Device,
samples: &[(Vec<i32>, Vec<i32>)],
micro: usize,
) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(samples.len());
for chunk in samples.chunks(micro.max(1)) {
let m = chunk.len();
let lmax = chunk.iter().map(|(i, _)| i.len()).max().unwrap();
let ins: Vec<&[i32]> = chunk.iter().map(|(i, _)| i.as_slice()).collect();
let tgs: Vec<&[i32]> = chunk.iter().map(|(_, t)| t.as_slice()).collect();
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
let logits = model.forward_batched(&ids, m).value();
let (_, per_row) = logits.cross_entropy(&tgt);
let pr = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
for (i, (inp, _)) in chunk.iter().enumerate() {
let b = i * lmax;
out.push((0..inp.len()).map(|r| -pr[b + r]).collect());
}
}
out
}
/// One inner clipped-PG epoch, batched: pack the batch (in `micro`-sized chunks) and run
/// ONE `forward_batched` + [`ops::clipped_pg_loss_batched`] + backward per chunk. The
/// per-row `weight = 1/(N·n_s)` uses the GLOBAL `N = batch.len()` (not the chunk size),
/// so chunked grad-accumulation reproduces the looped `Σ_s (1/N)(1/n_s)…` exactly.
/// Returns the summed loss. Caller does clip + opt.step + zero_grad.
pub fn inner_pg_step_batched(
model: &TinyTransformer,
device: Device,
batch: &[PgSample],
eps: f32,
beta: f32,
micro: usize,
) -> f32 {
let inv_n = 1.0 / batch.len() as f32;
let mut total = 0f32;
for chunk in batch.chunks(micro.max(1)) {
let m = chunk.len();
let lmax = chunk.iter().map(|s| s.input.len()).max().unwrap();
let ins: Vec<&[i32]> = chunk.iter().map(|s| s.input.as_slice()).collect();
let tgs: Vec<&[i32]> = chunk.iter().map(|s| s.target.as_slice()).collect();
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
let mut logp_old = vec![0f32; m * lmax];
let mut logp_ref = vec![0f32; m * lmax];
let mut advantage = vec![0f32; m * lmax];
let mut weight = vec![0f32; m * lmax];
for (i, s) in chunk.iter().enumerate() {
let b = i * lmax;
let li = s.input.len();
logp_old[b..b + li].copy_from_slice(&s.logp_old);
logp_ref[b..b + li].copy_from_slice(&s.logp_ref);
let n_s = s.target.iter().filter(|&&t| t >= 0).count().max(1) as f32;
let w = inv_n / n_s; // = 1/(N · n_s)
for r in 0..lmax {
advantage[b + r] = s.adv;
weight[b + r] = w;
}
}
let logits = model.forward_batched(&ids, m);
let loss = ops::clipped_pg_loss_batched(&logits, &tgt, &logp_old, &logp_ref, &advantage, &weight, eps, beta);
total += loss.value().to_device(Device::Cpu).as_slice::<f32>()[0];
loss.backward();
}
total
}

View File

@@ -15,6 +15,8 @@ pub mod task;
#[cfg(not(no_cuda))]
pub mod checkpoint;
#[cfg(not(no_cuda))]
pub mod grpo_batch;
#[cfg(not(no_cuda))]
pub mod sample;
#[cfg(not(no_cuda))]
mod train_loop;