diff --git a/crates/xtrain-train/src/bin/train_grpo.rs b/crates/xtrain-train/src/bin/train_grpo.rs index 2de31b5..4802668 100644 --- a/crates/xtrain-train/src/bin/train_grpo.rs +++ b/crates/xtrain-train/src/bin/train_grpo.rs @@ -28,7 +28,7 @@ use xtrain_autodiff::ops; #[cfg(not(no_cuda))] use xtrain_cuda::device; #[cfg(not(no_cuda))] -use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor}; +use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor}; #[cfg(not(no_cuda))] use xtrain_tensor::{DType, Device}; #[cfg(not(no_cuda))] @@ -205,12 +205,12 @@ fn main() { .into_iter() .map(|t| t as i32) .collect(); + // M2b batched rollout: the G samples of this prompt decode in lockstep + // (one forward per step over the whole group → G× fewer kernel launches + // than G sequential single-seq rollouts; the M4 rollout long-pole fix). let mut comps: Vec<(String, f32)> = Vec::with_capacity(group); - for _ in 0..group { - // KV-cache temperature rollout (M2 engine): single-row logits per - // step → far lighter on the allocator than the naive sampler, which - // fragments it over a long rollout (the M4 rollout long-pole). - let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng); + let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng); + for out in &outs { let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::>()); let seg = first_answer_segment(&cont).trim().to_string(); let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };