train_grpo: the online, critic-free RL loop — per step sample B prompts, roll out G completions each, score with the rule-based checker (reward 0/1), compute group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness (KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward (the falsifiable "it learns" signal). Periodic checkpoint save. decode: generate_cached adds temperature sampling to the KV-cache engine (M2) — single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far lighter on the caching allocator (the naive sampler fragments it over a long rollout). generate_greedy_cached now routes through it (temp 0); decode_kv token-identical gate still passes. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
289 lines
12 KiB
Rust
289 lines
12 KiB
Rust
//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
|
||
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
|
||
//!
|
||
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
|
||
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
|
||
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
|
||
//! compute the **group-relative advantage** `A_i = (r_i − mean) / (std + ε)` (no
|
||
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
|
||
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
|
||
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
|
||
//! lesson — here it is an explicit leash, not just a hope).
|
||
//!
|
||
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
|
||
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
|
||
//! eval_arith on the saved checkpoint.
|
||
//!
|
||
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
|
||
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
|
||
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
|
||
|
||
#[cfg(no_cuda)]
|
||
fn main() {
|
||
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
|
||
}
|
||
|
||
#[cfg(not(no_cuda))]
|
||
use xtrain_autodiff::ops;
|
||
#[cfg(not(no_cuda))]
|
||
use xtrain_cuda::device;
|
||
#[cfg(not(no_cuda))]
|
||
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
||
#[cfg(not(no_cuda))]
|
||
use xtrain_tensor::{DType, Device};
|
||
#[cfg(not(no_cuda))]
|
||
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
|
||
|
||
#[cfg(not(no_cuda))]
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
#[cfg(not(no_cuda))]
|
||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||
args.iter()
|
||
.position(|a| a == name)
|
||
.and_then(|i| args.get(i + 1))
|
||
.and_then(|s| s.parse().ok())
|
||
.unwrap_or(default)
|
||
}
|
||
|
||
#[cfg(not(no_cuda))]
|
||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||
args.iter()
|
||
.position(|a| a == name)
|
||
.and_then(|i| args.get(i + 1))
|
||
.cloned()
|
||
}
|
||
|
||
#[cfg(not(no_cuda))]
|
||
fn first_answer_segment(c: &str) -> &str {
|
||
let s = c.split("<|endoftext|>").next().unwrap_or(c);
|
||
s.split('\n').next().unwrap_or(s)
|
||
}
|
||
|
||
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
|
||
/// policy enables activation recompute (T13) so its backward fits alongside the
|
||
/// frozen reference + the Adam state; the reference only forwards (no backward).
|
||
#[cfg(not(no_cuda))]
|
||
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
|
||
let mut seed = 1u64;
|
||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.04)
|
||
}
|
||
})
|
||
.with_compute_dtype(DType::BF16)
|
||
.with_recompute(recompute)
|
||
.with_flash(true);
|
||
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
|
||
m.eval();
|
||
m
|
||
}
|
||
|
||
/// Frame (question, completion) like the SFT loader and return the next-token
|
||
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
|
||
#[cfg(not(no_cuda))]
|
||
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
|
||
let p_ids: Vec<i32> = tok
|
||
.encode(&format!("User: {question}\nAssistant:"))
|
||
.into_iter()
|
||
.map(|t| t as i32)
|
||
.collect();
|
||
let a_ids: Vec<i32> = tok
|
||
.encode(&format!(" {completion}\n<|endoftext|>"))
|
||
.into_iter()
|
||
.map(|t| t as i32)
|
||
.collect();
|
||
let mut tokens = p_ids.clone();
|
||
tokens.extend_from_slice(&a_ids);
|
||
let mut labels = vec![-100i32; p_ids.len()];
|
||
labels.extend_from_slice(&a_ids);
|
||
let l = tokens.len();
|
||
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||
}
|
||
|
||
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= −per_row
|
||
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
|
||
#[cfg(not(no_cuda))]
|
||
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||
let logits = model.forward(&ids_tensor(input, device)).value();
|
||
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||
per_row
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.iter()
|
||
.map(|p| -p)
|
||
.collect()
|
||
}
|
||
|
||
#[cfg(not(no_cuda))]
|
||
fn main() {
|
||
use xserv_tokenizer::Tokenizer;
|
||
use xtrain_optim::GpuAdamW;
|
||
|
||
let args: Vec<String> = std::env::args().collect();
|
||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
|
||
|
||
let n_heads = flag(&args, "--heads", 52usize);
|
||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||
let n_layers = flag(&args, "--layers", 22usize);
|
||
let ffn = flag(&args, "--ffn", 6656usize);
|
||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||
let steps: usize = flag(&args, "--steps", 200);
|
||
let group: usize = flag(&args, "--group", 6);
|
||
let n_prompts: usize = flag(&args, "--prompts", 8);
|
||
let inner: usize = flag(&args, "--inner", 1);
|
||
let temp: f32 = flag(&args, "--temp", 1.0);
|
||
let beta: f32 = flag(&args, "--beta", 0.04);
|
||
let eps: f32 = flag(&args, "--eps", 0.2);
|
||
let lr: f32 = flag(&args, "--lr", 1e-6);
|
||
let clip: f32 = flag(&args, "--clip", 1.0);
|
||
let max_new: usize = flag(&args, "--max-tokens", 24);
|
||
let max_add: i64 = flag(&args, "--max-add", 20);
|
||
let max_mul: i64 = flag(&args, "--max-mul", 9);
|
||
let seed: u64 = flag(&args, "--seed", 20260630);
|
||
let log_every: usize = flag(&args, "--log-every", 20);
|
||
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
|
||
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
|
||
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
|
||
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
|
||
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
|
||
let reference = if beta > 0.0 {
|
||
Some(load_model(cfg, device, &init_ckpt, false))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let gcfg = GenConfig {
|
||
max_add,
|
||
max_mul,
|
||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||
};
|
||
let params = policy.params();
|
||
let mut opt = GpuAdamW::new(0.0);
|
||
let mut rng = seed.max(1);
|
||
|
||
let start = std::time::Instant::now();
|
||
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
|
||
for step in 0..steps {
|
||
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
|
||
struct Sample {
|
||
input: Vec<i32>,
|
||
target: Vec<i32>,
|
||
adv: f32,
|
||
logp_old: Vec<f32>,
|
||
logp_ref: Vec<f32>,
|
||
}
|
||
let mut batch: Vec<Sample> = Vec::new();
|
||
for _ in 0..n_prompts {
|
||
let p = gen_problem(&mut rng, &gcfg);
|
||
let prompt_ids: Vec<i32> = tok
|
||
.encode(&format!("User: {}\nAssistant:", p.question()))
|
||
.into_iter()
|
||
.map(|t| t as i32)
|
||
.collect();
|
||
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||
for _ in 0..group {
|
||
// KV-cache temperature rollout (M2 engine): single-row logits per
|
||
// step → far lighter on the allocator than the naive sampler, which
|
||
// fragments it over a long rollout (the M4 rollout long-pole).
|
||
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
||
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||
let seg = first_answer_segment(&cont).trim().to_string();
|
||
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||
comps.push((seg, r));
|
||
}
|
||
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
|
||
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
|
||
let std = var.sqrt();
|
||
win_reward += mean * group as f32;
|
||
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
|
||
win_n += group;
|
||
// A whole group with no reward variance gives zero advantage → skip
|
||
// (no learning signal, and avoids dividing by ~0).
|
||
if std < 1e-6 {
|
||
continue;
|
||
}
|
||
for (seg, r) in &comps {
|
||
let adv = (r - mean) / (std + 1e-4);
|
||
let (input, target) = frame(&tok, &p.question(), seg);
|
||
let logp_old = per_token_logp(&policy, device, &input, &target);
|
||
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||
let logp_ref = match &reference {
|
||
Some(r) => per_token_logp(r, device, &input, &target),
|
||
None => vec![0.0; logp_old.len()],
|
||
};
|
||
batch.push(Sample { input, target, adv, logp_old, logp_ref });
|
||
}
|
||
}
|
||
|
||
// ---- K inner clipped-PG epochs over the captured batch ----
|
||
if !batch.is_empty() {
|
||
let scale = 1.0 / batch.len() as f32;
|
||
for _ in 0..inner {
|
||
for s in &batch {
|
||
let logits = policy.forward(&ids_tensor(&s.input, device));
|
||
let loss = ops::clipped_pg_loss(
|
||
&logits,
|
||
&ids_tensor(&s.target, device),
|
||
&s.logp_old,
|
||
&s.logp_ref,
|
||
s.adv,
|
||
eps,
|
||
beta,
|
||
);
|
||
ops::scale(&loss, scale).backward();
|
||
}
|
||
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||
opt.step(lr, ¶ms);
|
||
for p in ¶ms {
|
||
p.zero_grad();
|
||
}
|
||
}
|
||
}
|
||
|
||
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||
println!(
|
||
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
|
||
step + 1,
|
||
win_reward / win_n.max(1) as f32,
|
||
win_solved,
|
||
win_n,
|
||
start.elapsed().as_secs_f32(),
|
||
);
|
||
win_reward = 0.0;
|
||
win_solved = 0;
|
||
win_n = 0;
|
||
// Periodic save so a later OOM (naive rollout fragments the allocator —
|
||
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
|
||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save");
|
||
}
|
||
}
|
||
|
||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save ckpt");
|
||
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e} → {out_ckpt}");
|
||
}
|