diff --git a/crates/xtrain-autodiff/src/ops.rs b/crates/xtrain-autodiff/src/ops.rs index 0eb860d..8e0c8d5 100644 --- a/crates/xtrain-autodiff/src/ops.rs +++ b/crates/xtrain-autodiff/src/ops.rs @@ -517,3 +517,83 @@ pub fn dpo_loss( }), ) } + +/// GRPO clipped policy-gradient loss (M4) for ONE completion, a scalar `[1]` Var +/// with the policy logits as the single parent. Per non-ignored (completion) token +/// `t` (`target[t] ≥ 0`): +/// `logπθ_t = log softmax(logits[t])[target_t]` (`= −per_row[t]` of cross_entropy) +/// `ρ_t = exp(logπθ_t − logp_old[t])` +/// `pg_t = min(ρ_t·A, clip(ρ_t, 1−ε, 1+ε)·A)` +/// `kl_t = exp(logp_ref[t] − logπθ_t) − (logp_ref[t] − logπθ_t) − 1` (k3 estimator) +/// `L = −mean_t pg_t + β·mean_t kl_t` over the `N` completion tokens. +/// +/// `advantage` `A` is the group-relative advantage (constant per completion in +/// GRPO); `logp_old`/`logp_ref` are per-position constants (old policy at rollout +/// time / frozen reference). Backward reuses the CE machinery + the per-row +/// `scale_rows`: `dL/dlogits[t,:] = g_t·(onehot − probs)[t,:]` with +/// `g_t = −(1/N)A·ρ_t·[unclipped active] + (β/N)(1 − exp(logp_ref_t − logπθ_t))`. +/// Degenerate points the gate pins: `A=0` ⇒ only the KL term; `ε→∞` ⇒ vanilla PG +/// (no clip); `β=0` ⇒ no KL term. +#[allow(clippy::too_many_arguments)] +pub fn clipped_pg_loss( + logits: &Var, + target: &Tensor, + logp_old: &[f32], + logp_ref: &[f32], + advantage: f32, + eps: f32, + beta: f32, +) -> Var { + use xtrain_tensor::Device; + let logit_dtype = logits.value().dtype(); + let (probs, per_row) = logits.value().cross_entropy(target); + let rows = per_row.shape()[0]; + let per_row_h = per_row.to_device(Device::Cpu).as_slice::().to_vec(); + let target_h = target.to_device(Device::Cpu).as_slice::().to_vec(); + assert_eq!(logp_old.len(), rows, "logp_old must have one entry per position"); + assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per position"); + + let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0) + let (mut pg_sum, mut kl_sum, mut n) = (0f32, 0f32, 0f32); + for t in 0..rows { + if target_h[t] < 0 { + continue; // masked (prompt) position — no contribution, no gradient + } + n += 1.0; + let lp = -per_row_h[t]; // logπθ_t + let ratio = (lp - logp_old[t]).exp(); + let clipped = ratio.clamp(1.0 - eps, 1.0 + eps); + let (unclipped_term, clipped_term) = (ratio * advantage, clipped * advantage); + pg_sum += unclipped_term.min(clipped_term); + let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows + let d = logp_ref[t] - lp; + kl_sum += d.exp() - d - 1.0; + let pg_grad = if active { -advantage * ratio } else { 0.0 }; + let kl_grad = beta * (1.0 - d.exp()); + s[t] = -(pg_grad + kl_grad); // dL/dlogits = g·(onehot−probs) = −g·(probs−onehot) + } + let inv_n = if n > 0.0 { 1.0 / n } else { 1.0 }; + for v in &mut s { + *v *= inv_n; + } + let loss_val = -pg_sum * inv_n + beta * kl_sum * inv_n; + let dev = logits.value().device(); + let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev); + let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev); + + let target = target.clone(); + Var::from_op( + out, + vec![logits.clone()], + Box::new(move |d, parents| { + let up = d.to_device(Device::Cpu).as_slice::()[0]; + // (probs − onehot), masked rows already 0; per-row scale by s; × upstream. + let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0); + let mut dx = ce.scale_rows(&s_dev); + if up != 1.0 { + dx = dx.scale(up); + } + Var::push_grad(&parents[0], dx.to_dtype(logit_dtype)); + }), + ) +} diff --git a/crates/xtrain-autodiff/tests/autograd.rs b/crates/xtrain-autodiff/tests/autograd.rs index e2b8d57..ff15be4 100644 --- a/crates/xtrain-autodiff/tests/autograd.rs +++ b/crates/xtrain-autodiff/tests/autograd.rs @@ -1085,3 +1085,95 @@ fn dpo_loss_bwd_and_degenerate() { assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}"); println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)"); } + +// clipped_pg_loss (M4 GRPO): per-token clipped PG + k3 KL, one completion. Grad-check +// the active (in-trust-region) path + the A=0 (KL-only) path, plus value-level +// degenerate checks (ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL). +#[test] +fn clipped_pg_loss_bwd_and_degenerate() { + require_gpu(); + let (rows, cols) = (6usize, 10usize); + let x_h = fill(rows * cols, 303); + // rows 0,1 masked (prompt); 2..6 supervised (completion). + let targets: Vec = (0..rows) + .map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 }) + .collect(); + let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0)); + + // logp_old = logπθ at the base logits ⇒ ρ≈1 (in trust region → active path). + let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target()); + let logp_old: Vec = per_row0 + .to_device(Device::Cpu) + .as_slice::() + .iter() + .map(|p| -p) + .collect(); + let logp_ref: Vec = logp_old.iter().map(|l| l - 0.3).collect(); // exercise KL + let (eps, beta) = (0.2f32, 0.1f32); + + // Host replica of the forward loss as a function of per-row CE values. + let host_loss = { + let (tg, lo, lr) = (targets.clone(), logp_old.clone(), logp_ref.clone()); + move |per_row_h: &[f32], a: f32, e: f32, b: f32| -> f32 { + let (mut pg, mut kl, mut n) = (0f32, 0f32, 0f32); + for t in 0..per_row_h.len() { + if tg[t] < 0 { + continue; + } + n += 1.0; + let lp = -per_row_h[t]; + let ratio = (lp - lo[t]).exp(); + let clipped = ratio.clamp(1.0 - e, 1.0 + e); + pg += (ratio * a).min(clipped * a); + let d = lr[t] - lp; + kl += d.exp() - d - 1.0; + } + let inv = if n > 0.0 { 1.0 / n } else { 1.0 }; + -pg * inv + b * kl * inv + } + }; + let per_row_of = |v: &[f32], s: &[usize]| { + let (_, pr) = cuda(v, s).cross_entropy(&mk_target()); + pr.to_device(Device::Cpu).as_slice::().to_vec() + }; + + // (1) grad-check the active PG path (A>0, ρ≈1). + let adv = 0.7f32; + let x = Var::leaf(cuda(&x_h, &[rows, cols])); + let loss = ops::clipped_pg_loss(&x, &mk_target(), &logp_old, &logp_ref, adv, eps, beta); + loss.backward(); + let dx = x.grad().unwrap().to_device(Device::Cpu); + let hl = host_loss.clone(); + let lx = move |v: &[f32], s: &[usize]| hl(&per_row_of(v, s), adv, eps, beta); + report( + "clipped_pg dX (active)", + &grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::(), cfg_nonlinear()), + ); + + // (2) grad-check the A=0 path (loss = β·mean KL; PG gradient must vanish). + let x0 = Var::leaf(cuda(&x_h, &[rows, cols])); + let loss0 = ops::clipped_pg_loss(&x0, &mk_target(), &logp_old, &logp_ref, 0.0, eps, beta); + loss0.backward(); + let dx0 = x0.grad().unwrap().to_device(Device::Cpu); + let hl0 = host_loss.clone(); + let lx0 = move |v: &[f32], s: &[usize]| hl0(&per_row_of(v, s), 0.0, eps, beta); + report( + "clipped_pg dX (A=0, KL only)", + &grad_check(&x_h, &[rows, cols], &lx0, dx0.as_slice::(), cfg_nonlinear()), + ); + + // (3) ε→∞ ⇒ vanilla PG (no clip): loss value == −mean(ρA) + β·mean KL. + let big = 1e9f32; + let lv = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, big, beta); + let got = lv.value().to_device(Device::Cpu).as_slice::()[0]; + let pr0 = per_row_of(&x_h, &[rows, cols]); + let want = host_loss(&pr0, adv, big, beta); + assert!((got - want).abs() < 1e-4, "ε→∞ vanilla loss mismatch: {got} vs {want}"); + + // (4) β=0 ⇒ no KL term (loss == −mean pg only). + let lvb = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, eps, 0.0); + let gotb = lvb.value().to_device(Device::Cpu).as_slice::()[0]; + let wantb = host_loss(&pr0, adv, eps, 0.0); + assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}"); + println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)"); +} diff --git a/crates/xtrain-cuda/src/ffi.rs b/crates/xtrain-cuda/src/ffi.rs index 170f3e1..5599df2 100644 --- a/crates/xtrain-cuda/src/ffi.rs +++ b/crates/xtrain-cuda/src/ffi.rs @@ -152,6 +152,15 @@ unsafe extern "C" { pos0: i32, s: CudaStream, ); + // Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward). + pub fn launch_scale_rows_f32( + x: *const f32, + s: *const f32, + y: *mut f32, + rows: i32, + cols: i32, + stream: CudaStream, + ); pub fn launch_rope_dx_f32( dy: *const f32, dx: *mut f32, diff --git a/crates/xtrain-tensor/src/tensor.rs b/crates/xtrain-tensor/src/tensor.rs index 7c305fe..a07f6c0 100644 --- a/crates/xtrain-tensor/src/tensor.rs +++ b/crates/xtrain-tensor/src/tensor.rs @@ -941,6 +941,31 @@ impl Tensor { dx } + /// Per-row scale: `out[r,c] = self[r,c] * s[r]`. `self`:[rows,cols] F32, + /// `s`:[rows] F32. Used by the GRPO (M4) policy-gradient backward, where each + /// completion token's row of `(probs − onehot)` is scaled by its own per-token + /// coefficient (the per-token clipped-PG + KL gradient). Forward-only. + #[cfg(not(no_cuda))] + pub fn scale_rows(&self, s: &Tensor) -> Self { + assert_eq!(self.ndim(), 2, "scale_rows requires a 2D tensor"); + assert_eq!(self.dtype, DType::F32, "scale_rows is F32"); + assert_eq!(s.dtype, DType::F32, "scale vector is F32"); + let (rows, cols) = (self.shape[0], self.shape[1]); + assert_eq!(s.numel(), rows, "scale vector must have one entry per row"); + let out = Tensor::zeros(&self.shape, DType::F32, self.device()); + unsafe { + xtrain_cuda::ffi::launch_scale_rows_f32( + self.data_ptr() as *const f32, + s.data_ptr() as *const f32, + out.data_ptr() as *mut f32, + rows as i32, + cols as i32, + std::ptr::null_mut(), + ); + } + out + } + // --- Structural / model ops (the T5 kernels) --- /// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a diff --git a/csrc/ops/nn.cu b/csrc/ops/nn.cu index 37399ef..7b93210 100644 --- a/csrc/ops/nn.cu +++ b/csrc/ops/nn.cu @@ -269,6 +269,23 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads, rope_at_k<<>>(x, y, heads, head_dim, theta, pos0); } +// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO +// (M4) policy-gradient backward, where each completion token's row of +// (probs − onehot) is scaled by its own per-token coefficient. +__global__ void scale_rows_k(const float* x, const float* s, float* y, + int rows, int cols) { + int r = blockIdx.x; + float sr = s[r]; + for (int c = threadIdx.x; c < cols; c += blockDim.x) + y[r * cols + c] = x[r * cols + c] * sr; +} +void launch_scale_rows_f32(const float* x, const float* s, float* y, + int rows, int cols, void* st) { + int blk = cols < 1024 ? cols : 1024; + if (blk < 32) blk = 32; + scale_rows_k<<>>(x, s, y, rows, cols); +} + __global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim, float theta, int period) { int tok = blockIdx.x;