From f3c764ce9504d88e218fb41eb15d3edf95bb2e82 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 30 Jun 2026 12:11:01 +0800 Subject: [PATCH] =?UTF-8?q?post-train:=20M3=20=E2=80=94=20seq=5Flogprob=20?= =?UTF-8?q?+=20dpo=5Floss=20autograd=20ops?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two new ops for DPO (M3), both reusing existing kernels (no new CUDA): - seq_logprob(logits, target): Σ log πθ(target) over non-ignored (target≥0) positions — the per-sequence logprob DPO compares between policy and reference. = −Σ per_row of cross_entropy (ignored rows already 0, like SFT masking); backward = cross_entropy_backward(probs, target, −upstream) (sum, no mean division). Gate: finite-diff grad-check with a -100 completion mask. - dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β): scalar L = −log σ(Δ) = softplus(−Δ) with the two policy logprobs as parents (ref logprobs constant). Gate: grad-check both parents + degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0). Same formula as TRL. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-autodiff/src/ops.rs | 78 +++++++++++++++++++++++ crates/xtrain-autodiff/tests/autograd.rs | 80 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+) diff --git a/crates/xtrain-autodiff/src/ops.rs b/crates/xtrain-autodiff/src/ops.rs index e7a98a0..0eb860d 100644 --- a/crates/xtrain-autodiff/src/ops.rs +++ b/crates/xtrain-autodiff/src/ops.rs @@ -439,3 +439,81 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var { }), ) } + +/// Per-sequence log-probability: `Σ log πθ(target)` over the non-ignored +/// (`target ≥ 0`) positions — the quantity DPO (M3) compares between policy and +/// reference. `target` is `[rows]` I32 carrying `-100` (ignore) at masked positions +/// (e.g. the prompt) and the gold token id elsewhere; ignored positions contribute +/// 0, exactly like the SFT cross-entropy masking. Returns a scalar `[1]` Var. +/// +/// Reuses the CE forward (per-row `−log p(target)`) and backward, so no new kernel: +/// `seq_logprob = −Σ per_row`, and `d(seq_logprob)/d(logits) = −(probs − onehot)` +/// = `cross_entropy_backward(probs, target, −upstream)` (a SUM, so no mean +/// division — contrast [`cross_entropy`], which divides by `valid_rows`). +pub fn seq_logprob(x: &Var, target: &Tensor) -> Var { + let logit_dtype = x.value().dtype(); + let (probs, per_row) = x.value().cross_entropy(target); + // per_row[r] = −log p(target_r), and is 0 for ignored rows (target < 0), so the + // sum already counts only the supervised (completion) positions. + let sum_neg_lp: f32 = per_row + .to_device(xtrain_tensor::Device::Cpu) + .as_slice::() + .iter() + .sum(); + let out = Tensor::from_slice(&[-sum_neg_lp], &[1]).to_device(x.value().device()); + + let target = target.clone(); + Var::from_op( + out, + vec![x.clone()], + Box::new(move |d, parents| { + let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::()[0]; + // d(Σ log p)/d(logits) = −(probs − onehot); SUM, so no /valid_rows. + let dx = Tensor::cross_entropy_backward(&probs, &target, -upstream); + Var::push_grad(&parents[0], dx.to_dtype(logit_dtype)); + }), + ) +} + +/// DPO loss (Rafailov et al., M3) for one preference pair, as a scalar `[1]` Var +/// whose two parents are the POLICY sequence-logprobs of the chosen and rejected +/// completions (from [`seq_logprob`]); the REFERENCE logprobs are constants +/// (precomputed once from the frozen SFT model). With +/// `Δ = β·[(lpθ_chosen − lpref_chosen) − (lpθ_rejected − lpref_rejected)]` +/// the loss is `L = −log σ(Δ) = softplus(−Δ)`. Only the policy terms carry gradient: +/// `∂L/∂lpθ_chosen = −β·(1−σ(Δ))`, `∂L/∂lpθ_rejected = +β·(1−σ(Δ))`. +/// Degenerate points the M3 gate pins: `πθ == πref` ⇒ `Δ = 0`, `L = log 2`, implicit +/// reward 0; `β → 0` ⇒ gradient → 0. Same formula as TRL +/// (`-logsigmoid(β·(pol_c − pol_r − (ref_c − ref_r)))`). +pub fn dpo_loss( + lp_pol_chosen: &Var, + lp_pol_rejected: &Var, + lp_ref_chosen: f32, + lp_ref_rejected: f32, + beta: f32, +) -> Var { + use xtrain_tensor::Device; + let scalar = |v: &Var| v.value().to_device(Device::Cpu).as_slice::()[0]; + let pc = scalar(lp_pol_chosen); + let pr = scalar(lp_pol_rejected); + let delta = beta * ((pc - lp_ref_chosen) - (pr - lp_ref_rejected)); + // L = softplus(−Δ) = log(1 + e^{−Δ}) (numerically stable). + let nd = -delta; + let l = nd.max(0.0) + (-(nd.abs())).exp().ln_1p(); + let dev = lp_pol_chosen.value().device(); + let out = Tensor::from_slice(&[l], &[1]).to_device(dev); + + Var::from_op( + out, + vec![lp_pol_chosen.clone(), lp_pol_rejected.clone()], + Box::new(move |d, parents| { + let up = d.to_device(Device::Cpu).as_slice::()[0]; + // s = σ(−Δ) = 1 − σ(Δ); ∂L/∂Δ = −s, and ∂Δ/∂pc = β, ∂Δ/∂pr = −β. + let s = 1.0 / (1.0 + delta.exp()); + let g = up * beta * s; + let dev = parents[0].value().device(); + Var::push_grad(&parents[0], Tensor::from_slice(&[-g], &[1]).to_device(dev)); + Var::push_grad(&parents[1], Tensor::from_slice(&[g], &[1]).to_device(dev)); + }), + ) +} diff --git a/crates/xtrain-autodiff/tests/autograd.rs b/crates/xtrain-autodiff/tests/autograd.rs index 91bc788..e2b8d57 100644 --- a/crates/xtrain-autodiff/tests/autograd.rs +++ b/crates/xtrain-autodiff/tests/autograd.rs @@ -1005,3 +1005,83 @@ fn transpose_var(x: &Var) -> Var { }), ) } + +// seq_logprob (M3 DPO): Σ log p(target) over non-ignored rows. Grad-check with a +// completion mask — rows 0,1 are -100 (prompt, contribute 0), rows 2..6 supervised. +#[test] +fn seq_logprob_bwd() { + require_gpu(); + let (rows, cols) = (6usize, 9usize); + let x_h = fill(rows * cols, 202); + let targets: Vec = (0..rows) + .map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 }) + .collect(); + let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0)); + + let x = Var::leaf(cuda(&x_h, &[rows, cols])); + let lp = ops::seq_logprob(&x, &target); + lp.backward(); + let dx = x.grad().unwrap().to_device(Device::Cpu); + + // Numeric scalar = seq_logprob = −Σ per_row (per_row is 0 for ignored rows). + let tgt = targets.clone(); + let lx = move |v: &[f32], s: &[usize]| { + let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0)); + let (_, per_row) = cuda(v, s).cross_entropy(&t); + -per_row + .to_device(Device::Cpu) + .as_slice::() + .iter() + .sum::() + }; + report( + "seq_logprob dX", + &grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::(), cfg_nonlinear()), + ); +} + +// dpo_loss (M3): scalar DPO loss with the two policy logprobs as parents. Grad-check +// each parent (finite diff of softplus(−Δ)) + the degenerate points the gate pins: +// policy==reference ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0. +#[test] +fn dpo_loss_bwd_and_degenerate() { + require_gpu(); + let (ref_c, ref_r, beta) = (0.5f32, 0.9f32, 0.1f32); + let (pc0, pr0) = (1.2f32, 0.7f32); + let softplus = |z: f32| z.max(0.0) + (-(z.abs())).exp().ln_1p(); + + let pc = Var::leaf(cuda(&[pc0], &[1])); + let pr = Var::leaf(cuda(&[pr0], &[1])); + let l = ops::dpo_loss(&pc, &pr, ref_c, ref_r, beta); + l.backward(); + let dpc = pc.grad().unwrap().to_device(Device::Cpu).as_slice::()[0]; + let dpr = pr.grad().unwrap().to_device(Device::Cpu).as_slice::()[0]; + + let l_of_pc = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((v[0] - ref_c) - (pr0 - ref_r)))); + report("dpo_loss dpc", &grad_check(&[pc0], &[1], &l_of_pc, &[dpc], cfg_nonlinear())); + let l_of_pr = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((pc0 - ref_c) - (v[0] - ref_r)))); + report("dpo_loss dpr", &grad_check(&[pr0], &[1], &l_of_pr, &[dpr], cfg_nonlinear())); + + // Degenerate 1: policy == reference ⇒ Δ=0 ⇒ L=log2, grads = (∓β/2). + let pc2 = Var::leaf(cuda(&[ref_c], &[1])); + let pr2 = Var::leaf(cuda(&[ref_r], &[1])); + let l2 = ops::dpo_loss(&pc2, &pr2, ref_c, ref_r, beta); + let lval = l2.value().to_device(Device::Cpu).as_slice::()[0]; + l2.backward(); + let d2c = pc2.grad().unwrap().to_device(Device::Cpu).as_slice::()[0]; + let d2r = pr2.grad().unwrap().to_device(Device::Cpu).as_slice::()[0]; + assert!((lval - 2f32.ln()).abs() < 1e-5, "L at Δ=0 must be log2, got {lval}"); + assert!( + (d2c + beta * 0.5).abs() < 1e-5 && (d2r - beta * 0.5).abs() < 1e-5, + "grads at Δ=0 must be ∓β/2, got ({d2c},{d2r})" + ); + + // Degenerate 2: β=0 ⇒ grads 0. + let pc3 = Var::leaf(cuda(&[pc0], &[1])); + let pr3 = Var::leaf(cuda(&[pr0], &[1])); + let l3 = ops::dpo_loss(&pc3, &pr3, ref_c, ref_r, 0.0); + l3.backward(); + let d3c = pc3.grad().unwrap().to_device(Device::Cpu).as_slice::()[0]; + assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}"); + println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)"); +}