Compare commits

...

3 Commits

Author SHA1 Message Date
096e45b845 docs: M4 — GRPO results (infra + memory/rollout walls + capability-wall negative result)
Implementation log (docs/18) + Phase-3 row (evolution.md): the clipped_pg_loss
op + gates, the actor-learner loop, the easy-task SFT baseline (held-out 18.7%,
plateaus → no generalization), the two systems walls the design doc flagged
(two 1B models OOM the 32GB box → β=0; naive rollout fragments the allocator →
cached temperature sampling, rollout still the long pole), and the result:
format holds, held-out 20.0% (+1.3pp, statistically flat) — the same wall as
DPO. Closes the SFT→KV-cache→DPO→GRPO post-training arc with honest limits.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 17:01:22 +08:00
7fb3b32fd9 post-train: M4 — GRPO actor-learner loop + cached temperature rollout
train_grpo: the online, critic-free RL loop — per step sample B prompts, roll
out G completions each, score with the rule-based checker (reward 0/1), compute
group-relative advantage A=(r−mean)/(std+ε), then K inner clipped_pg_loss
epochs with a KL leash to the frozen reference. Reward = pure 0/1 correctness
(KL is the format protector, the M3 collapse lesson). Tracks mean rollout reward
(the falsifiable "it learns" signal). Periodic checkpoint save.

decode: generate_cached adds temperature sampling to the KV-cache engine (M2) —
single-row [1,vocab] logits per step vs the naive sampler's [seq,vocab], far
lighter on the caching allocator (the naive sampler fragments it over a long
rollout). generate_greedy_cached now routes through it (temp 0); decode_kv
token-identical gate still passes.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 16:59:05 +08:00
aaa77082ef post-train: M4 — clipped_pg_loss + scale_rows (GRPO policy-gradient op)
The GRPO (M4) token-level loss op + the one primitive it needs:

- scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The
  clipped-PG backward scales each completion token's row of (probs − onehot) by
  its own per-token coefficient, which cross_entropy_backward's single scalar
  scale can't express.
- clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token
  ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL
  (k3 estimator), masked to completion tokens. Backward reuses the CE machinery
  (probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0
  (KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 14:07:02 +08:00
10 changed files with 613 additions and 2 deletions

View File

@@ -517,3 +517,83 @@ pub fn dpo_loss(
}), }),
) )
} }
/// GRPO clipped policy-gradient loss (M4) for ONE completion, a scalar `[1]` Var
/// with the policy logits as the single parent. Per non-ignored (completion) token
/// `t` (`target[t] ≥ 0`):
/// `logπθ_t = log softmax(logits[t])[target_t]` (`= per_row[t]` of cross_entropy)
/// `ρ_t = exp(logπθ_t logp_old[t])`
/// `pg_t = min(ρ_t·A, clip(ρ_t, 1ε, 1+ε)·A)`
/// `kl_t = exp(logp_ref[t] logπθ_t) (logp_ref[t] logπθ_t) 1` (k3 estimator)
/// `L = mean_t pg_t + β·mean_t kl_t` over the `N` completion tokens.
///
/// `advantage` `A` is the group-relative advantage (constant per completion in
/// GRPO); `logp_old`/`logp_ref` are per-position constants (old policy at rollout
/// time / frozen reference). Backward reuses the CE machinery + the per-row
/// `scale_rows`: `dL/dlogits[t,:] = g_t·(onehot probs)[t,:]` with
/// `g_t = (1/N)A·ρ_t·[unclipped active] + (β/N)(1 exp(logp_ref_t logπθ_t))`.
/// Degenerate points the gate pins: `A=0` ⇒ only the KL term; `ε→∞` ⇒ vanilla PG
/// (no clip); `β=0` ⇒ no KL term.
#[allow(clippy::too_many_arguments)]
pub fn clipped_pg_loss(
logits: &Var,
target: &Tensor,
logp_old: &[f32],
logp_ref: &[f32],
advantage: f32,
eps: f32,
beta: f32,
) -> Var {
use xtrain_tensor::Device;
let logit_dtype = logits.value().dtype();
let (probs, per_row) = logits.value().cross_entropy(target);
let rows = per_row.shape()[0];
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per position");
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per position");
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
let (mut pg_sum, mut kl_sum, mut n) = (0f32, 0f32, 0f32);
for t in 0..rows {
if target_h[t] < 0 {
continue; // masked (prompt) position — no contribution, no gradient
}
n += 1.0;
let lp = -per_row_h[t]; // logπθ_t
let ratio = (lp - logp_old[t]).exp();
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
let (unclipped_term, clipped_term) = (ratio * advantage, clipped * advantage);
pg_sum += unclipped_term.min(clipped_term);
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
let d = logp_ref[t] - lp;
kl_sum += d.exp() - d - 1.0;
let pg_grad = if active { -advantage * ratio } else { 0.0 };
let kl_grad = beta * (1.0 - d.exp());
s[t] = -(pg_grad + kl_grad); // dL/dlogits = g·(onehotprobs) = g·(probsonehot)
}
let inv_n = if n > 0.0 { 1.0 / n } else { 1.0 };
for v in &mut s {
*v *= inv_n;
}
let loss_val = -pg_sum * inv_n + beta * kl_sum * inv_n;
let dev = logits.value().device();
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
let target = target.clone();
Var::from_op(
out,
vec![logits.clone()],
Box::new(move |d, parents| {
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
// (probs onehot), masked rows already 0; per-row scale by s; × upstream.
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
let mut dx = ce.scale_rows(&s_dev);
if up != 1.0 {
dx = dx.scale(up);
}
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}

View File

@@ -1085,3 +1085,95 @@ fn dpo_loss_bwd_and_degenerate() {
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}"); assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)"); println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
} }
// clipped_pg_loss (M4 GRPO): per-token clipped PG + k3 KL, one completion. Grad-check
// the active (in-trust-region) path + the A=0 (KL-only) path, plus value-level
// degenerate checks (ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL).
#[test]
fn clipped_pg_loss_bwd_and_degenerate() {
require_gpu();
let (rows, cols) = (6usize, 10usize);
let x_h = fill(rows * cols, 303);
// rows 0,1 masked (prompt); 2..6 supervised (completion).
let targets: Vec<i32> = (0..rows)
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
.collect();
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
// logp_old = logπθ at the base logits ⇒ ρ≈1 (in trust region → active path).
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
let logp_old: Vec<f32> = per_row0
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect();
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect(); // exercise KL
let (eps, beta) = (0.2f32, 0.1f32);
// Host replica of the forward loss as a function of per-row CE values.
let host_loss = {
let (tg, lo, lr) = (targets.clone(), logp_old.clone(), logp_ref.clone());
move |per_row_h: &[f32], a: f32, e: f32, b: f32| -> f32 {
let (mut pg, mut kl, mut n) = (0f32, 0f32, 0f32);
for t in 0..per_row_h.len() {
if tg[t] < 0 {
continue;
}
n += 1.0;
let lp = -per_row_h[t];
let ratio = (lp - lo[t]).exp();
let clipped = ratio.clamp(1.0 - e, 1.0 + e);
pg += (ratio * a).min(clipped * a);
let d = lr[t] - lp;
kl += d.exp() - d - 1.0;
}
let inv = if n > 0.0 { 1.0 / n } else { 1.0 };
-pg * inv + b * kl * inv
}
};
let per_row_of = |v: &[f32], s: &[usize]| {
let (_, pr) = cuda(v, s).cross_entropy(&mk_target());
pr.to_device(Device::Cpu).as_slice::<f32>().to_vec()
};
// (1) grad-check the active PG path (A>0, ρ≈1).
let adv = 0.7f32;
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::clipped_pg_loss(&x, &mk_target(), &logp_old, &logp_ref, adv, eps, beta);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let hl = host_loss.clone();
let lx = move |v: &[f32], s: &[usize]| hl(&per_row_of(v, s), adv, eps, beta);
report(
"clipped_pg dX (active)",
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
// (2) grad-check the A=0 path (loss = β·mean KL; PG gradient must vanish).
let x0 = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss0 = ops::clipped_pg_loss(&x0, &mk_target(), &logp_old, &logp_ref, 0.0, eps, beta);
loss0.backward();
let dx0 = x0.grad().unwrap().to_device(Device::Cpu);
let hl0 = host_loss.clone();
let lx0 = move |v: &[f32], s: &[usize]| hl0(&per_row_of(v, s), 0.0, eps, beta);
report(
"clipped_pg dX (A=0, KL only)",
&grad_check(&x_h, &[rows, cols], &lx0, dx0.as_slice::<f32>(), cfg_nonlinear()),
);
// (3) ε→∞ ⇒ vanilla PG (no clip): loss value == mean(ρA) + β·mean KL.
let big = 1e9f32;
let lv = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, big, beta);
let got = lv.value().to_device(Device::Cpu).as_slice::<f32>()[0];
let pr0 = per_row_of(&x_h, &[rows, cols]);
let want = host_loss(&pr0, adv, big, beta);
assert!((got - want).abs() < 1e-4, "ε→∞ vanilla loss mismatch: {got} vs {want}");
// (4) β=0 ⇒ no KL term (loss == mean pg only).
let lvb = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, eps, 0.0);
let gotb = lvb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
let wantb = host_loss(&pr0, adv, eps, 0.0);
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
}

View File

@@ -152,6 +152,15 @@ unsafe extern "C" {
pos0: i32, pos0: i32,
s: CudaStream, s: CudaStream,
); );
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
pub fn launch_scale_rows_f32(
x: *const f32,
s: *const f32,
y: *mut f32,
rows: i32,
cols: i32,
stream: CudaStream,
);
pub fn launch_rope_dx_f32( pub fn launch_rope_dx_f32(
dy: *const f32, dy: *const f32,
dx: *mut f32, dx: *mut f32,

View File

@@ -83,6 +83,24 @@ pub fn generate_greedy_cached(
device: Device, device: Device,
prompt: &[i32], prompt: &[i32],
max_new: usize, max_new: usize,
) -> Vec<i32> {
let mut rng = 0u64;
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
}
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
/// lighter on memory + the allocator — the naive sampler fragments the caching
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
pub fn generate_cached(
model: &TinyTransformer,
device: Device,
prompt: &[i32],
max_new: usize,
temperature: f32,
rng_state: &mut u64,
) -> Vec<i32> { ) -> Vec<i32> {
assert!(!prompt.is_empty(), "prompt must be non-empty"); assert!(!prompt.is_empty(), "prompt must be non-empty");
let cfg = model.config(); let cfg = model.config();
@@ -116,7 +134,11 @@ pub fn generate_greedy_cached(
} }
for _ in 0..max_new { for _ in 0..max_new {
let next = argmax(&logits) as i32; let next = if temperature <= 0.0 {
argmax(&logits) as i32
} else {
sample_temperature(&logits, temperature, rng_state) as i32
};
tokens.push(next); tokens.push(next);
let pos = tokens.len() - 1; // absolute position of the token just appended let pos = tokens.len() - 1; // absolute position of the token just appended
logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head); logits = decode_step(&params, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
@@ -124,6 +146,26 @@ pub fn generate_greedy_cached(
tokens tokens
} }
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
let sum: f32 = exps.iter().sum();
*rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
let mut acc = 0.0;
for (i, &e) in exps.iter().enumerate() {
acc += e;
if acc >= r {
return i;
}
}
exps.len() - 1
}
/// One incremental decode step for token `tok` at absolute position `pos`: append /// One incremental decode step for token `tok` at absolute position `pos`: append
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`. /// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]

View File

@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
pub mod decode; pub mod decode;
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
pub use decode::generate_greedy_cached; pub use decode::{generate_cached, generate_greedy_cached};

View File

@@ -941,6 +941,31 @@ impl Tensor {
dx dx
} }
/// Per-row scale: `out[r,c] = self[r,c] * s[r]`. `self`:[rows,cols] F32,
/// `s`:[rows] F32. Used by the GRPO (M4) policy-gradient backward, where each
/// completion token's row of `(probs onehot)` is scaled by its own per-token
/// coefficient (the per-token clipped-PG + KL gradient). Forward-only.
#[cfg(not(no_cuda))]
pub fn scale_rows(&self, s: &Tensor) -> Self {
assert_eq!(self.ndim(), 2, "scale_rows requires a 2D tensor");
assert_eq!(self.dtype, DType::F32, "scale_rows is F32");
assert_eq!(s.dtype, DType::F32, "scale vector is F32");
let (rows, cols) = (self.shape[0], self.shape[1]);
assert_eq!(s.numel(), rows, "scale vector must have one entry per row");
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_scale_rows_f32(
self.data_ptr() as *const f32,
s.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
rows as i32,
cols as i32,
std::ptr::null_mut(),
);
}
out
}
// --- Structural / model ops (the T5 kernels) --- // --- Structural / model ops (the T5 kernels) ---
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a /// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a

View File

@@ -0,0 +1,288 @@
//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
//!
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
//! compute the **group-relative advantage** `A_i = (r_i mean) / (std + ε)` (no
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
//! lesson — here it is an explicit leash, not just a hope).
//!
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
//! eval_arith on the saved checkpoint.
//!
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
#[cfg(no_cuda)]
fn main() {
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
#[cfg(not(no_cuda))]
use xtrain_tensor::{DType, Device};
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
#[cfg(not(no_cuda))]
fn first_answer_segment(c: &str) -> &str {
let s = c.split("<|endoftext|>").next().unwrap_or(c);
s.split('\n').next().unwrap_or(s)
}
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
/// policy enables activation recompute (T13) so its backward fits alongside the
/// frozen reference + the Adam state; the reference only forwards (no backward).
#[cfg(not(no_cuda))]
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
})
.with_compute_dtype(DType::BF16)
.with_recompute(recompute)
.with_flash(true);
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
m.eval();
m
}
/// Frame (question, completion) like the SFT loader and return the next-token
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
#[cfg(not(no_cuda))]
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
let p_ids: Vec<i32> = tok
.encode(&format!("User: {question}\nAssistant:"))
.into_iter()
.map(|t| t as i32)
.collect();
let a_ids: Vec<i32> = tok
.encode(&format!(" {completion}\n<|endoftext|>"))
.into_iter()
.map(|t| t as i32)
.collect();
let mut tokens = p_ids.clone();
tokens.extend_from_slice(&a_ids);
let mut labels = vec![-100i32; p_ids.len()];
labels.extend_from_slice(&a_ids);
let l = tokens.len();
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= per_row
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
#[cfg(not(no_cuda))]
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
let logits = model.forward(&ids_tensor(input, device)).value();
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.map(|p| -p)
.collect()
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
use xtrain_optim::GpuAdamW;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let steps: usize = flag(&args, "--steps", 200);
let group: usize = flag(&args, "--group", 6);
let n_prompts: usize = flag(&args, "--prompts", 8);
let inner: usize = flag(&args, "--inner", 1);
let temp: f32 = flag(&args, "--temp", 1.0);
let beta: f32 = flag(&args, "--beta", 0.04);
let eps: f32 = flag(&args, "--eps", 0.2);
let lr: f32 = flag(&args, "--lr", 1e-6);
let clip: f32 = flag(&args, "--clip", 1.0);
let max_new: usize = flag(&args, "--max-tokens", 24);
let max_add: i64 = flag(&args, "--max-add", 20);
let max_mul: i64 = flag(&args, "--max-mul", 9);
let seed: u64 = flag(&args, "--seed", 20260630);
let log_every: usize = flag(&args, "--log-every", 20);
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
let reference = if beta > 0.0 {
Some(load_model(cfg, device, &init_ckpt, false))
} else {
None
};
let gcfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let params = policy.params();
let mut opt = GpuAdamW::new(0.0);
let mut rng = seed.max(1);
let start = std::time::Instant::now();
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
for step in 0..steps {
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
struct Sample {
input: Vec<i32>,
target: Vec<i32>,
adv: f32,
logp_old: Vec<f32>,
logp_ref: Vec<f32>,
}
let mut batch: Vec<Sample> = Vec::new();
for _ in 0..n_prompts {
let p = gen_problem(&mut rng, &gcfg);
let prompt_ids: Vec<i32> = tok
.encode(&format!("User: {}\nAssistant:", p.question()))
.into_iter()
.map(|t| t as i32)
.collect();
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
for _ in 0..group {
// KV-cache temperature rollout (M2 engine): single-row logits per
// step → far lighter on the allocator than the naive sampler, which
// fragments it over a long rollout (the M4 rollout long-pole).
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont).trim().to_string();
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
comps.push((seg, r));
}
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
let std = var.sqrt();
win_reward += mean * group as f32;
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
win_n += group;
// A whole group with no reward variance gives zero advantage → skip
// (no learning signal, and avoids dividing by ~0).
if std < 1e-6 {
continue;
}
for (seg, r) in &comps {
let adv = (r - mean) / (std + 1e-4);
let (input, target) = frame(&tok, &p.question(), seg);
let logp_old = per_token_logp(&policy, device, &input, &target);
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
let logp_ref = match &reference {
Some(r) => per_token_logp(r, device, &input, &target),
None => vec![0.0; logp_old.len()],
};
batch.push(Sample { input, target, adv, logp_old, logp_ref });
}
}
// ---- K inner clipped-PG epochs over the captured batch ----
if !batch.is_empty() {
let scale = 1.0 / batch.len() as f32;
for _ in 0..inner {
for s in &batch {
let logits = policy.forward(&ids_tensor(&s.input, device));
let loss = ops::clipped_pg_loss(
&logits,
&ids_tensor(&s.target, device),
&s.logp_old,
&s.logp_ref,
s.adv,
eps,
beta,
);
ops::scale(&loss, scale).backward();
}
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
}
}
if (step + 1) % log_every == 0 || step == steps - 1 {
println!(
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
step + 1,
win_reward / win_n.max(1) as f32,
win_solved,
win_n,
start.elapsed().as_secs_f32(),
);
win_reward = 0.0;
win_solved = 0;
win_n = 0;
// Periodic save so a later OOM (naive rollout fragments the allocator —
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save");
}
}
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save ckpt");
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e}{out_ckpt}");
}

View File

@@ -269,6 +269,23 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0); rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
} }
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
// (M4) policy-gradient backward, where each completion token's row of
// (probs onehot) is scaled by its own per-token coefficient.
__global__ void scale_rows_k(const float* x, const float* s, float* y,
int rows, int cols) {
int r = blockIdx.x;
float sr = s[r];
for (int c = threadIdx.x; c < cols; c += blockDim.x)
y[r * cols + c] = x[r * cols + c] * sr;
}
void launch_scale_rows_f32(const float* x, const float* s, float* y,
int rows, int cols, void* st) {
int blk = cols < 1024 ? cols : 1024;
if (blk < 32) blk = 32;
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
}
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
float theta, int period) { float theta, int period) {
int tok = blockIdx.x; int tok = blockIdx.x;

View File

@@ -466,3 +466,59 @@ verifiable reward* online (sample → check → reinforce what is genuinely corr
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init, is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
margin/acc rise); the correctness flatness is the reported finding, not a bug. margin/acc rise); the correctness flatness is the reported finding, not a bug.
### M4 — GRPO (online RL, critic-free, landed; infra + two honest systems walls)
The centerpiece: generation INSIDE the training loop. Infra built and gated; the run surfaces
two concrete systems findings (the memory long-pole + the rollout long-pole, both flagged in the
design doc's Risks) and the same capability wall as M3.
**Task made learnable first (per the aligned decision "easier task → then M4"):** the v12 SFT
model scores ~8% on the hard task *and* on easy problems — it learned format, not arithmetic. So
the easy task (operands ≤20, ops `+ ×`) was re-SFT'd from the v12 base → **held-out 18.7%**
(100% format), a baseline with reward variance for GRPO. Note: even easy arithmetic plateaus at
~19% held-out (250 vs 600 SFT steps identical) — a 1B web-text model does not generalize the
add/sub algorithm from ~550 examples; it memorizes train (982 total problems, 550 seen).
**New op (`xtrain-autodiff`, reuses the CE kernel + one new primitive):**
- `clipped_pg_loss(logits, target, logp_old, logp_ref, A, ε, β)` — per completion token
`ρ_t = exp(logπθ_t logp_old_t)`, `L = mean min(ρA, clip(ρ,1±ε)A) + β·mean KL` (k3), masked
to completion tokens. Backward reuses `(probs onehot)` + `scale_rows` (a new ~5-line per-row
scale kernel — the per-token coefficient varies, which CE-backward's single scalar can't
express). **Gate:** grad-check the active PG path + the A=0 (KL-only) path; degenerate value
checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
**Loop (`train_grpo`):** per step — sample B prompts, roll out G completions each, score (reward
0/1), group-relative advantage `A=(rmean)/(std+ε)` (no critic; all-correct/all-wrong groups
skipped — zero advantage), capture `logπθ_old`/`logπref` per token, K inner clipped-PG epochs.
Rollout uses the M2 KV-cache engine with **temperature sampling** (added in M4): single-row
`[1,vocab]` logits per step vs the naive sampler's `[seq,vocab]`.
**Systems wall #1 — memory (the design doc's "two/three resident models"):** KL-leash GRPO needs
policy + frozen reference, two 1.05B fp32-master models + AdamW m/v ≈ 21 GB fixed + training
activations → unreliably OOMs on a 32 GB 5090 (fragmentation tips it over). To get a completing
run, `β=0` (pure PG) drops the reference model (4.2 GB). So the *principled* KL-leash version is
memory-bound at this model size on this hardware — a real, reported constraint, not a bug.
**Systems wall #2 — rollout (the design doc's "rollout is the long pole"):** the naive sampler's
growing `[seq,vocab]` allocations fragment the caching allocator over a long rollout → OOM. The
cached temperature rollout (single-row logits) is lighter; but single-sequence cached decode is
slow (the M2a host-round-trip), so rollout still dominates wall-clock (~16 s/step at G=6·B=6).
Batched ragged decode (M2b) is the real fix and is deferred to where it is load-bearing.
**Result (easy task, β=0, G=6·B=6, 40 steps, lr 5e-7; 150 held-out, vs SFT 28/150 = 18.7%):**
mean rollout reward fluctuates ~0.580.81 (noisy, inflated by train-set overlap in the sampled
problems); **format stays 100/100** (no collapse even without the KL leash, at this gentle lr);
**held-out 30/150 = 20.0%**`+1.3 pp`, within the ~3% std-error of 150 prompts, i.e.
**statistically flat**, the same wall as M3 DPO.
**The consistent M3+M4 lesson:** on a task where the base model lacks the underlying capability,
**neither offline preference optimization (DPO) nor online RL (GRPO) moves held-out correctness**
— each optimizes its objective (margin / reward) on the *training distribution* it can reach
(here inflated by memorization), but cannot install a *generalizable* algorithm the model never
had. RL reinforces what the model already does; it does not teach arithmetic. Gate met for M4 =
the infra is correct (PG/KL grad-checks + degenerate checks, the loop runs, reward signal + KL
leash wired, format held); the held-out flatness + the two memory/throughput walls are the
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
limits on what alignment can do for a capability the base model lacks.

View File

@@ -101,6 +101,8 @@ Phase 1/2 把**预训练全栈**学完后Phase 3 转向**后训练 infra**
**M3DPO离线偏好优化已落地 + 诚实负结果)**两个复用 CE kernel 的新算子零新 CUDA)——`seq_logprob`Σ log πθ over mask 反向 = CE_backward 取负求和grad-check + mask)、`dpo_loss`log σ(Δ) policy logprob 父节点grad-check + 退化 Δ=0→log2/∓β·½、β=0→0。造对`gen_dpo_pairs`= chosen=gold、rejected=SFT 自己 greedy M2a 引擎的格式合法**错误**答案8% greedy 答对的跳过)。训练`train_dpo` SFT ckpt 同时作 policy 和冻结 reference**一次性预算 reference logprob 并缓存**单模型驻留每步 policy forward chosen+rejected seq_logprob dpo_loss forward 共享 param 累积梯度**loss 起步恰好 log2**Δ=0 内置校验)。**结果v12, 1500 , β0.1100 留出题 vs SFT 8/100**reward-margin pref-acc 干净上升loss 被正确优化infra **不转化为 held-out 正确率**——lr5e-7×3007%、×8005%、lr1e-6×2000margin+34 **崩溃**0% 格式输出垃圾三档都在 100 ~2.7% 标准误内 = 统计持平。**教训**chosen/rejected 只差最终数字 tokenDPO 提升的是**特定训练对的 token 偏好reweight 现有分布, install 能力**base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布不连贯。**DPO chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样check强化真正对的)而非固定对的 proxy( GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 open question)。 v8/T17 同源的诚实账跑通+闸门齐全,负结果如实记 **M3DPO离线偏好优化已落地 + 诚实负结果)**两个复用 CE kernel 的新算子零新 CUDA)——`seq_logprob`Σ log πθ over mask 反向 = CE_backward 取负求和grad-check + mask)、`dpo_loss`log σ(Δ) policy logprob 父节点grad-check + 退化 Δ=0→log2/∓β·½、β=0→0。造对`gen_dpo_pairs`= chosen=gold、rejected=SFT 自己 greedy M2a 引擎的格式合法**错误**答案8% greedy 答对的跳过)。训练`train_dpo` SFT ckpt 同时作 policy 和冻结 reference**一次性预算 reference logprob 并缓存**单模型驻留每步 policy forward chosen+rejected seq_logprob dpo_loss forward 共享 param 累积梯度**loss 起步恰好 log2**Δ=0 内置校验)。**结果v12, 1500 , β0.1100 留出题 vs SFT 8/100**reward-margin pref-acc 干净上升loss 被正确优化infra **不转化为 held-out 正确率**——lr5e-7×3007%、×8005%、lr1e-6×2000margin+34 **崩溃**0% 格式输出垃圾三档都在 100 ~2.7% 标准误内 = 统计持平。**教训**chosen/rejected 只差最终数字 tokenDPO 提升的是**特定训练对的 token 偏好reweight 现有分布, install 能力**base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布不连贯。**DPO chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样check强化真正对的)而非固定对的 proxy( GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 open question)。 v8/T17 同源的诚实账跑通+闸门齐全,负结果如实记
**M4GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**新算子 `clipped_pg_loss`per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。 `train_grpo`:采 B prompt × rollout G checker reward 0/1 group-relative advantage `(rmean)/(std+ε)`( critic,全对/全错组跳过)→ πθ_old/πref per-token K 内层 clipped-PGrollout **M2 引擎 + 新加的 temperature 采样**单行 logits naive `[seq,vocab]` )。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ easy(操作数20)上从 v12 base 重训 SFT held-out **18.7%**; 250/600 步同样 18.7% = 1B web-text 模型从 ~550 **不泛化加减法只记 train**。**两道系统墙(设计文档 Risks 预言)**: 显存——KL-leash policy+reference 两个 1B fp32-master+Adam21GB,加激活在 32GB 5090 上不稳定 OOM 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81( train 重叠抬),**format 100/100 不崩**(温和 lr β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT KV-cache DPO GRPO **,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md) ## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)
- **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。 - **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。