post-train: M4 — GRPO actor-learner loop + cached temperature rollout

train_grpo: the online, critic-free RL loop — per step sample B prompts, roll
out G completions each, score with the rule-based checker (reward 0/1), compute
group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss
epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness
(KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward
(the falsifiable "it learns" signal). Periodic checkpoint save.

decode: generate_cached adds temperature sampling to the KV-cache engine (M2) —
single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far
lighter on the caching allocator (the naive sampler fragments it over a long
rollout). generate_greedy_cached now routes through it (temp 0); decode_kv
token-identical gate still passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 16:59:05 +08:00
parent aaa77082ef
commit 7fb3b32fd9
3 changed files with 332 additions and 2 deletions

View File

@@ -83,6 +83,24 @@ pub fn generate_greedy_cached(
device: Device,
prompt: &[i32],
max_new: usize,
) -> Vec<i32> {
let mut rng = 0u64;
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
}
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
/// lighter on memory + the allocator — the naive sampler fragments the caching
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
pub fn generate_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<i32> {
assert!(!prompt.is_empty(), "prompt must be non-empty");
let cfg = model.config();
@@ -116,7 +134,11 @@ pub fn generate_greedy_cached(
}
for _ in 0..max_new {
let next = argmax(&logits) as i32;
let next = if temperature <= 0.0 {
argmax(&logits) as i32
} else {
sample_temperature(&logits, temperature, rng_state) as i32
};
tokens.push(next);
let pos = tokens.len() - 1; // absolute position of the token just appended
logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
@@ -124,6 +146,26 @@ pub fn generate_greedy_cached(
tokens
}
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
let sum: f32 = exps.iter().sum();
*rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
let mut acc = 0.0;
for (i, &e) in exps.iter().enumerate() {
acc += e;
if acc >= r {
return i;
}
}
exps.len() - 1
}
/// One incremental decode step for token `tok` at absolute position `pos`: append
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
#[allow(clippy::too_many_arguments)]

View File

@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
#[cfg(not(no_cuda))]
pub mod decode;
#[cfg(not(no_cuda))]
pub use decode::generate_greedy_cached;
pub use decode::{generate_cached, generate_greedy_cached};

View File

@@ -0,0 +1,288 @@
//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
//!
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
//! compute the **group-relative advantage** `A_i = (r_i mean) / (std + ε)` (no
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
//! lesson — here it is an explicit leash, not just a hope).
//!
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
//! eval_arith on the saved checkpoint.
//!
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
#[cfg(no_cuda)]
fn main() {
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
#[cfg(not(no_cuda))]
fn first_answer_segment(c: &str) -> &str {
let s = c.split("<|endoftext|>").next().unwrap_or(c);
s.split('\n').next().unwrap_or(s)
}
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
/// policy enables activation recompute (T13) so its backward fits alongside the
/// frozen reference + the Adam state; the reference only forwards (no backward).
#[cfg(not(no_cuda))]
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
.with_compute_dtype(DType::BF16)
.with_recompute(recompute)
.with_flash(true);
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
m.eval();
m
}
/// Frame (question, completion) like the SFT loader and return the next-token
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
#[cfg(not(no_cuda))]
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
let p_ids: Vec<i32> = tok
.encode(&format!("User: {question}\nAssistant:"))
.into_iter()
.map(|t| t as i32)
.collect();
let a_ids: Vec<i32> = tok
.encode(&format!(" {completion}\n<|endoftext|>"))
.into_iter()
.map(|t| t as i32)
.collect();
let mut tokens = p_ids.clone();
tokens.extend_from_slice(&a_ids);
let mut labels = vec![-100i32; p_ids.len()];
labels.extend_from_slice(&a_ids);
let l = tokens.len();
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= per_row
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
#[cfg(not(no_cuda))]
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
use xtrain_optim::GpuAdamW;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let steps: usize = flag(&args, "--steps", 200);
let group: usize = flag(&args, "--group", 6);
let n_prompts: usize = flag(&args, "--prompts", 8);
let inner: usize = flag(&args, "--inner", 1);
let temp: f32 = flag(&args, "--temp", 1.0);
let beta: f32 = flag(&args, "--beta", 0.04);
let eps: f32 = flag(&args, "--eps", 0.2);
let lr: f32 = flag(&args, "--lr", 1e-6);
let clip: f32 = flag(&args, "--clip", 1.0);
let max_new: usize = flag(&args, "--max-tokens", 24);
let max_add: i64 = flag(&args, "--max-add", 20);
let max_mul: i64 = flag(&args, "--max-mul", 9);
let seed: u64 = flag(&args, "--seed", 20260630);
let log_every: usize = flag(&args, "--log-every", 20);
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
let reference = if beta > 0.0 {
Some(load_model(cfg, device, &init_ckpt, false))
} else {
None
};
let gcfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let params = policy.params();
let mut opt = GpuAdamW::new(0.0);
let mut rng = seed.max(1);
let start = std::time::Instant::now();
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
for step in 0..steps {
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
struct Sample {
input: Vec<i32>,
target: Vec<i32>,
adv: f32,
logp_old: Vec<f32>,
logp_ref: Vec<f32>,
}
let mut batch: Vec<Sample> = Vec::new();
for _ in 0..n_prompts {
let p = gen_problem(&mut rng, &gcfg);
let prompt_ids: Vec<i32> = tok
.encode(&format!("User: {}\nAssistant:", p.question()))
.into_iter()
.map(|t| t as i32)
.collect();
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
for _ in 0..group {
// KV-cache temperature rollout (M2 engine): single-row logits per
// step → far lighter on the allocator than the naive sampler, which
// fragments it over a long rollout (the M4 rollout long-pole).
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont).trim().to_string();
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
comps.push((seg, r));
}
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
let std = var.sqrt();
win_reward += mean * group as f32;
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
win_n += group;
// A whole group with no reward variance gives zero advantage → skip
// (no learning signal, and avoids dividing by ~0).
if std < 1e-6 {
continue;
}
for (seg, r) in &comps {
let adv = (r - mean) / (std + 1e-4);
let (input, target) = frame(&tok, &p.question(), seg);
let logp_old = per_token_logp(&policy, device, &input, &target);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp(r, device, &input, &target),
None => vec![0.0; logp_old.len()],
};
batch.push(Sample { input, target, adv, logp_old, logp_ref });
}
}
// ---- K inner clipped-PG epochs over the captured batch ----
if !batch.is_empty() {
let scale = 1.0 / batch.len() as f32;
for _ in 0..inner {
for s in &batch {
let logits = policy.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(
&logits,
&ids_tensor(&s.target, device),
&s.logp_old,
&s.logp_ref,
s.adv,
eps,
beta,
);
ops::scale(&loss, scale).backward();
}
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
}
if (step + 1) % log_every == 0 || step == steps - 1 {
println!(
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
step + 1,
win_reward / win_n.max(1) as f32,
win_solved,
win_n,
start.elapsed().as_secs_f32(),
);
win_reward = 0.0;
win_solved = 0;
win_n = 0;
// Periodic save so a later OOM (naive rollout fragments the allocator —
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save");
}
}
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save ckpt");
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e}{out_ckpt}");
}