From 361c5290fa2c1600cd9cd4bae4b75613c1d52e9f Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 30 Jun 2026 17:18:54 +0800 Subject: [PATCH] =?UTF-8?q?post-train:=20M4=20=E2=80=94=20use=20M2b=20batc?= =?UTF-8?q?hed=20rollout=20in=20GRPO=20(~1.7=C3=97=20step)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit train_grpo rolls out a prompt's G samples with one generate_cached_batch call instead of G sequential generate_cached calls. Measured on v12 1.05B (G=6, B=6, easy task): ~8.5 s/step vs ~14-16 s/step single-seq cached — ~1.7× (rollout- inclusive; short of G× because per_token_logp + the PG update also cost, and the M2a host round-trip remains). Also more stable memory: one batched forward per step vs G allocations that fragment the caching allocator. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-train/src/bin/train_grpo.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/xtrain-train/src/bin/train_grpo.rs b/crates/xtrain-train/src/bin/train_grpo.rs index 2de31b5..4802668 100644 --- a/crates/xtrain-train/src/bin/train_grpo.rs +++ b/crates/xtrain-train/src/bin/train_grpo.rs @@ -28,7 +28,7 @@ use xtrain_autodiff::ops; #[cfg(not(no_cuda))] use xtrain_cuda::device; #[cfg(not(no_cuda))] -use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor}; +use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor}; #[cfg(not(no_cuda))] use xtrain_tensor::{DType, Device}; #[cfg(not(no_cuda))] @@ -205,12 +205,12 @@ fn main() { .into_iter() .map(|t| t as i32) .collect(); + // M2b batched rollout: the G samples of this prompt decode in lockstep + // (one forward per step over the whole group → G× fewer kernel launches + // than G sequential single-seq rollouts; the M4 rollout long-pole fix). let mut comps: Vec<(String, f32)> = Vec::with_capacity(group); - for _ in 0..group { - // KV-cache temperature rollout (M2 engine): single-row logits per - // step → far lighter on the allocator than the naive sampler, which - // fragments it over a long rollout (the M4 rollout long-pole). - let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng); + let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng); + for out in &outs { let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::>()); let seg = first_answer_segment(&cont).trim().to_string(); let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };