post-train: M3 — seq_logprob + dpo_loss autograd ops

Two new ops for DPO (M3), both reusing existing kernels (no new CUDA):

- seq_logprob(logits, target): Σ log πθ(target) over non-ignored (target≥0)
  positions — the per-sequence logprob DPO compares between policy and
  reference. = −Σ per_row of cross_entropy (ignored rows already 0, like SFT
  masking); backward = cross_entropy_backward(probs, target, −upstream) (sum,
  no mean division). Gate: finite-diff grad-check with a -100 completion mask.

- dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β): scalar
  L = −log σ(Δ) = softplus(−Δ) with the two policy logprobs as parents (ref
  logprobs constant). Gate: grad-check both parents + degenerate points
  (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0). Same formula as TRL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 12:11:01 +08:00
parent b39e6e7110
commit f3c764ce95
2 changed files with 158 additions and 0 deletions

View File

@@ -439,3 +439,81 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
}),
)
}
/// Per-sequence log-probability: `Σ log πθ(target)` over the non-ignored
/// (`target ≥ 0`) positions — the quantity DPO (M3) compares between policy and
/// reference. `target` is `[rows]` I32 carrying `-100` (ignore) at masked positions
/// (e.g. the prompt) and the gold token id elsewhere; ignored positions contribute
/// 0, exactly like the SFT cross-entropy masking. Returns a scalar `[1]` Var.
///
/// Reuses the CE forward (per-row `log p(target)`) and backward, so no new kernel:
/// `seq_logprob = −Σ per_row`, and `d(seq_logprob)/d(logits) = (probs onehot)`
/// = `cross_entropy_backward(probs, target, upstream)` (a SUM, so no mean
/// division — contrast [`cross_entropy`], which divides by `valid_rows`).
pub fn seq_logprob(x: &Var, target: &Tensor) -> Var {
let logit_dtype = x.value().dtype();
let (probs, per_row) = x.value().cross_entropy(target);
// per_row[r] = log p(target_r), and is 0 for ignored rows (target < 0), so the
// sum already counts only the supervised (completion) positions.
let sum_neg_lp: f32 = per_row
.to_device(xtrain_tensor::Device::Cpu)
.as_slice::<f32>()
.iter()
.sum();
let out = Tensor::from_slice(&[-sum_neg_lp], &[1]).to_device(x.value().device());
let target = target.clone();
Var::from_op(
out,
vec![x.clone()],
Box::new(move |d, parents| {
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
// d(Σ log p)/d(logits) = (probs onehot); SUM, so no /valid_rows.
let dx = Tensor::cross_entropy_backward(&probs, &target, -upstream);
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
}),
)
}
/// DPO loss (Rafailov et al., M3) for one preference pair, as a scalar `[1]` Var
/// whose two parents are the POLICY sequence-logprobs of the chosen and rejected
/// completions (from [`seq_logprob`]); the REFERENCE logprobs are constants
/// (precomputed once from the frozen SFT model). With
/// `Δ = β·[(lpθ_chosen lpref_chosen) (lpθ_rejected lpref_rejected)]`
/// the loss is `L = log σ(Δ) = softplus(−Δ)`. Only the policy terms carry gradient:
/// `∂L/∂lpθ_chosen = −β·(1σ(Δ))`, `∂L/∂lpθ_rejected = +β·(1σ(Δ))`.
/// Degenerate points the M3 gate pins: `πθ == πref` ⇒ `Δ = 0`, `L = log 2`, implicit
/// reward 0; `β → 0` ⇒ gradient → 0. Same formula as TRL
/// (`-logsigmoid(β·(pol_c pol_r (ref_c ref_r)))`).
pub fn dpo_loss(
lp_pol_chosen: &Var,
lp_pol_rejected: &Var,
lp_ref_chosen: f32,
lp_ref_rejected: f32,
beta: f32,
) -> Var {
use xtrain_tensor::Device;
let scalar = |v: &Var| v.value().to_device(Device::Cpu).as_slice::<f32>()[0];
let pc = scalar(lp_pol_chosen);
let pr = scalar(lp_pol_rejected);
let delta = beta * ((pc - lp_ref_chosen) - (pr - lp_ref_rejected));
// L = softplus(−Δ) = log(1 + e^{−Δ}) (numerically stable).
let nd = -delta;
let l = nd.max(0.0) + (-(nd.abs())).exp().ln_1p();
let dev = lp_pol_chosen.value().device();
let out = Tensor::from_slice(&[l], &[1]).to_device(dev);
Var::from_op(
out,
vec![lp_pol_chosen.clone(), lp_pol_rejected.clone()],
Box::new(move |d, parents| {
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
// s = σ(−Δ) = 1 σ(Δ); ∂L/∂Δ = s, and ∂Δ/∂pc = β, ∂Δ/∂pr = −β.
let s = 1.0 / (1.0 + delta.exp());
let g = up * beta * s;
let dev = parents[0].value().device();
Var::push_grad(&parents[0], Tensor::from_slice(&[-g], &[1]).to_device(dev));
Var::push_grad(&parents[1], Tensor::from_slice(&[g], &[1]).to_device(dev));
}),
)
}

View File

@@ -1005,3 +1005,83 @@ fn transpose_var(x: &Var) -> Var {
}),
)
}
// seq_logprob (M3 DPO): Σ log p(target) over non-ignored rows. Grad-check with a
// completion mask — rows 0,1 are -100 (prompt, contribute 0), rows 2..6 supervised.
#[test]
fn seq_logprob_bwd() {
require_gpu();
let (rows, cols) = (6usize, 9usize);
let x_h = fill(rows * cols, 202);
let targets: Vec<i32> = (0..rows)
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
.collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let lp = ops::seq_logprob(&x, &target);
lp.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Numeric scalar = seq_logprob = −Σ per_row (per_row is 0 for ignored rows).
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
-per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
};
report(
"seq_logprob dX",
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// dpo_loss (M3): scalar DPO loss with the two policy logprobs as parents. Grad-check
// each parent (finite diff of softplus(−Δ)) + the degenerate points the gate pins:
// policy==reference ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0.
#[test]
fn dpo_loss_bwd_and_degenerate() {
require_gpu();
let (ref_c, ref_r, beta) = (0.5f32, 0.9f32, 0.1f32);
let (pc0, pr0) = (1.2f32, 0.7f32);
let softplus = |z: f32| z.max(0.0) + (-(z.abs())).exp().ln_1p();
let pc = Var::leaf(cuda(&[pc0], &[1]));
let pr = Var::leaf(cuda(&[pr0], &[1]));
let l = ops::dpo_loss(&pc, &pr, ref_c, ref_r, beta);
l.backward();
let dpc = pc.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let dpr = pr.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let l_of_pc = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((v[0] - ref_c) - (pr0 - ref_r))));
report("dpo_loss dpc", &grad_check(&[pc0], &[1], &l_of_pc, &[dpc], cfg_nonlinear()));
let l_of_pr = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((pc0 - ref_c) - (v[0] - ref_r))));
report("dpo_loss dpr", &grad_check(&[pr0], &[1], &l_of_pr, &[dpr], cfg_nonlinear()));
// Degenerate 1: policy == reference ⇒ Δ=0 ⇒ L=log2, grads = (∓β/2).
let pc2 = Var::leaf(cuda(&[ref_c], &[1]));
let pr2 = Var::leaf(cuda(&[ref_r], &[1]));
let l2 = ops::dpo_loss(&pc2, &pr2, ref_c, ref_r, beta);
let lval = l2.value().to_device(Device::Cpu).as_slice::<f32>()[0];
l2.backward();
let d2c = pc2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
let d2r = pr2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
assert!((lval - 2f32.ln()).abs() < 1e-5, "L at Δ=0 must be log2, got {lval}");
assert!(
(d2c + beta * 0.5).abs() < 1e-5 && (d2r - beta * 0.5).abs() < 1e-5,
"grads at Δ=0 must be ∓β/2, got ({d2c},{d2r})"
);
// Degenerate 2: β=0 ⇒ grads 0.
let pc3 = Var::leaf(cuda(&[pc0], &[1]));
let pr3 = Var::leaf(cuda(&[pr0], &[1]));
let l3 = ops::dpo_loss(&pc3, &pr3, ref_c, ref_r, 0.0);
l3.backward();
let d3c = pc3.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
}