Compare commits

...

3 Commits

Author SHA1 Message Date
4379868f2d docs: M2d — ragged-batching lever, 9× measured, step bottleneck → rollout
Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free
insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput.

The honest decomposition correction: M2c claimed the training forwards "dominate" the
step; the clean per-component bench falsifies the strong form — they were ~2.5 s of
the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger
share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G
rollout batching (today only the G samples of each prompt decode in lockstep; the B
prompts are still sequential). Same measure-first lesson, once more.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:28 +08:00
0e82b2438e test: M2d — ragged-forward + batched-op equivalence gates + throughput bench
Two exact correctness gates (composed = the end-to-end batched GRPO step == looped):
- xtrain-model forward_batched_ragged_matches_looped: forward_batched on RIGHT-padded
  ragged sequences == per-sequence single-seq forward on the real rows. fp32
  max|Δlogit| = 3.7e-7, bf16 = 0.0, both composed + flash SDPA. Pins "right-pad is
  free under causal".
- xtrain-autodiff clipped_pg_loss_batched_matches_looped: batched op == looped
  Σ_s (1/N)·clipped_pg_loss_s. loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32).

bench_grpo_batch: weight-independent micro-bench of the per-sample training forwards
(loads v12 base as policy, N realistic ragged samples, teacher-forced argmax targets
so the closeness smoke isn't −log-amplified by random low-prob tokens). Measured on
dash5 (v12 1.05B, N=48, micro=16): capture 622→71 ms (8.7×), inner 1907→208 ms
(9.2×), training forwards 2526→280 ms (9.0×).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:09 +08:00
c2ebf62ae1 post-train: M2d — batch the GRPO training-side forwards (op + module + wiring)
After M2b/M2c made the rollout cheap, the GRPO step is dominated by the per-sample
single-sequence training-side forwards: the per_token_logp captures (policy +
reference) and the inner clipped-PG forward/backwards. M2d packs all N=B·G ragged
samples of a step into ONE forward_batched.

Enabling property — right-padding is free under causal attention: a real completion
row sits at an earlier position than the trailing pad, and causal masking forbids
attending forward, so its logits equal the unpadded single-sequence forward; pad
rows are masked out (target=-100).

- ops::clipped_pg_loss_batched: like clipped_pg_loss but takes per-row advantage[t]
  (the owning sample's A) and per-row weight[t] (the full normaliser). It does NOT
  compute its own 1/n_tokens, so the caller passing weight=1/(N·n_s) reproduces the
  looped Σ_s (1/N)(1/n_s)·clipped_pg_loss_s bit-for-bit (per-row CE backward is
  row-local).
- grpo_batch.rs (shared module): per_token_logp_batched (right-pad → one
  forward_batched(N) → slice back to real length) + looped baselines +
  inner_pg_step_{looped,batched}. A --micro knob chunks the pack to bound the
  [chunk·Lmax, vocab] logits memory; weight uses the GLOBAL N so chunked
  grad-accumulation stays exact.
- train_grpo restructured to collect-all-samples-then-batch; per-window phase timers
  (rollout / capture / inner) to keep the step decomposition honest. Default micro =
  B·G; bench-measured 9× on the training forwards.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:02:56 +08:00
9 changed files with 812 additions and 49 deletions

View File

@@ -597,3 +597,87 @@ pub fn clipped_pg_loss(
}),
)
}
/// Batched GRPO clipped-PG loss over `N` ragged completions packed into ONE
/// `forward_batched` (M2d): `logits` is `[R, vocab]` with `R = N·Lmax` rows in
/// sequence-major order (sample 0's `Lmax` rows, then sample 1's, …), each ragged
/// completion right-padded to the batch's `Lmax`. Prompt AND pad rows are masked
/// (`target < 0`), so they contribute nothing and carry no gradient — the
/// **right-pad-is-free-under-causal-attention** property (a real completion row
/// never attends to the trailing pad rows, so its logits equal the unpadded
/// single-sequence forward's).
///
/// Unlike the per-sample [`clipped_pg_loss`] (which folds a single scalar
/// `advantage` and a global `1/N_tokens` normaliser), this op takes **per-row**
/// `advantage[t]` (the owning sample's group-relative `A`) and **per-row**
/// `weight[t]` (the full normaliser, e.g. `1/(N_samples · n_s)` for sample `s`'s
/// completion rows, `0` at masked rows). It does NOT compute its own `inv_n`. With
/// `weight[t] = 1/(N_samples·n_s)` and `advantage[t] = A_s` this is **bit-equivalent
/// to the looped path** `Σ_s scale·(1/n_s)·clipped_pg_loss_s` (`scale = 1/N_samples`):
/// the per-row backward is local (`cross_entropy_backward` is row-wise), so the
/// batched row-`t` gradient equals the looped sample-`s` row-`t` gradient, and the
/// scalar loss equals the looped weighted sum. (`tests/autograd.rs`:
/// `clipped_pg_loss_batched_matches_looped`.) Degenerate points match
/// [`clipped_pg_loss`] (`A=0` ⇒ KL only; `ε→∞` ⇒ vanilla PG; `β=0` ⇒ no KL).
#[allow(clippy::too_many_arguments)]
pub fn clipped_pg_loss_batched(
logits: &Var,
target: &Tensor,
logp_old: &[f32],
logp_ref: &[f32],
advantage: &[f32],
weight: &[f32],
eps: f32,
beta: f32,
) -> Var {
use xtrain_tensor::Device;
let logit_dtype = logits.value().dtype();
let (probs, per_row) = logits.value().cross_entropy(target);
let rows = per_row.shape()[0];
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per row");
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per row");
assert_eq!(advantage.len(), rows, "advantage must have one entry per row");
assert_eq!(weight.len(), rows, "weight must have one entry per row");
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
let mut loss_val = 0f32;
for t in 0..rows {
if target_h[t] < 0 {
continue; // masked (prompt or pad) row — no contribution, no gradient
}
let (a, w) = (advantage[t], weight[t]);
let lp = -per_row_h[t]; // logπθ_t
let ratio = (lp - logp_old[t]).exp();
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
let (unclipped_term, clipped_term) = (ratio * a, clipped * a);
let pg_t = unclipped_term.min(clipped_term);
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
let d = logp_ref[t] - lp;
let kl_t = d.exp() - d - 1.0;
let pg_grad = if active { -a * ratio } else { 0.0 };
let kl_grad = beta * (1.0 - d.exp());
// The full per-row normaliser is folded into s (no global inv_n here).
s[t] = -(pg_grad + kl_grad) * w;
loss_val += (-pg_t + beta * kl_t) * w;
}
let dev = logits.value().device();
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
let target = target.clone();
Var::from_op(
out,
vec![logits.clone()],
Box::new(move |d, parents| {
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
let mut dx = ce.scale_rows(&s_dev);
if up != 1.0 {
dx = dx.scale(up);
}
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}

View File

@@ -1177,3 +1177,94 @@ fn clipped_pg_loss_bwd_and_degenerate() {
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
}
// clipped_pg_loss_batched (M2d): N ragged completions packed + right-padded into ONE
// forward must equal the looped per-sample path Σ_s (1/N)·clipped_pg_loss_s. The
// per-row CE backward is row-local, so folding weight = 1/(N·n_s) into the batched
// op reproduces the looped gradient and weighted-sum loss bit-for-bit (f32 path).
#[test]
fn clipped_pg_loss_batched_matches_looped() {
require_gpu();
let (n, lmax, cols) = (3usize, 5usize, 10usize);
let rows = n * lmax;
let x_h = fill(rows * cols, 909);
// Per sample: row 0 = prompt (-100); rows 1..real_len = completion; rest = pad
// (-100). Different real_len ⇒ n_s = {2, 3, 1} completion rows.
let real_len = [3usize, 4, 2];
let adv_s = [0.7f32, -0.5, 0.3];
let mut targets = vec![-100i32; rows];
for s in 0..n {
for r in 1..real_len[s] {
let t = s * lmax + r;
targets[t] = ((t * 3) % cols) as i32;
}
}
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
// logp_old ≈ logπθ at base logits (ρ≈1), logp_ref offset to exercise the KL term.
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
let logp_old: Vec<f32> = per_row0
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect();
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect();
let (eps, beta) = (0.2f32, 0.1f32);
// Per-row advantage (sample's A) + per-row weight 1/(N·n_s) (full normaliser).
let n_of = |s: usize| (0..lmax).filter(|&r| targets[s * lmax + r] >= 0).count() as f32;
let mut advantage = vec![0f32; rows];
let mut weight = vec![0f32; rows];
for s in 0..n {
let w = (1.0 / n as f32) * (1.0 / n_of(s));
for r in 0..lmax {
advantage[s * lmax + r] = adv_s[s];
weight[s * lmax + r] = w;
}
}
// Batched: one packed [R, vocab] forward + one backward.
let xb = Var::leaf(cuda(&x_h, &[rows, cols]));
let lb = ops::clipped_pg_loss_batched(
&xb, &mk_target(), &logp_old, &logp_ref, &advantage, &weight, eps, beta,
);
lb.backward();
let gb = xb.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
let lb_val = lb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
// Looped reference: per-sample slice → clipped_pg_loss → scale(1/N) → backward.
let mut g_ref = vec![0f32; rows * cols];
let mut loss_ref = 0f32;
for s in 0..n {
let r0 = s * lmax;
let xs_h = x_h[r0 * cols..(r0 + lmax) * cols].to_vec();
let tgt_s: Vec<i32> = targets[r0..r0 + lmax].to_vec();
let lo_s = logp_old[r0..r0 + lmax].to_vec();
let lr_s = logp_ref[r0..r0 + lmax].to_vec();
let xs = Var::leaf(cuda(&xs_h, &[lmax, cols]));
let tgt = Tensor::from_slice(&tgt_s, &[lmax]).to_device(Device::Cuda(0));
let ls = ops::clipped_pg_loss(&xs, &tgt, &lo_s, &lr_s, adv_s[s], eps, beta);
let scaled = ops::scale(&ls, 1.0 / n as f32);
scaled.backward();
let gs = xs.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>().to_vec();
g_ref[r0 * cols..(r0 + lmax) * cols].copy_from_slice(&gs);
loss_ref += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
}
let max_g = gb
.iter()
.zip(&g_ref)
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
(lb_val - loss_ref).abs() < 1e-5,
"batched loss {lb_val} vs looped {loss_ref}"
);
assert!(max_g < 1e-5, "batched grad vs looped: max|Δ| = {max_g}");
println!(
"clipped_pg_loss_batched OK: loss Δ={:.2e}, grad max|Δ|={:.2e} (== looped Σ_s 1/N·pg_s)",
(lb_val - loss_ref).abs(),
max_g
);
}

View File

@@ -0,0 +1,97 @@
// M2d gate: does forward_batched on RIGHT-PADDED ragged sequences reproduce the
// per-sequence single-seq forward on the real (non-pad) rows? The batched GRPO
// training-side forwards depend on this "right-pad is free under causal attention"
// property — a real completion row is at an earlier position than the trailing pad,
// and causal masking forbids attending forward, so its logits should be unchanged.
//
// Tested in fp32 (exact) over both SDPA cores (composed + fused flash), since the
// bench uses flash and a kernel could in principle leak the pad keys into the online
// softmax.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor};
use xtrain_tensor::{DType, Device, Tensor};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &Tensor) -> Vec<f32> {
t.to_dtype(DType::F32).to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
#[test]
fn forward_batched_ragged_matches_looped() {
if device::device_count().unwrap_or(0) == 0 {
eprintln!("no CUDA device; skipping");
return;
}
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut cfg = Config::tiny();
cfg.vocab = 32;
cfg.n_layers = 2;
let vocab = cfg.vocab;
// Ragged lengths incl. one crossing the flash tile (>32) and short ones.
let lens = [6usize, 40, 9, 4];
let lmax = *lens.iter().max().unwrap();
let n = lens.len();
let seqs: Vec<Vec<i32>> = lens
.iter()
.enumerate()
.map(|(b, &l)| (0..l).map(|i| ((b * 7 + i * 3 + 1) % vocab) as i32).collect())
.collect();
for (dtype, tol) in [(DType::F32, 2e-3f32), (DType::BF16, 3e-1f32)] {
for flash in [false, true] {
let m = build(cfg, device, dtype, flash);
// Looped: each sequence on its own (the ground truth).
let looped: Vec<Vec<f32>> = seqs.iter().map(|s| host(&m.forward(&ids_tensor(s, device)).value())).collect();
// Batched: right-pad each to lmax (pad id 0), one forward_batched(batch = n).
let mut flat = vec![0i32; n * lmax];
for (i, s) in seqs.iter().enumerate() {
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
}
let ids = Tensor::from_slice(&flat, &[n * lmax]).to_device(device);
let batched = host(&m.forward_batched(&ids, n).value()); // [n*lmax, vocab]
let mut dmax = 0f32;
for (i, s) in seqs.iter().enumerate() {
for r in 0..s.len() {
for c in 0..vocab {
let a = looped[i][r * vocab + c];
let b = batched[(i * lmax + r) * vocab + c];
dmax = dmax.max((a - b).abs());
}
}
}
println!("dtype={dtype:?} flash={flash}: ragged right-pad vs looped, max|Δlogit| (real rows) = {dmax:.3e}");
assert!(dmax < tol, "dtype={dtype:?} flash={flash}: right-pad NOT free under causal — max|Δ| = {dmax}");
}
}
println!("forward_batched_ragged_matches_looped OK: right-pad is free under causal (fp32+bf16, composed + flash)");
}

View File

@@ -0,0 +1,268 @@
//! Micro-benchmark + closeness gate for the M2d batched GRPO training-side forwards.
//!
//! After M2b/M2c the GRPO *step* is no longer rollout-bound — it is the `N = B·G`
//! per-sample full-sequence forwards (the `per_token_logp` captures + the inner
//! clipped-PG forward/backwards). This bin isolates exactly that, weight-independently
//! (step wall-clock depends on shapes + launch counts, not on what the weights are), by
//! synthesising `N` realistic ragged samples and A/B-timing the looped vs batched path
//! for BOTH phases — plus asserting they agree numerically (the looped-vs-batched
//! closeness gate; per-row bit-equivalence of the loss op is pinned by the autograd
//! test `clipped_pg_loss_batched_matches_looped`).
//!
//! bench_grpo_batch <tokenizer.json> --init-ckpt <base.ckpt> <arch flags> \
//! --n 48 --plen 12 --clen 24 --micro 16 --reps 3
#[cfg(no_cuda)]
fn main() {
eprintln!("bench_grpo_batch: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device, Tensor};
#[cfg(not(no_cuda))]
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, inner_pg_step_looped, per_token_logp, per_token_logp_batched};
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed.wrapping_mul(2862933555777941757).wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).and_then(|s| s.parse().ok()).unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter().position(|a| a == name).and_then(|i| args.get(i + 1)).cloned()
}
#[cfg(not(no_cuda))]
fn load_model(cfg: Config, device: Device, ckpt: &str) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
.with_compute_dtype(DType::BF16)
.with_flash(true);
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
m.eval();
m
}
#[cfg(not(no_cuda))]
fn elapsed_ms<F: FnMut()>(reps: usize, mut f: F) -> f32 {
let start = std::time::Instant::now();
for _ in 0..reps {
f();
}
start.elapsed().as_secs_f32() * 1e3 / reps as f32
}
/// Per-position argmax of the model over each ragged `input` (one `forward_batched`
/// per `micro`-chunk). Used to teacher-force WELL-CONDITIONED targets (the top-1 token,
/// high prob) so the closeness gate's logp isn't the ~20 of a random token — where
/// `log p` amplifies bf16 noise. This matches real GRPO (targets are model samples).
#[cfg(not(no_cuda))]
fn model_argmax(model: &TinyTransformer, device: Device, inputs: &[Vec<i32>], vocab: usize, micro: usize) -> Vec<Vec<i32>> {
let mut out = Vec::with_capacity(inputs.len());
for chunk in inputs.chunks(micro.max(1)) {
let m = chunk.len();
let lmax = chunk.iter().map(|s| s.len()).max().unwrap();
let mut flat = vec![0i32; m * lmax];
for (i, s) in chunk.iter().enumerate() {
flat[i * lmax..i * lmax + s.len()].copy_from_slice(s);
}
let ids = Tensor::from_slice(&flat, &[m * lmax]).to_device(device);
let logits = model.forward_batched(&ids, m).value().to_dtype(DType::F32).to_device(Device::Cpu);
let v = logits.as_slice::<f32>();
for (i, s) in chunk.iter().enumerate() {
let mut row = Vec::with_capacity(s.len());
for r in 0..s.len() {
let base = (i * lmax + r) * vocab;
let mut best = 0usize;
for c in 1..vocab {
if v[base + c] > v[base + best] {
best = c;
}
}
row.push(best as i32);
}
out.push(row);
}
}
out
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals.first().expect("usage: bench_grpo_batch <tokenizer.json> [flags]");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let n: usize = flag(&args, "--n", 48); // B·G samples per step
let plen: usize = flag(&args, "--plen", 12); // prompt tokens
let clen: usize = flag(&args, "--clen", 24); // max completion tokens
let micro: usize = flag(&args, "--micro", 16);
let reps: usize = flag(&args, "--reps", 3);
let (eps, beta) = (flag(&args, "--eps", 0.2f32), flag(&args, "--beta", 0.0f32));
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <base.ckpt> required");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
let vocab = tok.vocab_size();
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
let policy = load_model(cfg, device, &init_ckpt);
let params = policy.params();
// --- Synthesise N ragged samples (frame-shaped: prompt masked, ragged completion).
// Token IDs are random-but-valid; only the SHAPES drive the forward cost.
let mut rng = 0xC0FFEEu64;
let mut next = || {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(rng >> 33) as usize
};
let mut io: Vec<(Vec<i32>, Vec<i32>)> = Vec::with_capacity(n);
let mut advs: Vec<f32> = Vec::with_capacity(n);
for _ in 0..n {
let pl = plen.saturating_sub(2) + next() % 5; // jitter prompt length a little
let cl = 4 + next() % clen.max(1); // completion 4..=clen
let total = pl + cl;
let toks: Vec<i32> = (0..total).map(|_| (next() % vocab) as i32).collect();
let mut labels = vec![-100i32; pl]; // prompt masked
labels.extend_from_slice(&toks[pl..]);
let l = toks.len();
io.push((toks[..l - 1].to_vec(), labels[1..l].to_vec())); // target masked at [..pl-1]
advs.push(if next() % 2 == 0 { 0.7 } else { -0.7 });
}
let toklens: Vec<usize> = io.iter().map(|(i, _)| i.len()).collect();
let (lmin, lmax) = (*toklens.iter().min().unwrap(), *toklens.iter().max().unwrap());
println!("samples N={n}, seq len {lmin}..{lmax} (ragged), micro={micro}, β={beta}\n");
// Replace random completion targets with the model's own argmax (teacher forcing):
// well-conditioned logp (top-1, not the ~20 of a random token where bf16 noise
// blows up via log p). The completion target positions are where the skeleton is
// ≥0; prompt positions stay masked (100).
let inputs: Vec<Vec<i32>> = io.iter().map(|(i, _)| i.clone()).collect();
let preds = model_argmax(&policy, device, &inputs, vocab, micro);
for (s, (_, target)) in io.iter_mut().enumerate() {
for j in 0..target.len() {
if target[j] >= 0 {
target[j] = preds[s][j];
}
}
}
// ---------------- Phase 1: capture (per_token_logp) ----------------
let logp_loop: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
let logp_batch = per_token_logp_batched(&policy, device, &io, micro);
let cap_dmax = logp_loop
.iter()
.zip(&logp_batch)
.flat_map(|(a, b)| a.iter().zip(b).map(|(x, y)| (x - y).abs()))
.fold(0.0f32, f32::max);
let t_cap_loop = elapsed_ms(reps, || {
let _: Vec<Vec<f32>> = io.iter().map(|(i, t)| per_token_logp(&policy, device, i, t)).collect();
});
let t_cap_batch = elapsed_ms(reps, || {
let _ = per_token_logp_batched(&policy, device, &io, micro);
});
// Build PgSamples from the (matching) capture; ref = old 0.3 to exercise KL.
let batch: Vec<PgSample> = io
.iter()
.zip(&advs)
.zip(&logp_batch)
.map(|(((input, target), &adv), lp)| PgSample {
input: input.clone(),
target: target.clone(),
adv,
logp_old: lp.clone(),
logp_ref: lp.iter().map(|v| v - 0.3).collect(),
})
.collect();
// ---------------- Phase 2: inner clipped-PG (forward + backward) ----------------
// Representative grad snapshots: layer-0 wq (params[2]) + final_norm.
let wq0 = &params[2];
let fnorm = &params[1 + n_layers * 11];
let snap = |v: &xtrain_autodiff::Var| -> Vec<f32> {
v.grad().map(|g| g.to_device(Device::Cpu).as_slice::<f32>().to_vec()).unwrap_or_default()
};
let zero = |ps: &[xtrain_autodiff::Var]| ps.iter().for_each(|p| p.zero_grad());
zero(&params);
inner_pg_step_looped(&policy, device, &batch, eps, beta);
let (gq_loop, gn_loop) = (snap(wq0), snap(fnorm));
zero(&params);
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
let (gq_batch, gn_batch) = (snap(wq0), snap(fnorm));
zero(&params);
let reldiff = |a: &[f32], b: &[f32]| -> f32 {
let num = a.iter().zip(b).map(|(x, y)| (x - y).abs()).fold(0.0f32, f32::max);
let den = a.iter().map(|x| x.abs()).fold(0.0f32, f32::max).max(1e-12);
num / den
};
let gq_rel = reldiff(&gq_loop, &gq_batch);
let gn_rel = reldiff(&gn_loop, &gn_batch);
// Time only forward+backward — the lever. opt.step + grad-clip are identical in
// both paths (one call over `params` after the per-sample loop), so they would
// only add a constant; excluding them also dodges the unrelated 1B-Adam-state
// memory wall (the M4 finding) that this diagnostic doesn't need to reproduce.
let t_inner_loop = elapsed_ms(reps, || {
inner_pg_step_looped(&policy, device, &batch, eps, beta);
zero(&params);
});
let t_inner_batch = elapsed_ms(reps, || {
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
zero(&params);
});
// ---------------- Report ----------------
let spd = |a: f32, b: f32| if b > 0.0 { a / b } else { 0.0 };
println!("=== closeness gate (looped vs batched) ===");
println!(" capture per_token_logp : max|Δ| = {cap_dmax:.3e}");
println!(" inner grad wq[0] : rel|Δ| = {gq_rel:.3e}");
println!(" inner grad final_norm : rel|Δ| = {gn_rel:.3e}");
println!("\n=== timing (mean of {reps} reps, ms/phase) ===");
println!(" capture : looped {t_cap_loop:8.1} batched {t_cap_batch:8.1} ({:.2}× )", spd(t_cap_loop, t_cap_batch));
println!(" inner : looped {t_inner_loop:8.1} batched {t_inner_batch:8.1} ({:.2}× )", spd(t_inner_loop, t_inner_batch));
let (step_loop, step_batch) = (t_cap_loop + t_inner_loop, t_cap_batch + t_inner_batch);
println!(" STEP : looped {step_loop:8.1} batched {step_batch:8.1} ({:.2}× )", spd(step_loop, step_batch));
// The RIGOROUS correctness gates live in the test suite (exact, not bf16-noisy):
// - xtrain-model forward_batched_ragged_matches_looped (forward+pad == looped)
// - xtrain-autodiff clipped_pg_loss_batched_matches_looped (op == looped, f32)
// This is a smoke check at the 1B/bf16 scale: single-seq vs batched GEMM differ in
// batch-reduction order, so a loose band, with well-conditioned (argmax) targets.
assert!(cap_dmax < 0.2, "capture closeness smoke FAILED: max|Δlogp| = {cap_dmax}");
assert!(gq_rel < 0.2 && gn_rel < 0.2, "inner grad closeness smoke FAILED: wq {gq_rel}, fn {gn_rel}");
println!("\nSMOKE PASS (bf16 band): batched ≈ looped; rigorous gates are the two tests above.");
}

View File

@@ -23,15 +23,15 @@ fn main() {
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
use xtrain_model::{Config, TinyTransformer, generate_cached_batch};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))]
use xtrain_train::grpo_batch::{PgSample, inner_pg_step_batched, per_token_logp_batched};
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
#[cfg(not(no_cuda))]
@@ -117,20 +117,6 @@ fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) ->
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= per_row
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
#[cfg(not(no_cuda))]
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
@@ -149,6 +135,9 @@ fn main() {
let group: usize = flag(&args, "--group", 6);
let n_prompts: usize = flag(&args, "--prompts", 8);
let inner: usize = flag(&args, "--inner", 1);
// M2d: pack the step's N=B·G ragged samples into forward_batched chunks of this
// many samples (bounds the [chunk·Lmax, vocab] logits memory). Default = whole batch.
let micro: usize = flag(&args, "--micro", n_prompts * group.max(1));
let temp: f32 = flag(&args, "--temp", 1.0);
let beta: f32 = flag(&args, "--beta", 0.04);
let eps: f32 = flag(&args, "--eps", 0.2);
@@ -188,16 +177,17 @@ fn main() {
let start = std::time::Instant::now();
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
// Per-window phase timers (ms): rollout / capture / inner — to keep the step
// decomposition honest (M2d cut the training-side forwards 9×, so the question is
// what now dominates the step).
let (mut t_roll, mut t_cap, mut t_inner) = (0f32, 0f32, 0f32);
for step in 0..steps {
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
struct Sample {
input: Vec<i32>,
target: Vec<i32>,
adv: f32,
logp_old: Vec<f32>,
logp_ref: Vec<f32>,
}
let mut batch: Vec<Sample> = Vec::new();
// Collect ALL the step's framed samples first (input, target, adv), so the
// training-side forwards can be batched across the whole step (M2d) instead of
// run one ragged sequence at a time.
let t0 = std::time::Instant::now();
let mut raw: Vec<(Vec<i32>, Vec<i32>, f32)> = Vec::new();
for _ in 0..n_prompts {
let p = gen_problem(&mut rng, &gcfg);
let prompt_ids: Vec<i32> = tok
@@ -230,53 +220,69 @@ fn main() {
for (seg, r) in &comps {
let adv = (r - mean) / (std + 1e-4);
let (input, target) = frame(&tok, &p.question(), seg);
let logp_old = per_token_logp(&policy, device, &input, &target);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp(r, device, &input, &target),
None => vec![0.0; logp_old.len()],
};
batch.push(Sample { input, target, adv, logp_old, logp_ref });
raw.push((input, target, adv));
}
}
// ---- K inner clipped-PG epochs over the captured batch ----
if !batch.is_empty() {
let scale = 1.0 / batch.len() as f32;
t_roll += t0.elapsed().as_secs_f32() * 1e3;
// ---- Batched capture (M2d): logπ_old (policy) + logπ_ref (frozen) over ALL
// samples in forward_batched chunks, instead of one forward per sample. ----
if !raw.is_empty() {
let t1 = std::time::Instant::now();
let io: Vec<(Vec<i32>, Vec<i32>)> = raw.iter().map(|(i, t, _)| (i.clone(), t.clone())).collect();
let logp_old = per_token_logp_batched(&policy, device, &io, micro);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp_batched(r, device, &io, micro),
None => raw.iter().map(|(i, _, _)| vec![0.0; i.len()]).collect(),
};
let batch: Vec<PgSample> = raw
.iter()
.zip(logp_old)
.zip(logp_ref)
.map(|(((input, target, adv), lo), lr)| PgSample {
input: input.clone(),
target: target.clone(),
adv: *adv,
logp_old: lo,
logp_ref: lr,
})
.collect();
t_cap += t1.elapsed().as_secs_f32() * 1e3;
// ---- K inner clipped-PG epochs, batched over the captured samples ----
let t2 = std::time::Instant::now();
for _ in 0..inner {
for s in &batch {
let logits = policy.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(
&logits,
&ids_tensor(&s.target, device),
&s.logp_old,
&s.logp_ref,
s.adv,
eps,
beta,
);
ops::scale(&loss, scale).backward();
}
inner_pg_step_batched(&policy, device, &batch, eps, beta, micro);
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
t_inner += t2.elapsed().as_secs_f32() * 1e3;
}
if (step + 1) % log_every == 0 || step == steps - 1 {
let w = log_every.min(step + 1) as f32; // steps in this window
println!(
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s | ms/step roll {:.0} cap {:.0} inner {:.0}",
step + 1,
win_reward / win_n.max(1) as f32,
win_solved,
win_n,
start.elapsed().as_secs_f32(),
t_roll / w,
t_cap / w,
t_inner / w,
);
win_reward = 0.0;
win_solved = 0;
win_n = 0;
t_roll = 0.0;
t_cap = 0.0;
t_inner = 0.0;
// Periodic save so a later OOM (naive rollout fragments the allocator —
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save");

View File

@@ -0,0 +1,162 @@
//! Batched GRPO training-side forwards (post-training M2d). After M2b/M2c made the
//! rollout cheap, the GRPO **step** is dominated by the per-sample full-sequence
//! forwards: the `per_token_logp` captures (policy + reference) and the inner
//! clipped-PG `forward`/`backward`s — each a single-sequence `forward` over a short
//! ragged completion. This module packs the `N = B·G` ragged samples of a step into
//! ONE `forward_batched`, amortising the per-launch overhead across N (the same win
//! M2b gave the rollout).
//!
//! The enabling property: **right-padding is free under causal attention.** Pad each
//! ragged completion on the RIGHT to the batch's `Lmax`; a real completion row is at
//! an earlier position than the trailing pad, and causal masking forbids attending
//! forward, so its logits are bit-identical to the unpadded single-sequence forward.
//! The pad rows' own outputs are garbage but are masked out (`target = -100`).
//!
//! Both the looped (baseline) and batched paths live here so they share one source of
//! truth — `bin/bench_grpo_batch` A/Bs them (timing + a closeness gate), and the
//! per-row equivalence of the loss op is pinned by `clipped_pg_loss_batched_matches_looped`
//! in `xtrain-autodiff/tests/autograd.rs`.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_model::{TinyTransformer, ids_tensor};
use xtrain_tensor::{Device, Tensor};
/// One framed completion of a GRPO step: the next-token `(input, target)` pair
/// (prompt positions masked to `-100` in `target`), its group-relative `adv`, and the
/// per-position rollout-time / reference logprobs the clipped-PG loss needs.
pub struct PgSample {
pub input: Vec<i32>,
pub target: Vec<i32>,
pub adv: f32,
pub logp_old: Vec<f32>,
pub logp_ref: Vec<f32>,
}
// ------------------------------- looped (baseline) -------------------------------
/// Per-position `logπ(target_t)` of one framed `(input, target)` pair (= `per_row`
/// of cross_entropy; masked positions are 0). One single-sequence forward, no grad.
pub fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
/// One inner clipped-PG epoch the looped way: per sample, a single-sequence forward +
/// [`ops::clipped_pg_loss`] scaled by `1/N` + backward (grads accumulate on `model`'s
/// params). Returns the summed scaled loss. Caller does clip + opt.step + zero_grad.
pub fn inner_pg_step_looped(
model: &TinyTransformer,
device: Device,
batch: &[PgSample],
eps: f32,
beta: f32,
) -> f32 {
let scale = 1.0 / batch.len() as f32;
let mut total = 0f32;
for s in batch {
let logits = model.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(&logits, &ids_tensor(&s.target, device), &s.logp_old, &s.logp_ref, s.adv, eps, beta);
let scaled = ops::scale(&loss, scale);
total += scaled.value().to_device(Device::Cpu).as_slice::<f32>()[0];
scaled.backward();
}
total
}
// ------------------------------- batched (M2d) -----------------------------------
/// Right-pad `m` ragged `i32` rows (each `< lmax` long) to `[m*lmax]` sequence-major,
/// filling with `pad`. Used for both the id stream (pad = 0, arbitrary) and the target
/// stream (pad = 100, ignored by cross_entropy).
fn pack_i32(rows: &[&[i32]], lmax: usize, pad: i32) -> Vec<i32> {
let mut flat = vec![pad; rows.len() * lmax];
for (i, r) in rows.iter().enumerate() {
flat[i * lmax..i * lmax + r.len()].copy_from_slice(r);
}
flat
}
/// Batched [`per_token_logp`]: pack `samples` (each `(input, target)`) right-padded to
/// `Lmax`, run ONE `forward_batched(batch = N)`, and slice each sample's `logπ` back to
/// its real length. Equal to looping [`per_token_logp`] (right-pad is free under causal
/// attention), to bf16 batch-reduction tolerance. `samples` are processed in chunks of
/// `micro` (≥1) to bound the `[chunk*Lmax, vocab]` logits memory.
pub fn per_token_logp_batched(
model: &TinyTransformer,
device: Device,
samples: &[(Vec<i32>, Vec<i32>)],
micro: usize,
) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(samples.len());
for chunk in samples.chunks(micro.max(1)) {
let m = chunk.len();
let lmax = chunk.iter().map(|(i, _)| i.len()).max().unwrap();
let ins: Vec<&[i32]> = chunk.iter().map(|(i, _)| i.as_slice()).collect();
let tgs: Vec<&[i32]> = chunk.iter().map(|(_, t)| t.as_slice()).collect();
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
let logits = model.forward_batched(&ids, m).value();
let (_, per_row) = logits.cross_entropy(&tgt);
let pr = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
for (i, (inp, _)) in chunk.iter().enumerate() {
let b = i * lmax;
out.push((0..inp.len()).map(|r| -pr[b + r]).collect());
}
}
out
}
/// One inner clipped-PG epoch, batched: pack the batch (in `micro`-sized chunks) and run
/// ONE `forward_batched` + [`ops::clipped_pg_loss_batched`] + backward per chunk. The
/// per-row `weight = 1/(N·n_s)` uses the GLOBAL `N = batch.len()` (not the chunk size),
/// so chunked grad-accumulation reproduces the looped `Σ_s (1/N)(1/n_s)…` exactly.
/// Returns the summed loss. Caller does clip + opt.step + zero_grad.
pub fn inner_pg_step_batched(
model: &TinyTransformer,
device: Device,
batch: &[PgSample],
eps: f32,
beta: f32,
micro: usize,
) -> f32 {
let inv_n = 1.0 / batch.len() as f32;
let mut total = 0f32;
for chunk in batch.chunks(micro.max(1)) {
let m = chunk.len();
let lmax = chunk.iter().map(|s| s.input.len()).max().unwrap();
let ins: Vec<&[i32]> = chunk.iter().map(|s| s.input.as_slice()).collect();
let tgs: Vec<&[i32]> = chunk.iter().map(|s| s.target.as_slice()).collect();
let ids = Tensor::from_slice(&pack_i32(&ins, lmax, 0), &[m * lmax]).to_device(device);
let tgt = Tensor::from_slice(&pack_i32(&tgs, lmax, -100), &[m * lmax]).to_device(device);
let mut logp_old = vec![0f32; m * lmax];
let mut logp_ref = vec![0f32; m * lmax];
let mut advantage = vec![0f32; m * lmax];
let mut weight = vec![0f32; m * lmax];
for (i, s) in chunk.iter().enumerate() {
let b = i * lmax;
let li = s.input.len();
logp_old[b..b + li].copy_from_slice(&s.logp_old);
logp_ref[b..b + li].copy_from_slice(&s.logp_ref);
let n_s = s.target.iter().filter(|&&t| t >= 0).count().max(1) as f32;
let w = inv_n / n_s; // = 1/(N · n_s)
for r in 0..lmax {
advantage[b + r] = s.adv;
weight[b + r] = w;
}
}
let logits = model.forward_batched(&ids, m);
let loss = ops::clipped_pg_loss_batched(&logits, &tgt, &logp_old, &logp_ref, &advantage, &weight, eps, beta);
total += loss.value().to_device(Device::Cpu).as_slice::<f32>()[0];
loss.backward();
}
total
}

View File

@@ -15,6 +15,8 @@ pub mod task;
#[cfg(not(no_cuda))]
pub mod checkpoint;
#[cfg(not(no_cuda))]
pub mod grpo_batch;
#[cfg(not(no_cuda))]
pub mod sample;
#[cfg(not(no_cuda))]
mod train_loop;

View File

@@ -576,3 +576,54 @@ still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decod
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
**Correctness gates (exact, not bf16-noisy):**
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|-------------------------|---------------------|---------------|---------|
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
now dominate the step." The clean per-component bench falsifies the strong form: the training
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
training stack stays complete, now with the step decomposition measured, not asserted.

View File

@@ -107,6 +107,8 @@ Phase 1/2 把**预训练全栈**学完后Phase 3 转向**后训练 infra**
**M2cdevice 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**K/V device `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它( host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concatdecode_kv 单序列 + decode_batch 批量仍 **token-identical**GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133147 tok/s@128), **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward( T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实闸门齐全的改进(更干净 PCIe解码 +10%),但下一杠杆是 **per-sample forward ragged 批量**而非 cacheM2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache), token-identical-gated;后训练栈完整瓶颈已测绘
**M2d批量 GRPO 训练侧 forward,已落地,M2c 点名的杠杆 + 一处 decomposition 纠正)**M2c 点名的下一杠杆——把每步 `N=B·G` ragged 样本的训练侧 forward(`per_token_logp` 捕获 + inner clipped-PG fwd/bwd)打包进**一次 `forward_batched`**。**使能性质 = causal 下右 padding 免费**:真 completion 行位置早于尾部 pad,causal 禁止前向 attend,故真行 logits 与单序列 forward **逐位相同**,pad 行垃圾被 `target=-100` 屏蔽——这正是训练引擎 pad-and-mask 而非跑 ragged 的原因两件新东西:`per_token_logp_batched`( pad 一次 `forward_batched(N)` 按真长切片)、`ops::clipped_pg_loss_batched`(per-row `advantage[t]` + per-row `weight[t]`,caller `1/(N·n_s)`,op 不再自算 `1/n_tokens` 折进 weight 即与 looped `Σ_s (1/N)(1/n_s)…` **逐位等价**;`--micro` 分块界定 `[chunk·Lmax,vocab]` logits 显存,weight 用全局 N 故分块梯度累积精确)。**两道精确闸门**:`forward_batched_ragged_matches_looped`( pad 批量 forward == 单序列,fp32 max|Δ|=3.7e-7bf16 **0.0**,composed+flash)+ `clipped_pg_loss_batched_matches_looped`(批量 op == looped,loss Δ=1.5e-8/grad 7.5e-9,f32),复合即证端到端等价;端到端短 SFT`train_grpo` 12 ** OOM**(1B master+AdamW+批量激活 micro=16 容得下)、批量 inner 执行。**吞吐(bench,v12 1.05B,N=48,micro16,权重无关)**:capture 62271ms(8.7×)、inner 1907208ms(9.2×)、**训练侧 forward 合计 2526280ms(9.0×)**。**Decomposition 纠正(诚实发现)**:M2c "训练侧 forward 主导 step",干净分量 bench 证伪强形式——训练侧 forward **~8.5s step 里的 ~2.5s(~30%)**,可观值这 9×, **rollout(`generate_cached_batch` ~6s)一直是更大头**;M2d 把训练侧砍到 ~0.28s ,step **~95% rollout**,长杆又摆回 rollout。⇒ M2d 拔掉训练侧 forward 这块 overhang(分量级精确 9×),再次印证 measure-first:**step 级下一杠杆 = B×G rollout 批量**(今天只有每 prompt G 同步B prompt 仍串行)。后训练栈保持完整,step decomposition 现为**实测**而非断言
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)
- **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。