Compare commits
3 Commits
99090465bf
...
096e45b845
| Author | SHA1 | Date | |
|---|---|---|---|
| 096e45b845 | |||
| 7fb3b32fd9 | |||
| aaa77082ef |
@@ -517,3 +517,83 @@ pub fn dpo_loss(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// GRPO clipped policy-gradient loss (M4) for ONE completion, a scalar `[1]` Var
|
||||||
|
/// with the policy logits as the single parent. Per non-ignored (completion) token
|
||||||
|
/// `t` (`target[t] ≥ 0`):
|
||||||
|
/// `logπθ_t = log softmax(logits[t])[target_t]` (`= −per_row[t]` of cross_entropy)
|
||||||
|
/// `ρ_t = exp(logπθ_t − logp_old[t])`
|
||||||
|
/// `pg_t = min(ρ_t·A, clip(ρ_t, 1−ε, 1+ε)·A)`
|
||||||
|
/// `kl_t = exp(logp_ref[t] − logπθ_t) − (logp_ref[t] − logπθ_t) − 1` (k3 estimator)
|
||||||
|
/// `L = −mean_t pg_t + β·mean_t kl_t` over the `N` completion tokens.
|
||||||
|
///
|
||||||
|
/// `advantage` `A` is the group-relative advantage (constant per completion in
|
||||||
|
/// GRPO); `logp_old`/`logp_ref` are per-position constants (old policy at rollout
|
||||||
|
/// time / frozen reference). Backward reuses the CE machinery + the per-row
|
||||||
|
/// `scale_rows`: `dL/dlogits[t,:] = g_t·(onehot − probs)[t,:]` with
|
||||||
|
/// `g_t = −(1/N)A·ρ_t·[unclipped active] + (β/N)(1 − exp(logp_ref_t − logπθ_t))`.
|
||||||
|
/// Degenerate points the gate pins: `A=0` ⇒ only the KL term; `ε→∞` ⇒ vanilla PG
|
||||||
|
/// (no clip); `β=0` ⇒ no KL term.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn clipped_pg_loss(
|
||||||
|
logits: &Var,
|
||||||
|
target: &Tensor,
|
||||||
|
logp_old: &[f32],
|
||||||
|
logp_ref: &[f32],
|
||||||
|
advantage: f32,
|
||||||
|
eps: f32,
|
||||||
|
beta: f32,
|
||||||
|
) -> Var {
|
||||||
|
use xtrain_tensor::Device;
|
||||||
|
let logit_dtype = logits.value().dtype();
|
||||||
|
let (probs, per_row) = logits.value().cross_entropy(target);
|
||||||
|
let rows = per_row.shape()[0];
|
||||||
|
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
|
||||||
|
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per position");
|
||||||
|
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per position");
|
||||||
|
|
||||||
|
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
|
||||||
|
let (mut pg_sum, mut kl_sum, mut n) = (0f32, 0f32, 0f32);
|
||||||
|
for t in 0..rows {
|
||||||
|
if target_h[t] < 0 {
|
||||||
|
continue; // masked (prompt) position — no contribution, no gradient
|
||||||
|
}
|
||||||
|
n += 1.0;
|
||||||
|
let lp = -per_row_h[t]; // logπθ_t
|
||||||
|
let ratio = (lp - logp_old[t]).exp();
|
||||||
|
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
|
||||||
|
let (unclipped_term, clipped_term) = (ratio * advantage, clipped * advantage);
|
||||||
|
pg_sum += unclipped_term.min(clipped_term);
|
||||||
|
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
|
||||||
|
let d = logp_ref[t] - lp;
|
||||||
|
kl_sum += d.exp() - d - 1.0;
|
||||||
|
let pg_grad = if active { -advantage * ratio } else { 0.0 };
|
||||||
|
let kl_grad = beta * (1.0 - d.exp());
|
||||||
|
s[t] = -(pg_grad + kl_grad); // dL/dlogits = g·(onehot−probs) = −g·(probs−onehot)
|
||||||
|
}
|
||||||
|
let inv_n = if n > 0.0 { 1.0 / n } else { 1.0 };
|
||||||
|
for v in &mut s {
|
||||||
|
*v *= inv_n;
|
||||||
|
}
|
||||||
|
let loss_val = -pg_sum * inv_n + beta * kl_sum * inv_n;
|
||||||
|
let dev = logits.value().device();
|
||||||
|
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
|
||||||
|
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
|
||||||
|
|
||||||
|
let target = target.clone();
|
||||||
|
Var::from_op(
|
||||||
|
out,
|
||||||
|
vec![logits.clone()],
|
||||||
|
Box::new(move |d, parents| {
|
||||||
|
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
// (probs − onehot), masked rows already 0; per-row scale by s; × upstream.
|
||||||
|
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
|
||||||
|
let mut dx = ce.scale_rows(&s_dev);
|
||||||
|
if up != 1.0 {
|
||||||
|
dx = dx.scale(up);
|
||||||
|
}
|
||||||
|
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1085,3 +1085,95 @@ fn dpo_loss_bwd_and_degenerate() {
|
|||||||
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
||||||
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clipped_pg_loss (M4 GRPO): per-token clipped PG + k3 KL, one completion. Grad-check
|
||||||
|
// the active (in-trust-region) path + the A=0 (KL-only) path, plus value-level
|
||||||
|
// degenerate checks (ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL).
|
||||||
|
#[test]
|
||||||
|
fn clipped_pg_loss_bwd_and_degenerate() {
|
||||||
|
require_gpu();
|
||||||
|
let (rows, cols) = (6usize, 10usize);
|
||||||
|
let x_h = fill(rows * cols, 303);
|
||||||
|
// rows 0,1 masked (prompt); 2..6 supervised (completion).
|
||||||
|
let targets: Vec<i32> = (0..rows)
|
||||||
|
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
|
||||||
|
.collect();
|
||||||
|
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||||
|
|
||||||
|
// logp_old = logπθ at the base logits ⇒ ρ≈1 (in trust region → active path).
|
||||||
|
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
|
||||||
|
let logp_old: Vec<f32> = per_row0
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.iter()
|
||||||
|
.map(|p| -p)
|
||||||
|
.collect();
|
||||||
|
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect(); // exercise KL
|
||||||
|
let (eps, beta) = (0.2f32, 0.1f32);
|
||||||
|
|
||||||
|
// Host replica of the forward loss as a function of per-row CE values.
|
||||||
|
let host_loss = {
|
||||||
|
let (tg, lo, lr) = (targets.clone(), logp_old.clone(), logp_ref.clone());
|
||||||
|
move |per_row_h: &[f32], a: f32, e: f32, b: f32| -> f32 {
|
||||||
|
let (mut pg, mut kl, mut n) = (0f32, 0f32, 0f32);
|
||||||
|
for t in 0..per_row_h.len() {
|
||||||
|
if tg[t] < 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
n += 1.0;
|
||||||
|
let lp = -per_row_h[t];
|
||||||
|
let ratio = (lp - lo[t]).exp();
|
||||||
|
let clipped = ratio.clamp(1.0 - e, 1.0 + e);
|
||||||
|
pg += (ratio * a).min(clipped * a);
|
||||||
|
let d = lr[t] - lp;
|
||||||
|
kl += d.exp() - d - 1.0;
|
||||||
|
}
|
||||||
|
let inv = if n > 0.0 { 1.0 / n } else { 1.0 };
|
||||||
|
-pg * inv + b * kl * inv
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let per_row_of = |v: &[f32], s: &[usize]| {
|
||||||
|
let (_, pr) = cuda(v, s).cross_entropy(&mk_target());
|
||||||
|
pr.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||||
|
};
|
||||||
|
|
||||||
|
// (1) grad-check the active PG path (A>0, ρ≈1).
|
||||||
|
let adv = 0.7f32;
|
||||||
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||||
|
let loss = ops::clipped_pg_loss(&x, &mk_target(), &logp_old, &logp_ref, adv, eps, beta);
|
||||||
|
loss.backward();
|
||||||
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||||
|
let hl = host_loss.clone();
|
||||||
|
let lx = move |v: &[f32], s: &[usize]| hl(&per_row_of(v, s), adv, eps, beta);
|
||||||
|
report(
|
||||||
|
"clipped_pg dX (active)",
|
||||||
|
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// (2) grad-check the A=0 path (loss = β·mean KL; PG gradient must vanish).
|
||||||
|
let x0 = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||||
|
let loss0 = ops::clipped_pg_loss(&x0, &mk_target(), &logp_old, &logp_ref, 0.0, eps, beta);
|
||||||
|
loss0.backward();
|
||||||
|
let dx0 = x0.grad().unwrap().to_device(Device::Cpu);
|
||||||
|
let hl0 = host_loss.clone();
|
||||||
|
let lx0 = move |v: &[f32], s: &[usize]| hl0(&per_row_of(v, s), 0.0, eps, beta);
|
||||||
|
report(
|
||||||
|
"clipped_pg dX (A=0, KL only)",
|
||||||
|
&grad_check(&x_h, &[rows, cols], &lx0, dx0.as_slice::<f32>(), cfg_nonlinear()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// (3) ε→∞ ⇒ vanilla PG (no clip): loss value == −mean(ρA) + β·mean KL.
|
||||||
|
let big = 1e9f32;
|
||||||
|
let lv = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, big, beta);
|
||||||
|
let got = lv.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let pr0 = per_row_of(&x_h, &[rows, cols]);
|
||||||
|
let want = host_loss(&pr0, adv, big, beta);
|
||||||
|
assert!((got - want).abs() < 1e-4, "ε→∞ vanilla loss mismatch: {got} vs {want}");
|
||||||
|
|
||||||
|
// (4) β=0 ⇒ no KL term (loss == −mean pg only).
|
||||||
|
let lvb = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, eps, 0.0);
|
||||||
|
let gotb = lvb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let wantb = host_loss(&pr0, adv, eps, 0.0);
|
||||||
|
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
|
||||||
|
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
|
||||||
|
}
|
||||||
|
|||||||
@@ -152,6 +152,15 @@ unsafe extern "C" {
|
|||||||
pos0: i32,
|
pos0: i32,
|
||||||
s: CudaStream,
|
s: CudaStream,
|
||||||
);
|
);
|
||||||
|
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
||||||
|
pub fn launch_scale_rows_f32(
|
||||||
|
x: *const f32,
|
||||||
|
s: *const f32,
|
||||||
|
y: *mut f32,
|
||||||
|
rows: i32,
|
||||||
|
cols: i32,
|
||||||
|
stream: CudaStream,
|
||||||
|
);
|
||||||
pub fn launch_rope_dx_f32(
|
pub fn launch_rope_dx_f32(
|
||||||
dy: *const f32,
|
dy: *const f32,
|
||||||
dx: *mut f32,
|
dx: *mut f32,
|
||||||
|
|||||||
@@ -83,6 +83,24 @@ pub fn generate_greedy_cached(
|
|||||||
device: Device,
|
device: Device,
|
||||||
prompt: &[i32],
|
prompt: &[i32],
|
||||||
max_new: usize,
|
max_new: usize,
|
||||||
|
) -> Vec<i32> {
|
||||||
|
let mut rng = 0u64;
|
||||||
|
generate_cached(model, device, prompt, max_new, 0.0, &mut rng)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// KV-cache decode with temperature sampling (`temperature == 0` → greedy argmax,
|
||||||
|
/// matching [`generate_greedy_cached`]; otherwise sample from `softmax(logits/T)`).
|
||||||
|
/// The KV-cache rollout the GRPO loop uses: each step allocates only a single-row
|
||||||
|
/// `[1, vocab]` logits buffer (vs the naive sampler's `[seq, vocab]`), so it is far
|
||||||
|
/// lighter on memory + the allocator — the naive sampler fragments the caching
|
||||||
|
/// allocator over a long rollout, which is the M4 "rollout is the long pole" wall.
|
||||||
|
pub fn generate_cached(
|
||||||
|
model: &TinyTransformer,
|
||||||
|
device: Device,
|
||||||
|
prompt: &[i32],
|
||||||
|
max_new: usize,
|
||||||
|
temperature: f32,
|
||||||
|
rng_state: &mut u64,
|
||||||
) -> Vec<i32> {
|
) -> Vec<i32> {
|
||||||
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
assert!(!prompt.is_empty(), "prompt must be non-empty");
|
||||||
let cfg = model.config();
|
let cfg = model.config();
|
||||||
@@ -116,7 +134,11 @@ pub fn generate_greedy_cached(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _ in 0..max_new {
|
for _ in 0..max_new {
|
||||||
let next = argmax(&logits) as i32;
|
let next = if temperature <= 0.0 {
|
||||||
|
argmax(&logits) as i32
|
||||||
|
} else {
|
||||||
|
sample_temperature(&logits, temperature, rng_state) as i32
|
||||||
|
};
|
||||||
tokens.push(next);
|
tokens.push(next);
|
||||||
let pos = tokens.len() - 1; // absolute position of the token just appended
|
let pos = tokens.len() - 1; // absolute position of the token just appended
|
||||||
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
|
logits = decode_step(¶ms, cfg, cdt, device, &mut cache, next, pos, embed, final_norm, lm_head);
|
||||||
@@ -124,6 +146,26 @@ pub fn generate_greedy_cached(
|
|||||||
tokens
|
tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sample a token from `softmax(logits / temperature)` (numerically stable). Same
|
||||||
|
/// LCG + inverse-CDF scheme as the naive `sample::sample_temperature`.
|
||||||
|
fn sample_temperature(row: &[f32], temperature: f32, rng_state: &mut u64) -> usize {
|
||||||
|
let max = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||||
|
let exps: Vec<f32> = row.iter().map(|&x| ((x - max) / temperature).exp()).collect();
|
||||||
|
let sum: f32 = exps.iter().sum();
|
||||||
|
*rng_state = rng_state
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
|
let r = ((*rng_state >> 32) as f32 / u32::MAX as f32) * sum;
|
||||||
|
let mut acc = 0.0;
|
||||||
|
for (i, &e) in exps.iter().enumerate() {
|
||||||
|
acc += e;
|
||||||
|
if acc >= r {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exps.len() - 1
|
||||||
|
}
|
||||||
|
|
||||||
/// One incremental decode step for token `tok` at absolute position `pos`: append
|
/// One incremental decode step for token `tok` at absolute position `pos`: append
|
||||||
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
|
/// its K/V to the cache and return the next-token logits as host f32 `[vocab]`.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
|||||||
@@ -29,4 +29,4 @@ pub use model::{TinyTransformer, batched_ids_tensor, ids_tensor, param_to_host};
|
|||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub mod decode;
|
pub mod decode;
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub use decode::generate_greedy_cached;
|
pub use decode::{generate_cached, generate_greedy_cached};
|
||||||
|
|||||||
@@ -941,6 +941,31 @@ impl Tensor {
|
|||||||
dx
|
dx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Per-row scale: `out[r,c] = self[r,c] * s[r]`. `self`:[rows,cols] F32,
|
||||||
|
/// `s`:[rows] F32. Used by the GRPO (M4) policy-gradient backward, where each
|
||||||
|
/// completion token's row of `(probs − onehot)` is scaled by its own per-token
|
||||||
|
/// coefficient (the per-token clipped-PG + KL gradient). Forward-only.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
pub fn scale_rows(&self, s: &Tensor) -> Self {
|
||||||
|
assert_eq!(self.ndim(), 2, "scale_rows requires a 2D tensor");
|
||||||
|
assert_eq!(self.dtype, DType::F32, "scale_rows is F32");
|
||||||
|
assert_eq!(s.dtype, DType::F32, "scale vector is F32");
|
||||||
|
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||||
|
assert_eq!(s.numel(), rows, "scale vector must have one entry per row");
|
||||||
|
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||||
|
unsafe {
|
||||||
|
xtrain_cuda::ffi::launch_scale_rows_f32(
|
||||||
|
self.data_ptr() as *const f32,
|
||||||
|
s.data_ptr() as *const f32,
|
||||||
|
out.data_ptr() as *mut f32,
|
||||||
|
rows as i32,
|
||||||
|
cols as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
// --- Structural / model ops (the T5 kernels) ---
|
// --- Structural / model ops (the T5 kernels) ---
|
||||||
|
|
||||||
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a
|
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a
|
||||||
|
|||||||
288
crates/xtrain-train/src/bin/train_grpo.rs
Normal file
288
crates/xtrain-train/src/bin/train_grpo.rs
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
//! GRPO training on the verifiable arithmetic task (M4 / Stage P3) — online,
|
||||||
|
//! critic-free RL. The centerpiece: generation INSIDE the training loop.
|
||||||
|
//!
|
||||||
|
//! Each step: sample B prompts (fresh problems), roll out G completions per prompt
|
||||||
|
//! (temperature sampling via the naive sampler — batched/cached rollout is the M2b/
|
||||||
|
//! M4-perf follow-up), score each with the rule-based checker (reward ∈ {0,1}),
|
||||||
|
//! compute the **group-relative advantage** `A_i = (r_i − mean) / (std + ε)` (no
|
||||||
|
//! critic), then K inner clipped-PG epochs minimising [`clipped_pg_loss`] with a KL
|
||||||
|
//! leash to the frozen reference (πref = the SFT checkpoint). Reward = pure 0/1
|
||||||
|
//! correctness; the KL term (β) is what keeps format/coherence (the M3 collapse
|
||||||
|
//! lesson — here it is an explicit leash, not just a hope).
|
||||||
|
//!
|
||||||
|
//! Health signal (the falsifiable "it learns"): **mean rollout reward must rise**
|
||||||
|
//! (the RL analogue of T5's overfit-27/27). Held-out correctness is measured by
|
||||||
|
//! eval_arith on the saved checkpoint.
|
||||||
|
//!
|
||||||
|
//! train_grpo <tokenizer.json> --init-ckpt <sft.ckpt> <arch flags> \
|
||||||
|
//! --steps 200 --group 6 --prompts 8 --temp 1.0 --beta 0.04 --eps 0.2 \
|
||||||
|
//! --lr 1e-6 --max-add 20 --max-mul 9 --ckpt <out.ckpt>
|
||||||
|
|
||||||
|
#[cfg(no_cuda)]
|
||||||
|
fn main() {
|
||||||
|
eprintln!("train_grpo: built without CUDA (no_cuda); run on a GPU host.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
use xtrain_autodiff::ops;
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
use xtrain_cuda::device;
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
use xtrain_model::{Config, TinyTransformer, generate_cached, ids_tensor};
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
use xtrain_tensor::{DType, Device};
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
use xtrain_train::task::{check_answer, gen_problem, GenConfig, Op};
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||||
|
let mut state = seed
|
||||||
|
.wrapping_mul(2862933555777941757)
|
||||||
|
.wrapping_add(3037000493);
|
||||||
|
(0..n)
|
||||||
|
.map(|_| {
|
||||||
|
state = state
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn first_answer_segment(c: &str) -> &str {
|
||||||
|
let s = c.split("<|endoftext|>").next().unwrap_or(c);
|
||||||
|
s.split('\n').next().unwrap_or(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a model from the SFT checkpoint (bf16 compute to fit two 1B models). The
|
||||||
|
/// policy enables activation recompute (T13) so its backward fits alongside the
|
||||||
|
/// frozen reference + the Adam state; the reference only forwards (no backward).
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn load_model(cfg: Config, device: Device, ckpt: &str, recompute: bool) -> TinyTransformer {
|
||||||
|
let mut seed = 1u64;
|
||||||
|
let m = TinyTransformer::new(cfg, device, |shape| {
|
||||||
|
seed = seed.wrapping_add(1);
|
||||||
|
let n: usize = shape.iter().product();
|
||||||
|
if shape.len() == 1 {
|
||||||
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||||
|
} else {
|
||||||
|
fill(n, seed, 0.04)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.with_compute_dtype(DType::BF16)
|
||||||
|
.with_recompute(recompute)
|
||||||
|
.with_flash(true);
|
||||||
|
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt), &m.params()).expect("load ckpt");
|
||||||
|
m.eval();
|
||||||
|
m
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Frame (question, completion) like the SFT loader and return the next-token
|
||||||
|
/// (input, target) pair (prompt masked to -100). Same as train_dpo.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn frame(tok: &xserv_tokenizer::Tokenizer, question: &str, completion: &str) -> (Vec<i32>, Vec<i32>) {
|
||||||
|
let p_ids: Vec<i32> = tok
|
||||||
|
.encode(&format!("User: {question}\nAssistant:"))
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| t as i32)
|
||||||
|
.collect();
|
||||||
|
let a_ids: Vec<i32> = tok
|
||||||
|
.encode(&format!(" {completion}\n<|endoftext|>"))
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| t as i32)
|
||||||
|
.collect();
|
||||||
|
let mut tokens = p_ids.clone();
|
||||||
|
tokens.extend_from_slice(&a_ids);
|
||||||
|
let mut labels = vec![-100i32; p_ids.len()];
|
||||||
|
labels.extend_from_slice(&a_ids);
|
||||||
|
let l = tokens.len();
|
||||||
|
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-position logprob `logπ(target_t)` of a framed (input, target) pair (= −per_row
|
||||||
|
/// of cross_entropy; masked positions are 0 and unused). No grad kept.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn per_token_logp(model: &TinyTransformer, device: Device, input: &[i32], target: &[i32]) -> Vec<f32> {
|
||||||
|
let logits = model.forward(&ids_tensor(input, device)).value();
|
||||||
|
let (_, per_row) = logits.cross_entropy(&ids_tensor(target, device));
|
||||||
|
per_row
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.iter()
|
||||||
|
.map(|p| -p)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
fn main() {
|
||||||
|
use xserv_tokenizer::Tokenizer;
|
||||||
|
use xtrain_optim::GpuAdamW;
|
||||||
|
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||||
|
let tok_path = positionals.first().expect("usage: train_grpo <tokenizer.json> [flags]");
|
||||||
|
|
||||||
|
let n_heads = flag(&args, "--heads", 52usize);
|
||||||
|
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||||
|
let n_layers = flag(&args, "--layers", 22usize);
|
||||||
|
let ffn = flag(&args, "--ffn", 6656usize);
|
||||||
|
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||||
|
let steps: usize = flag(&args, "--steps", 200);
|
||||||
|
let group: usize = flag(&args, "--group", 6);
|
||||||
|
let n_prompts: usize = flag(&args, "--prompts", 8);
|
||||||
|
let inner: usize = flag(&args, "--inner", 1);
|
||||||
|
let temp: f32 = flag(&args, "--temp", 1.0);
|
||||||
|
let beta: f32 = flag(&args, "--beta", 0.04);
|
||||||
|
let eps: f32 = flag(&args, "--eps", 0.2);
|
||||||
|
let lr: f32 = flag(&args, "--lr", 1e-6);
|
||||||
|
let clip: f32 = flag(&args, "--clip", 1.0);
|
||||||
|
let max_new: usize = flag(&args, "--max-tokens", 24);
|
||||||
|
let max_add: i64 = flag(&args, "--max-add", 20);
|
||||||
|
let max_mul: i64 = flag(&args, "--max-mul", 9);
|
||||||
|
let seed: u64 = flag(&args, "--seed", 20260630);
|
||||||
|
let log_every: usize = flag(&args, "--log-every", 20);
|
||||||
|
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
|
||||||
|
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
|
||||||
|
|
||||||
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||||
|
device::set_device(0).unwrap();
|
||||||
|
let device = Device::Cuda(0);
|
||||||
|
|
||||||
|
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||||||
|
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
|
||||||
|
let policy = load_model(cfg, device, &init_ckpt, false); // flash keeps attn memory bounded
|
||||||
|
// Frozen πref for the KL leash — only resident when β>0 (a second 1B model is the
|
||||||
|
// memory long-pole; β=0 is pure PG and skips it, the gated degenerate).
|
||||||
|
let reference = if beta > 0.0 {
|
||||||
|
Some(load_model(cfg, device, &init_ckpt, false))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let gcfg = GenConfig {
|
||||||
|
max_add,
|
||||||
|
max_mul,
|
||||||
|
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||||
|
};
|
||||||
|
let params = policy.params();
|
||||||
|
let mut opt = GpuAdamW::new(0.0);
|
||||||
|
let mut rng = seed.max(1);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let (mut win_reward, mut win_solved, mut win_n) = (0f32, 0usize, 0usize);
|
||||||
|
for step in 0..steps {
|
||||||
|
// ---- Rollout: B prompts × G completions, scored, group-advantage ----
|
||||||
|
struct Sample {
|
||||||
|
input: Vec<i32>,
|
||||||
|
target: Vec<i32>,
|
||||||
|
adv: f32,
|
||||||
|
logp_old: Vec<f32>,
|
||||||
|
logp_ref: Vec<f32>,
|
||||||
|
}
|
||||||
|
let mut batch: Vec<Sample> = Vec::new();
|
||||||
|
for _ in 0..n_prompts {
|
||||||
|
let p = gen_problem(&mut rng, &gcfg);
|
||||||
|
let prompt_ids: Vec<i32> = tok
|
||||||
|
.encode(&format!("User: {}\nAssistant:", p.question()))
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| t as i32)
|
||||||
|
.collect();
|
||||||
|
let mut comps: Vec<(String, f32)> = Vec::with_capacity(group);
|
||||||
|
for _ in 0..group {
|
||||||
|
// KV-cache temperature rollout (M2 engine): single-row logits per
|
||||||
|
// step → far lighter on the allocator than the naive sampler, which
|
||||||
|
// fragments it over a long rollout (the M4 rollout long-pole).
|
||||||
|
let out = generate_cached(&policy, device, &prompt_ids, max_new, temp, &mut rng);
|
||||||
|
let cont = tok.decode(&out[prompt_ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||||
|
let seg = first_answer_segment(&cont).trim().to_string();
|
||||||
|
let r = if check_answer(&seg, p.answer()) { 1.0 } else { 0.0 };
|
||||||
|
comps.push((seg, r));
|
||||||
|
}
|
||||||
|
let mean = comps.iter().map(|c| c.1).sum::<f32>() / group as f32;
|
||||||
|
let var = comps.iter().map(|c| (c.1 - mean).powi(2)).sum::<f32>() / group as f32;
|
||||||
|
let std = var.sqrt();
|
||||||
|
win_reward += mean * group as f32;
|
||||||
|
win_solved += comps.iter().filter(|c| c.1 > 0.5).count();
|
||||||
|
win_n += group;
|
||||||
|
// A whole group with no reward variance gives zero advantage → skip
|
||||||
|
// (no learning signal, and avoids dividing by ~0).
|
||||||
|
if std < 1e-6 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (seg, r) in &comps {
|
||||||
|
let adv = (r - mean) / (std + 1e-4);
|
||||||
|
let (input, target) = frame(&tok, &p.question(), seg);
|
||||||
|
let logp_old = per_token_logp(&policy, device, &input, &target);
|
||||||
|
// β=0 ⇒ KL term drops ⇒ logp_ref unused; pass zeros (no reference model).
|
||||||
|
let logp_ref = match &reference {
|
||||||
|
Some(r) => per_token_logp(r, device, &input, &target),
|
||||||
|
None => vec![0.0; logp_old.len()],
|
||||||
|
};
|
||||||
|
batch.push(Sample { input, target, adv, logp_old, logp_ref });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- K inner clipped-PG epochs over the captured batch ----
|
||||||
|
if !batch.is_empty() {
|
||||||
|
let scale = 1.0 / batch.len() as f32;
|
||||||
|
for _ in 0..inner {
|
||||||
|
for s in &batch {
|
||||||
|
let logits = policy.forward(&ids_tensor(&s.input, device));
|
||||||
|
let loss = ops::clipped_pg_loss(
|
||||||
|
&logits,
|
||||||
|
&ids_tensor(&s.target, device),
|
||||||
|
&s.logp_old,
|
||||||
|
&s.logp_ref,
|
||||||
|
s.adv,
|
||||||
|
eps,
|
||||||
|
beta,
|
||||||
|
);
|
||||||
|
ops::scale(&loss, scale).backward();
|
||||||
|
}
|
||||||
|
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||||||
|
opt.step(lr, ¶ms);
|
||||||
|
for p in ¶ms {
|
||||||
|
p.zero_grad();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||||||
|
println!(
|
||||||
|
"step {:5}/{steps}: mean-reward {:.3} | solved {}/{} | {:.0}s",
|
||||||
|
step + 1,
|
||||||
|
win_reward / win_n.max(1) as f32,
|
||||||
|
win_solved,
|
||||||
|
win_n,
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
win_reward = 0.0;
|
||||||
|
win_solved = 0;
|
||||||
|
win_n = 0;
|
||||||
|
// Periodic save so a later OOM (naive rollout fragments the allocator —
|
||||||
|
// the long-pole the design doc flagged) still leaves an evaluatable ckpt.
|
||||||
|
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save ckpt");
|
||||||
|
println!("GRPO done: {steps} steps, G={group}, B={n_prompts}, beta {beta}, lr {lr:.1e} → {out_ckpt}");
|
||||||
|
}
|
||||||
@@ -269,6 +269,23 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
|||||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||||
|
// (M4) policy-gradient backward, where each completion token's row of
|
||||||
|
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||||
|
__global__ void scale_rows_k(const float* x, const float* s, float* y,
|
||||||
|
int rows, int cols) {
|
||||||
|
int r = blockIdx.x;
|
||||||
|
float sr = s[r];
|
||||||
|
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||||||
|
y[r * cols + c] = x[r * cols + c] * sr;
|
||||||
|
}
|
||||||
|
void launch_scale_rows_f32(const float* x, const float* s, float* y,
|
||||||
|
int rows, int cols, void* st) {
|
||||||
|
int blk = cols < 1024 ? cols : 1024;
|
||||||
|
if (blk < 32) blk = 32;
|
||||||
|
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||||
float theta, int period) {
|
float theta, int period) {
|
||||||
int tok = blockIdx.x;
|
int tok = blockIdx.x;
|
||||||
|
|||||||
@@ -466,3 +466,59 @@ verifiable reward* online (sample → check → reinforce what is genuinely corr
|
|||||||
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
|
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
|
||||||
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
|
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
|
||||||
margin/acc rise); the correctness flatness is the reported finding, not a bug.
|
margin/acc rise); the correctness flatness is the reported finding, not a bug.
|
||||||
|
|
||||||
|
### M4 — GRPO (online RL, critic-free, landed; infra + two honest systems walls)
|
||||||
|
|
||||||
|
The centerpiece: generation INSIDE the training loop. Infra built and gated; the run surfaces
|
||||||
|
two concrete systems findings (the memory long-pole + the rollout long-pole, both flagged in the
|
||||||
|
design doc's Risks) and the same capability wall as M3.
|
||||||
|
|
||||||
|
**Task made learnable first (per the aligned decision "easier task → then M4"):** the v12 SFT
|
||||||
|
model scores ~8% on the hard task *and* on easy problems — it learned format, not arithmetic. So
|
||||||
|
the easy task (operands ≤20, ops `+ − ×`) was re-SFT'd from the v12 base → **held-out 18.7%**
|
||||||
|
(100% format), a baseline with reward variance for GRPO. Note: even easy arithmetic plateaus at
|
||||||
|
~19% held-out (250 vs 600 SFT steps identical) — a 1B web-text model does not generalize the
|
||||||
|
add/sub algorithm from ~550 examples; it memorizes train (982 total problems, 550 seen).
|
||||||
|
|
||||||
|
**New op (`xtrain-autodiff`, reuses the CE kernel + one new primitive):**
|
||||||
|
- `clipped_pg_loss(logits, target, logp_old, logp_ref, A, ε, β)` — per completion token
|
||||||
|
`ρ_t = exp(logπθ_t − logp_old_t)`, `L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL` (k3), masked
|
||||||
|
to completion tokens. Backward reuses `(probs − onehot)` + `scale_rows` (a new ~5-line per-row
|
||||||
|
scale kernel — the per-token coefficient varies, which CE-backward's single scalar can't
|
||||||
|
express). **Gate:** grad-check the active PG path + the A=0 (KL-only) path; degenerate value
|
||||||
|
checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
|
||||||
|
|
||||||
|
**Loop (`train_grpo`):** per step — sample B prompts, roll out G completions each, score (reward
|
||||||
|
0/1), group-relative advantage `A=(r−mean)/(std+ε)` (no critic; all-correct/all-wrong groups
|
||||||
|
skipped — zero advantage), capture `logπθ_old`/`logπref` per token, K inner clipped-PG epochs.
|
||||||
|
Rollout uses the M2 KV-cache engine with **temperature sampling** (added in M4): single-row
|
||||||
|
`[1,vocab]` logits per step vs the naive sampler's `[seq,vocab]`.
|
||||||
|
|
||||||
|
**Systems wall #1 — memory (the design doc's "two/three resident models"):** KL-leash GRPO needs
|
||||||
|
policy + frozen reference, two 1.05B fp32-master models + AdamW m/v ≈ 21 GB fixed + training
|
||||||
|
activations → unreliably OOMs on a 32 GB 5090 (fragmentation tips it over). To get a completing
|
||||||
|
run, `β=0` (pure PG) drops the reference model (−4.2 GB). So the *principled* KL-leash version is
|
||||||
|
memory-bound at this model size on this hardware — a real, reported constraint, not a bug.
|
||||||
|
|
||||||
|
**Systems wall #2 — rollout (the design doc's "rollout is the long pole"):** the naive sampler's
|
||||||
|
growing `[seq,vocab]` allocations fragment the caching allocator over a long rollout → OOM. The
|
||||||
|
cached temperature rollout (single-row logits) is lighter; but single-sequence cached decode is
|
||||||
|
slow (the M2a host-round-trip), so rollout still dominates wall-clock (~16 s/step at G=6·B=6).
|
||||||
|
Batched ragged decode (M2b) is the real fix and is deferred to where it is load-bearing.
|
||||||
|
|
||||||
|
**Result (easy task, β=0, G=6·B=6, 40 steps, lr 5e-7; 150 held-out, vs SFT 28/150 = 18.7%):**
|
||||||
|
mean rollout reward fluctuates ~0.58–0.81 (noisy, inflated by train-set overlap in the sampled
|
||||||
|
problems); **format stays 100/100** (no collapse even without the KL leash, at this gentle lr);
|
||||||
|
**held-out 30/150 = 20.0%** — `+1.3 pp`, within the ~3% std-error of 150 prompts, i.e.
|
||||||
|
**statistically flat**, the same wall as M3 DPO.
|
||||||
|
|
||||||
|
**The consistent M3+M4 lesson:** on a task where the base model lacks the underlying capability,
|
||||||
|
**neither offline preference optimization (DPO) nor online RL (GRPO) moves held-out correctness**
|
||||||
|
— each optimizes its objective (margin / reward) on the *training distribution* it can reach
|
||||||
|
(here inflated by memorization), but cannot install a *generalizable* algorithm the model never
|
||||||
|
had. RL reinforces what the model already does; it does not teach arithmetic. Gate met for M4 =
|
||||||
|
the infra is correct (PG/KL grad-checks + degenerate checks, the loop runs, reward signal + KL
|
||||||
|
leash wired, format held); the held-out flatness + the two memory/throughput walls are the
|
||||||
|
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
||||||
|
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
||||||
|
limits on what alignment can do for a capability the base model lacks.
|
||||||
|
|||||||
@@ -101,6 +101,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
|||||||
|
|
||||||
**M3(DPO,离线偏好优化,已落地 + 诚实负结果)**:两个复用 CE kernel 的新算子(零新 CUDA)——`seq_logprob`(Σ log πθ over 非 mask 位,反向 = CE_backward 取负求和;grad-check + mask)、`dpo_loss`(−log σ(Δ),双 policy logprob 父节点;grad-check + 退化 Δ=0→log2/∓β·½、β=0→0)。造对(`gen_dpo_pairs`)= chosen=gold、rejected=SFT 自己 greedy(用 M2a 引擎)的格式合法**错误**答案(8% greedy 答对的跳过)。训练(`train_dpo`)把 SFT ckpt 同时作 policy 和冻结 reference,**一次性预算 reference logprob 并缓存**(单模型驻留),每步 policy forward chosen+rejected → seq_logprob → dpo_loss,两 forward 共享 param 累积梯度;**loss 起步恰好 log2**(Δ=0 内置校验)。**结果(v12, 1500 对, β0.1;100 留出题 vs SFT 8/100)**:reward-margin 与 pref-acc 干净上升(loss 被正确优化、infra 对),但**不转化为 held-out 正确率**——lr5e-7×300→7%、×800→5%、lr1e-6×2000→margin+34 **崩溃**(0% 格式、输出垃圾),三档都在 100 题 ~2.7% 标准误内 = 统计持平。**教训**:chosen/rejected 只差最终数字 token,DPO 提升的是**特定训练对的 token 偏好、reweight 现有分布,不 install 能力**;base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布→不连贯。**DPO 在 chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样→check→强化真正对的)而非固定对的 proxy(但 GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 的 open question)。与 v8/T17 同源的诚实账:跑通+闸门齐全,负结果如实记。
|
**M3(DPO,离线偏好优化,已落地 + 诚实负结果)**:两个复用 CE kernel 的新算子(零新 CUDA)——`seq_logprob`(Σ log πθ over 非 mask 位,反向 = CE_backward 取负求和;grad-check + mask)、`dpo_loss`(−log σ(Δ),双 policy logprob 父节点;grad-check + 退化 Δ=0→log2/∓β·½、β=0→0)。造对(`gen_dpo_pairs`)= chosen=gold、rejected=SFT 自己 greedy(用 M2a 引擎)的格式合法**错误**答案(8% greedy 答对的跳过)。训练(`train_dpo`)把 SFT ckpt 同时作 policy 和冻结 reference,**一次性预算 reference logprob 并缓存**(单模型驻留),每步 policy forward chosen+rejected → seq_logprob → dpo_loss,两 forward 共享 param 累积梯度;**loss 起步恰好 log2**(Δ=0 内置校验)。**结果(v12, 1500 对, β0.1;100 留出题 vs SFT 8/100)**:reward-margin 与 pref-acc 干净上升(loss 被正确优化、infra 对),但**不转化为 held-out 正确率**——lr5e-7×300→7%、×800→5%、lr1e-6×2000→margin+34 **崩溃**(0% 格式、输出垃圾),三档都在 100 题 ~2.7% 标准误内 = 统计持平。**教训**:chosen/rejected 只差最终数字 token,DPO 提升的是**特定训练对的 token 偏好、reweight 现有分布,不 install 能力**;base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布→不连贯。**DPO 在 chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样→check→强化真正对的)而非固定对的 proxy(但 GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 的 open question)。与 v8/T17 同源的诚实账:跑通+闸门齐全,负结果如实记。
|
||||||
|
|
||||||
|
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
||||||
|
|
||||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||||
|
|
||||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||||
|
|||||||
Reference in New Issue
Block a user