post-train: M4 — use M2b batched rollout in GRPO (~1.7× step)

train_grpo rolls out a prompt's G samples with one generate_cached_batch call
instead of G sequential generate_cached calls. Measured on v12 1.05B (G=6, B=6,
easy task): ~8.5 s/step vs ~14-16 s/step single-seq cached — ~1.7× (rollout-
inclusive; short of G× because per_token_logp + the PG update also cost, and the
M2a host round-trip remains). Also more stable memory: one batched forward per
step vs G allocations that fragment the caching allocator.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 17:18:54 +08:00
parent 2c9b58cb3b
commit 361c5290fa

View File

@@ -28,7 +28,7 @@ use xtrain_autodiff::ops;
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
use xtrain_cuda::device; use xtrain_cuda::device;
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor}; use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device}; use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
@@ -205,12 +205,12 @@ fn main() {
.into_iter() .into_iter()
.map(|t| t as i32) .map(|t| t as i32)
.collect(); .collect();
// M2b batched rollout: the G samples of this prompt decode in lockstep
// (one forward per step over the whole group → G× fewer kernel launches
// than G sequential single-seq rollouts; the M4 rollout long-pole fix).
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group); let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
for _ in 0..group { let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng);
// KV-cache temperature rollout (M2 engine): single-row logits per for out in &outs {
// step → far lighter on the allocator than the naive sampler, which
// fragments it over a long rollout (the M4 rollout long-pole).
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>()); let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont).trim().to_string(); let seg = first_answer_segment(&cont).trim().to_string();
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 }; let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };