post-train: M4 — clipped_pg_loss + scale_rows (GRPO policy-gradient op)
The GRPO (M4) token-level loss op + the one primitive it needs: - scale_rows(x[r,c], s[r]): per-row scale (new ~5-line CUDA kernel). The clipped-PG backward scales each completion token's row of (probs − onehot) by its own per-token coefficient, which cross_entropy_backward's single scalar scale can't express. - clipped_pg_loss(logits, target, logp_old, logp_ref, A, eps, beta): per-token ρ_t = exp(logπθ_t − logp_old_t), L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL (k3 estimator), masked to completion tokens. Backward reuses the CE machinery (probs − onehot) + scale_rows. Gates: grad-check the active PG path + the A=0 (KL-only) path; degenerate value checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -517,3 +517,83 @@ pub fn dpo_loss(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// GRPO clipped policy-gradient loss (M4) for ONE completion, a scalar `[1]` Var
|
||||||
|
/// with the policy logits as the single parent. Per non-ignored (completion) token
|
||||||
|
/// `t` (`target[t] ≥ 0`):
|
||||||
|
/// `logπθ_t = log softmax(logits[t])[target_t]` (`= −per_row[t]` of cross_entropy)
|
||||||
|
/// `ρ_t = exp(logπθ_t − logp_old[t])`
|
||||||
|
/// `pg_t = min(ρ_t·A, clip(ρ_t, 1−ε, 1+ε)·A)`
|
||||||
|
/// `kl_t = exp(logp_ref[t] − logπθ_t) − (logp_ref[t] − logπθ_t) − 1` (k3 estimator)
|
||||||
|
/// `L = −mean_t pg_t + β·mean_t kl_t` over the `N` completion tokens.
|
||||||
|
///
|
||||||
|
/// `advantage` `A` is the group-relative advantage (constant per completion in
|
||||||
|
/// GRPO); `logp_old`/`logp_ref` are per-position constants (old policy at rollout
|
||||||
|
/// time / frozen reference). Backward reuses the CE machinery + the per-row
|
||||||
|
/// `scale_rows`: `dL/dlogits[t,:] = g_t·(onehot − probs)[t,:]` with
|
||||||
|
/// `g_t = −(1/N)A·ρ_t·[unclipped active] + (β/N)(1 − exp(logp_ref_t − logπθ_t))`.
|
||||||
|
/// Degenerate points the gate pins: `A=0` ⇒ only the KL term; `ε→∞` ⇒ vanilla PG
|
||||||
|
/// (no clip); `β=0` ⇒ no KL term.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn clipped_pg_loss(
|
||||||
|
logits: &Var,
|
||||||
|
target: &Tensor,
|
||||||
|
logp_old: &[f32],
|
||||||
|
logp_ref: &[f32],
|
||||||
|
advantage: f32,
|
||||||
|
eps: f32,
|
||||||
|
beta: f32,
|
||||||
|
) -> Var {
|
||||||
|
use xtrain_tensor::Device;
|
||||||
|
let logit_dtype = logits.value().dtype();
|
||||||
|
let (probs, per_row) = logits.value().cross_entropy(target);
|
||||||
|
let rows = per_row.shape()[0];
|
||||||
|
let per_row_h = per_row.to_device(Device::Cpu).as_slice::<f32>().to_vec();
|
||||||
|
let target_h = target.to_device(Device::Cpu).as_slice::<i32>().to_vec();
|
||||||
|
assert_eq!(logp_old.len(), rows, "logp_old must have one entry per position");
|
||||||
|
assert_eq!(logp_ref.len(), rows, "logp_ref must have one entry per position");
|
||||||
|
|
||||||
|
let mut s = vec![0f32; rows]; // per-row scale for cross_entropy_backward(·,·,1.0)
|
||||||
|
let (mut pg_sum, mut kl_sum, mut n) = (0f32, 0f32, 0f32);
|
||||||
|
for t in 0..rows {
|
||||||
|
if target_h[t] < 0 {
|
||||||
|
continue; // masked (prompt) position — no contribution, no gradient
|
||||||
|
}
|
||||||
|
n += 1.0;
|
||||||
|
let lp = -per_row_h[t]; // logπθ_t
|
||||||
|
let ratio = (lp - logp_old[t]).exp();
|
||||||
|
let clipped = ratio.clamp(1.0 - eps, 1.0 + eps);
|
||||||
|
let (unclipped_term, clipped_term) = (ratio * advantage, clipped * advantage);
|
||||||
|
pg_sum += unclipped_term.min(clipped_term);
|
||||||
|
let active = unclipped_term <= clipped_term; // min picks unclipped ⇒ grad flows
|
||||||
|
let d = logp_ref[t] - lp;
|
||||||
|
kl_sum += d.exp() - d - 1.0;
|
||||||
|
let pg_grad = if active { -advantage * ratio } else { 0.0 };
|
||||||
|
let kl_grad = beta * (1.0 - d.exp());
|
||||||
|
s[t] = -(pg_grad + kl_grad); // dL/dlogits = g·(onehot−probs) = −g·(probs−onehot)
|
||||||
|
}
|
||||||
|
let inv_n = if n > 0.0 { 1.0 / n } else { 1.0 };
|
||||||
|
for v in &mut s {
|
||||||
|
*v *= inv_n;
|
||||||
|
}
|
||||||
|
let loss_val = -pg_sum * inv_n + beta * kl_sum * inv_n;
|
||||||
|
let dev = logits.value().device();
|
||||||
|
let out = Tensor::from_slice(&[loss_val], &[1]).to_device(dev);
|
||||||
|
let s_dev = Tensor::from_slice(&s, &[rows]).to_device(dev);
|
||||||
|
|
||||||
|
let target = target.clone();
|
||||||
|
Var::from_op(
|
||||||
|
out,
|
||||||
|
vec![logits.clone()],
|
||||||
|
Box::new(move |d, parents| {
|
||||||
|
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
// (probs − onehot), masked rows already 0; per-row scale by s; × upstream.
|
||||||
|
let ce = Tensor::cross_entropy_backward(&probs, &target, 1.0);
|
||||||
|
let mut dx = ce.scale_rows(&s_dev);
|
||||||
|
if up != 1.0 {
|
||||||
|
dx = dx.scale(up);
|
||||||
|
}
|
||||||
|
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1085,3 +1085,95 @@ fn dpo_loss_bwd_and_degenerate() {
|
|||||||
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
||||||
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clipped_pg_loss (M4 GRPO): per-token clipped PG + k3 KL, one completion. Grad-check
|
||||||
|
// the active (in-trust-region) path + the A=0 (KL-only) path, plus value-level
|
||||||
|
// degenerate checks (ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL).
|
||||||
|
#[test]
|
||||||
|
fn clipped_pg_loss_bwd_and_degenerate() {
|
||||||
|
require_gpu();
|
||||||
|
let (rows, cols) = (6usize, 10usize);
|
||||||
|
let x_h = fill(rows * cols, 303);
|
||||||
|
// rows 0,1 masked (prompt); 2..6 supervised (completion).
|
||||||
|
let targets: Vec<i32> = (0..rows)
|
||||||
|
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
|
||||||
|
.collect();
|
||||||
|
let mk_target = || Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||||
|
|
||||||
|
// logp_old = logπθ at the base logits ⇒ ρ≈1 (in trust region → active path).
|
||||||
|
let (_, per_row0) = cuda(&x_h, &[rows, cols]).cross_entropy(&mk_target());
|
||||||
|
let logp_old: Vec<f32> = per_row0
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.iter()
|
||||||
|
.map(|p| -p)
|
||||||
|
.collect();
|
||||||
|
let logp_ref: Vec<f32> = logp_old.iter().map(|l| l - 0.3).collect(); // exercise KL
|
||||||
|
let (eps, beta) = (0.2f32, 0.1f32);
|
||||||
|
|
||||||
|
// Host replica of the forward loss as a function of per-row CE values.
|
||||||
|
let host_loss = {
|
||||||
|
let (tg, lo, lr) = (targets.clone(), logp_old.clone(), logp_ref.clone());
|
||||||
|
move |per_row_h: &[f32], a: f32, e: f32, b: f32| -> f32 {
|
||||||
|
let (mut pg, mut kl, mut n) = (0f32, 0f32, 0f32);
|
||||||
|
for t in 0..per_row_h.len() {
|
||||||
|
if tg[t] < 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
n += 1.0;
|
||||||
|
let lp = -per_row_h[t];
|
||||||
|
let ratio = (lp - lo[t]).exp();
|
||||||
|
let clipped = ratio.clamp(1.0 - e, 1.0 + e);
|
||||||
|
pg += (ratio * a).min(clipped * a);
|
||||||
|
let d = lr[t] - lp;
|
||||||
|
kl += d.exp() - d - 1.0;
|
||||||
|
}
|
||||||
|
let inv = if n > 0.0 { 1.0 / n } else { 1.0 };
|
||||||
|
-pg * inv + b * kl * inv
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let per_row_of = |v: &[f32], s: &[usize]| {
|
||||||
|
let (_, pr) = cuda(v, s).cross_entropy(&mk_target());
|
||||||
|
pr.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||||
|
};
|
||||||
|
|
||||||
|
// (1) grad-check the active PG path (A>0, ρ≈1).
|
||||||
|
let adv = 0.7f32;
|
||||||
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||||
|
let loss = ops::clipped_pg_loss(&x, &mk_target(), &logp_old, &logp_ref, adv, eps, beta);
|
||||||
|
loss.backward();
|
||||||
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||||
|
let hl = host_loss.clone();
|
||||||
|
let lx = move |v: &[f32], s: &[usize]| hl(&per_row_of(v, s), adv, eps, beta);
|
||||||
|
report(
|
||||||
|
"clipped_pg dX (active)",
|
||||||
|
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// (2) grad-check the A=0 path (loss = β·mean KL; PG gradient must vanish).
|
||||||
|
let x0 = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||||
|
let loss0 = ops::clipped_pg_loss(&x0, &mk_target(), &logp_old, &logp_ref, 0.0, eps, beta);
|
||||||
|
loss0.backward();
|
||||||
|
let dx0 = x0.grad().unwrap().to_device(Device::Cpu);
|
||||||
|
let hl0 = host_loss.clone();
|
||||||
|
let lx0 = move |v: &[f32], s: &[usize]| hl0(&per_row_of(v, s), 0.0, eps, beta);
|
||||||
|
report(
|
||||||
|
"clipped_pg dX (A=0, KL only)",
|
||||||
|
&grad_check(&x_h, &[rows, cols], &lx0, dx0.as_slice::<f32>(), cfg_nonlinear()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// (3) ε→∞ ⇒ vanilla PG (no clip): loss value == −mean(ρA) + β·mean KL.
|
||||||
|
let big = 1e9f32;
|
||||||
|
let lv = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, big, beta);
|
||||||
|
let got = lv.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let pr0 = per_row_of(&x_h, &[rows, cols]);
|
||||||
|
let want = host_loss(&pr0, adv, big, beta);
|
||||||
|
assert!((got - want).abs() < 1e-4, "ε→∞ vanilla loss mismatch: {got} vs {want}");
|
||||||
|
|
||||||
|
// (4) β=0 ⇒ no KL term (loss == −mean pg only).
|
||||||
|
let lvb = ops::clipped_pg_loss(&Var::leaf(cuda(&x_h, &[rows, cols])), &mk_target(), &logp_old, &logp_ref, adv, eps, 0.0);
|
||||||
|
let gotb = lvb.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let wantb = host_loss(&pr0, adv, eps, 0.0);
|
||||||
|
assert!((gotb - wantb).abs() < 1e-5, "β=0 loss mismatch: {gotb} vs {wantb}");
|
||||||
|
println!("clipped_pg_loss OK: grad-check (active + A=0) + degenerate (ε→∞ vanilla, β=0 no KL)");
|
||||||
|
}
|
||||||
|
|||||||
@@ -152,6 +152,15 @@ unsafe extern "C" {
|
|||||||
pos0: i32,
|
pos0: i32,
|
||||||
s: CudaStream,
|
s: CudaStream,
|
||||||
);
|
);
|
||||||
|
// Per-row scale: y[r,c] = x[r,c] * s[r] (GRPO policy-gradient backward).
|
||||||
|
pub fn launch_scale_rows_f32(
|
||||||
|
x: *const f32,
|
||||||
|
s: *const f32,
|
||||||
|
y: *mut f32,
|
||||||
|
rows: i32,
|
||||||
|
cols: i32,
|
||||||
|
stream: CudaStream,
|
||||||
|
);
|
||||||
pub fn launch_rope_dx_f32(
|
pub fn launch_rope_dx_f32(
|
||||||
dy: *const f32,
|
dy: *const f32,
|
||||||
dx: *mut f32,
|
dx: *mut f32,
|
||||||
|
|||||||
@@ -941,6 +941,31 @@ impl Tensor {
|
|||||||
dx
|
dx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Per-row scale: `out[r,c] = self[r,c] * s[r]`. `self`:[rows,cols] F32,
|
||||||
|
/// `s`:[rows] F32. Used by the GRPO (M4) policy-gradient backward, where each
|
||||||
|
/// completion token's row of `(probs − onehot)` is scaled by its own per-token
|
||||||
|
/// coefficient (the per-token clipped-PG + KL gradient). Forward-only.
|
||||||
|
#[cfg(not(no_cuda))]
|
||||||
|
pub fn scale_rows(&self, s: &Tensor) -> Self {
|
||||||
|
assert_eq!(self.ndim(), 2, "scale_rows requires a 2D tensor");
|
||||||
|
assert_eq!(self.dtype, DType::F32, "scale_rows is F32");
|
||||||
|
assert_eq!(s.dtype, DType::F32, "scale vector is F32");
|
||||||
|
let (rows, cols) = (self.shape[0], self.shape[1]);
|
||||||
|
assert_eq!(s.numel(), rows, "scale vector must have one entry per row");
|
||||||
|
let out = Tensor::zeros(&self.shape, DType::F32, self.device());
|
||||||
|
unsafe {
|
||||||
|
xtrain_cuda::ffi::launch_scale_rows_f32(
|
||||||
|
self.data_ptr() as *const f32,
|
||||||
|
s.data_ptr() as *const f32,
|
||||||
|
out.data_ptr() as *mut f32,
|
||||||
|
rows as i32,
|
||||||
|
cols as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
// --- Structural / model ops (the T5 kernels) ---
|
// --- Structural / model ops (the T5 kernels) ---
|
||||||
|
|
||||||
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a
|
/// Reshape to `new_shape` (must keep `numel`). Pure metadata change on a
|
||||||
|
|||||||
@@ -269,6 +269,23 @@ void launch_rope_at_f32(const float* x, float* y, int tokens, int heads,
|
|||||||
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
rope_at_k<<<grid, blk, 0, (cudaStream_t)s>>>(x, y, heads, head_dim, theta, pos0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Per-row scale: y[r,c] = x[r,c] * s[r]. One block per row. Used by the GRPO
|
||||||
|
// (M4) policy-gradient backward, where each completion token's row of
|
||||||
|
// (probs − onehot) is scaled by its own per-token coefficient.
|
||||||
|
__global__ void scale_rows_k(const float* x, const float* s, float* y,
|
||||||
|
int rows, int cols) {
|
||||||
|
int r = blockIdx.x;
|
||||||
|
float sr = s[r];
|
||||||
|
for (int c = threadIdx.x; c < cols; c += blockDim.x)
|
||||||
|
y[r * cols + c] = x[r * cols + c] * sr;
|
||||||
|
}
|
||||||
|
void launch_scale_rows_f32(const float* x, const float* s, float* y,
|
||||||
|
int rows, int cols, void* st) {
|
||||||
|
int blk = cols < 1024 ? cols : 1024;
|
||||||
|
if (blk < 32) blk = 32;
|
||||||
|
scale_rows_k<<<rows, blk, 0, (cudaStream_t)st>>>(x, s, y, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
__global__ void rope_dx_k(const float* dy, float* dx, int heads, int head_dim,
|
||||||
float theta, int period) {
|
float theta, int period) {
|
||||||
int tok = blockIdx.x;
|
int tok = blockIdx.x;
|
||||||
|
|||||||
Reference in New Issue
Block a user