Files
xtrain/crates/xtrain-train/src/bin/train_grpo.rs
Gahow Wang 7fb3b32fd9 post-train: M4 — GRPO actor-learner loop + cached temperature rollout
train_grpo: the online, critic-free RL loop — per step sample B prompts, roll
out G completions each, score with the rule-based checker (reward 0/1), compute
group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss
epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness
(KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward
(the falsifiable "it learns" signal). Periodic checkpoint save.

decode: generate_cached adds temperature sampling to the KV-cache engine (M2) —
single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far
lighter on the caching allocator (the naive sampler fragments it over a long
rollout). generate_greedy_cached now routes through it (temp 0); decode_kv
token-identical gate still passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 16:59:05 +08:00

289 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
//!
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
//! compute the **group-relative advantage** `A_i = (r_i mean) / (std + ε)` (no
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
//! lesson — here it is an explicit leash, not just a hope).
//!
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
//! eval_arith on the saved checkpoint.
//!
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
#[cfg(no_cuda)]
fn main() {
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
#[cfg(not(no_cuda))]
fn first_answer_segment(c: &str) -> &str {
let s = c.split("<|endoftext|>").next().unwrap_or(c);
s.split('\n').next().unwrap_or(s)
}
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
/// policy enables activation recompute (T13) so its backward fits alongside the
/// frozen reference + the Adam state; the reference only forwards (no backward).
#[cfg(not(no_cuda))]
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
.with_compute_dtype(DType::BF16)
.with_recompute(recompute)
.with_flash(true);
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
m.eval();
m
}
/// Frame (question, completion) like the SFT loader and return the next-token
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
#[cfg(not(no_cuda))]
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
let p_ids: Vec<i32> = tok
.encode(&format!("User: {question}\nAssistant:"))
.into_iter()
.map(|t| t as i32)
.collect();
let a_ids: Vec<i32> = tok
.encode(&format!(" {completion}\n<|endoftext|>"))
.into_iter()
.map(|t| t as i32)
.collect();
let mut tokens = p_ids.clone();
tokens.extend_from_slice(&a_ids);
let mut labels = vec![-100i32; p_ids.len()];
labels.extend_from_slice(&a_ids);
let l = tokens.len();
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= per_row
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
#[cfg(not(no_cuda))]
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
use xtrain_optim::GpuAdamW;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let steps: usize = flag(&args, "--steps", 200);
let group: usize = flag(&args, "--group", 6);
let n_prompts: usize = flag(&args, "--prompts", 8);
let inner: usize = flag(&args, "--inner", 1);
let temp: f32 = flag(&args, "--temp", 1.0);
let beta: f32 = flag(&args, "--beta", 0.04);
let eps: f32 = flag(&args, "--eps", 0.2);
let lr: f32 = flag(&args, "--lr", 1e-6);
let clip: f32 = flag(&args, "--clip", 1.0);
let max_new: usize = flag(&args, "--max-tokens", 24);
let max_add: i64 = flag(&args, "--max-add", 20);
let max_mul: i64 = flag(&args, "--max-mul", 9);
let seed: u64 = flag(&args, "--seed", 20260630);
let log_every: usize = flag(&args, "--log-every", 20);
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
let reference = if beta > 0.0 {
Some(load_model(cfg, device, &init_ckpt, false))
} else {
None
};
let gcfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let params = policy.params();
let mut opt = GpuAdamW::new(0.0);
let mut rng = seed.max(1);
let start = std::time::Instant::now();
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
for step in 0..steps {
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
struct Sample {
input: Vec<i32>,
target: Vec<i32>,
adv: f32,
logp_old: Vec<f32>,
logp_ref: Vec<f32>,
}
let mut batch: Vec<Sample> = Vec::new();
for _ in 0..n_prompts {
let p = gen_problem(&mut rng, &gcfg);
let prompt_ids: Vec<i32> = tok
.encode(&format!("User: {}\nAssistant:", p.question()))
.into_iter()
.map(|t| t as i32)
.collect();
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
for _ in 0..group {
// KV-cache temperature rollout (M2 engine): single-row logits per
// step → far lighter on the allocator than the naive sampler, which
// fragments it over a long rollout (the M4 rollout long-pole).
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont).trim().to_string();
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
comps.push((seg, r));
}
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
let std = var.sqrt();
win_reward += mean * group as f32;
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
win_n += group;
// A whole group with no reward variance gives zero advantage → skip
// (no learning signal, and avoids dividing by ~0).
if std < 1e-6 {
continue;
}
for (seg, r) in &comps {
let adv = (r - mean) / (std + 1e-4);
let (input, target) = frame(&tok, &p.question(), seg);
let logp_old = per_token_logp(&policy, device, &input, &target);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp(r, device, &input, &target),
None => vec![0.0; logp_old.len()],
};
batch.push(Sample { input, target, adv, logp_old, logp_ref });
}
}
// ---- K inner clipped-PG epochs over the captured batch ----
if !batch.is_empty() {
let scale = 1.0 / batch.len() as f32;
for _ in 0..inner {
for s in &batch {
let logits = policy.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(
&logits,
&ids_tensor(&s.target, device),
&s.logp_old,
&s.logp_ref,
s.adv,
eps,
beta,
);
ops::scale(&loss, scale).backward();
}
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
}
if (step + 1) % log_every == 0 || step == steps - 1 {
println!(
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
step + 1,
win_reward / win_n.max(1) as f32,
win_solved,
win_n,
start.elapsed().as_secs_f32(),
);
win_reward = 0.0;
win_solved = 0;
win_n = 0;
// Periodic save so a later OOM (naive rollout fragments the allocator —
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save");
}
}
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save ckpt");
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e}{out_ckpt}");
}