post-train: M4 — GRPO actor-learner loop + cached temperature rollout
train_grpo: the online, critic-free RL loop — per step sample B prompts, roll out G completions each, score with the rule-based checker (reward 0/1), compute group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness (KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward (the falsifiable "it learns" signal). Periodic checkpoint save. decode: generate_cached adds temperature sampling to the KV-cache engine (M2) — single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far lighter on the caching allocator (the naive sampler fragments it over a long rollout). generate_greedy_cached now routes through it (temp 0); decode_kv token-identical gate still passes. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -83,6 +83,24 @@ pub fn generate_greedy_cached(
|
||||
device: Device,
|
||||
prompt: &[i32],
|
||||
max_new: usize,
|
||||
) -> Vec<i32> {
|
||||
let mut rng = 0u64;
|
||||
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
|
||||
}
|
||||
|
||||
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
|
||||
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
|
||||
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
|
||||
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
|
||||
/// lighter on memory + the allocator — the naive sampler fragments the caching
|
||||
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
|
||||
pub fn generate_cached(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
prompt: &[i32],
|
||||
max_new: usize,
|
||||
temperature: f32,
|
||||
rng_state: &mut u64,
|
||||
) -> Vec<i32> {
|
||||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||||
let cfg = model.config();
|
||||
@@ -116,7 +134,11 @@ pub fn generate_greedy_cached(
|
||||
}
|
||||
|
||||
for _ in 0..max_new {
|
||||
let next = argmax(&logits) as i32;
|
||||
let next = if temperature <= 0.0 {
|
||||
argmax(&logits) as i32
|
||||
} else {
|
||||
sample_temperature(&logits, temperature, rng_state) as i32
|
||||
};
|
||||
tokens.push(next);
|
||||
let pos = tokens.len() - 1; // absolute position of the token just appended
|
||||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
|
||||
@@ -124,6 +146,26 @@ pub fn generate_greedy_cached(
|
||||
tokens
|
||||
}
|
||||
|
||||
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
|
||||
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
|
||||
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
|
||||
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
|
||||
let sum: f32 = exps.iter().sum();
|
||||
*rng_state = rng_state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
|
||||
let mut acc = 0.0;
|
||||
for (i, &e) in exps.iter().enumerate() {
|
||||
acc += e;
|
||||
if acc >= r {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
exps.len() - 1
|
||||
}
|
||||
|
||||
/// One incremental decode step for token `tok` at absolute position `pos`: append
|
||||
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
||||
@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod decode;
|
||||
#[cfg(not(no_cuda))]
|
||||
pub use decode::generate_greedy_cached;
|
||||
pub use decode::{generate_cached, generate_greedy_cached};
|
||||
|
||||
288
crates/xtrain-train/src/bin/train_grpo.rs
Normal file
288
crates/xtrain-train/src/bin/train_grpo.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
|
||||
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
|
||||
//!
|
||||
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
|
||||
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
|
||||
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
|
||||
//! compute the **group-relative advantage** `A_i = (r_i − mean) / (std + ε)` (no
|
||||
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
|
||||
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
|
||||
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
|
||||
//! lesson — here it is an explicit leash, not just a hope).
|
||||
//!
|
||||
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
|
||||
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
|
||||
//! eval_arith on the saved checkpoint.
|
||||
//!
|
||||
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
|
||||
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
|
||||
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_autodiff::ops;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::{DType, Device};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn first_answer_segment(c: &str) -> &str {
|
||||
let s = c.split("<|endoftext|>").next().unwrap_or(c);
|
||||
s.split('\n').next().unwrap_or(s)
|
||||
}
|
||||
|
||||
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
|
||||
/// policy enables activation recompute (T13) so its backward fits alongside the
|
||||
/// frozen reference + the Adam state; the reference only forwards (no backward).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
})
|
||||
.with_compute_dtype(DType::BF16)
|
||||
.with_recompute(recompute)
|
||||
.with_flash(true);
|
||||
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
|
||||
m.eval();
|
||||
m
|
||||
}
|
||||
|
||||
/// Frame (question, completion) like the SFT loader and return the next-token
|
||||
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
|
||||
let p_ids: Vec<i32> = tok
|
||||
.encode(&format!("User: {question}\nAssistant:"))
|
||||
.into_iter()
|
||||
.map(|t| t as i32)
|
||||
.collect();
|
||||
let a_ids: Vec<i32> = tok
|
||||
.encode(&format!(" {completion}\n<|endoftext|>"))
|
||||
.into_iter()
|
||||
.map(|t| t as i32)
|
||||
.collect();
|
||||
let mut tokens = p_ids.clone();
|
||||
tokens.extend_from_slice(&a_ids);
|
||||
let mut labels = vec![-100i32; p_ids.len()];
|
||||
labels.extend_from_slice(&a_ids);
|
||||
let l = tokens.len();
|
||||
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||||
}
|
||||
|
||||
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= −per_row
|
||||
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||
per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.map(|p| -p)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use xtrain_optim::GpuAdamW;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let steps: usize = flag(&args, "--steps", 200);
|
||||
let group: usize = flag(&args, "--group", 6);
|
||||
let n_prompts: usize = flag(&args, "--prompts", 8);
|
||||
let inner: usize = flag(&args, "--inner", 1);
|
||||
let temp: f32 = flag(&args, "--temp", 1.0);
|
||||
let beta: f32 = flag(&args, "--beta", 0.04);
|
||||
let eps: f32 = flag(&args, "--eps", 0.2);
|
||||
let lr: f32 = flag(&args, "--lr", 1e-6);
|
||||
let clip: f32 = flag(&args, "--clip", 1.0);
|
||||
let max_new: usize = flag(&args, "--max-tokens", 24);
|
||||
let max_add: i64 = flag(&args, "--max-add", 20);
|
||||
let max_mul: i64 = flag(&args, "--max-mul", 9);
|
||||
let seed: u64 = flag(&args, "--seed", 20260630);
|
||||
let log_every: usize = flag(&args, "--log-every", 20);
|
||||
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
|
||||
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
|
||||
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
|
||||
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
|
||||
let reference = if beta > 0.0 {
|
||||
Some(load_model(cfg, device, &init_ckpt, false))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let gcfg = GenConfig {
|
||||
max_add,
|
||||
max_mul,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
};
|
||||
let params = policy.params();
|
||||
let mut opt = GpuAdamW::new(0.0);
|
||||
let mut rng = seed.max(1);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
|
||||
for step in 0..steps {
|
||||
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
|
||||
struct Sample {
|
||||
input: Vec<i32>,
|
||||
target: Vec<i32>,
|
||||
adv: f32,
|
||||
logp_old: Vec<f32>,
|
||||
logp_ref: Vec<f32>,
|
||||
}
|
||||
let mut batch: Vec<Sample> = Vec::new();
|
||||
for _ in 0..n_prompts {
|
||||
let p = gen_problem(&mut rng, &gcfg);
|
||||
let prompt_ids: Vec<i32> = tok
|
||||
.encode(&format!("User: {}\nAssistant:", p.question()))
|
||||
.into_iter()
|
||||
.map(|t| t as i32)
|
||||
.collect();
|
||||
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||||
for _ in 0..group {
|
||||
// KV-cache temperature rollout (M2 engine): single-row logits per
|
||||
// step → far lighter on the allocator than the naive sampler, which
|
||||
// fragments it over a long rollout (the M4 rollout long-pole).
|
||||
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
||||
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont).trim().to_string();
|
||||
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||||
comps.push((seg, r));
|
||||
}
|
||||
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
|
||||
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
|
||||
let std = var.sqrt();
|
||||
win_reward += mean * group as f32;
|
||||
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
|
||||
win_n += group;
|
||||
// A whole group with no reward variance gives zero advantage → skip
|
||||
// (no learning signal, and avoids dividing by ~0).
|
||||
if std < 1e-6 {
|
||||
continue;
|
||||
}
|
||||
for (seg, r) in &comps {
|
||||
let adv = (r - mean) / (std + 1e-4);
|
||||
let (input, target) = frame(&tok, &p.question(), seg);
|
||||
let logp_old = per_token_logp(&policy, device, &input, &target);
|
||||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||
let logp_ref = match &reference {
|
||||
Some(r) => per_token_logp(r, device, &input, &target),
|
||||
None => vec![0.0; logp_old.len()],
|
||||
};
|
||||
batch.push(Sample { input, target, adv, logp_old, logp_ref });
|
||||
}
|
||||
}
|
||||
|
||||
// ---- K inner clipped-PG epochs over the captured batch ----
|
||||
if !batch.is_empty() {
|
||||
let scale = 1.0 / batch.len() as f32;
|
||||
for _ in 0..inner {
|
||||
for s in &batch {
|
||||
let logits = policy.forward(&ids_tensor(&s.input, device));
|
||||
let loss = ops::clipped_pg_loss(
|
||||
&logits,
|
||||
&ids_tensor(&s.target, device),
|
||||
&s.logp_old,
|
||||
&s.logp_ref,
|
||||
s.adv,
|
||||
eps,
|
||||
beta,
|
||||
);
|
||||
ops::scale(&loss, scale).backward();
|
||||
}
|
||||
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||||
println!(
|
||||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
|
||||
step + 1,
|
||||
win_reward / win_n.max(1) as f32,
|
||||
win_solved,
|
||||
win_n,
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
win_reward = 0.0;
|
||||
win_solved = 0;
|
||||
win_n = 0;
|
||||
// Periodic save so a later OOM (naive rollout fragments the allocator —
|
||||
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
|
||||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save");
|
||||
}
|
||||
}
|
||||
|
||||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save ckpt");
|
||||
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e} → {out_ckpt}");
|
||||
}
|
||||
Reference in New Issue
Block a user