post-train: M4 — use M2b batched rollout in GRPO (~1.7× step)
train_grpo rolls out a prompt's G samples with one generate_cached_batch call instead of G sequential generate_cached calls. Measured on v12 1.05B (G=6, B=6, easy task): ~8.5 s/step vs ~14-16 s/step single-seq cached — ~1.7× (rollout- inclusive; short of G× because per_token_logp + the PG update also cost, and the M2a host round-trip remains). Also more stable memory: one batched forward per step vs G allocations that fragment the caching allocator. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -28,7 +28,7 @@ use xtrain_autodiff::ops;
|
|||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_cuda::device;
|
use xtrain_cuda::device;
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
use xtrain_model::{Config, TinyTransformer, generate_cached_batch, ids_tensor};
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
use xtrain_tensor::{DType, Device};
|
use xtrain_tensor::{DType, Device};
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
@@ -205,12 +205,12 @@ fn main() {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|t| t as i32)
|
.map(|t| t as i32)
|
||||||
.collect();
|
.collect();
|
||||||
|
// M2b batched rollout: the G samples of this prompt decode in lockstep
|
||||||
|
// (one forward per step over the whole group → G× fewer kernel launches
|
||||||
|
// than G sequential single-seq rollouts; the M4 rollout long-pole fix).
|
||||||
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||||||
for _ in 0..group {
|
let outs = generate_cached_batch(&policy, device, &prompt_ids, group, max_new, temp, &mut rng);
|
||||||
// KV-cache temperature rollout (M2 engine): single-row logits per
|
for out in &outs {
|
||||||
// step → far lighter on the allocator than the naive sampler, which
|
|
||||||
// fragments it over a long rollout (the M4 rollout long-pole).
|
|
||||||
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
|
||||||
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||||
let seg = first_answer_segment(&cont).trim().to_string();
|
let seg = first_answer_segment(&cont).trim().to_string();
|
||||||
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||||||
|
|||||||
Reference in New Issue
Block a user