post-train: M3 — seq_logprob + dpo_loss autograd ops
Two new ops for DPO (M3), both reusing existing kernels (no new CUDA): - seq_logprob(logits, target): Σ log πθ(target) over non-ignored (target≥0) positions — the per-sequence logprob DPO compares between policy and reference. = −Σ per_row of cross_entropy (ignored rows already 0, like SFT masking); backward = cross_entropy_backward(probs, target, −upstream) (sum, no mean division). Gate: finite-diff grad-check with a -100 completion mask. - dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β): scalar L = −log σ(Δ) = softplus(−Δ) with the two policy logprobs as parents (ref logprobs constant). Gate: grad-check both parents + degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0). Same formula as TRL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -439,3 +439,81 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-sequence log-probability: `Σ log πθ(target)` over the non-ignored
|
||||
/// (`target ≥ 0`) positions — the quantity DPO (M3) compares between policy and
|
||||
/// reference. `target` is `[rows]` I32 carrying `-100` (ignore) at masked positions
|
||||
/// (e.g. the prompt) and the gold token id elsewhere; ignored positions contribute
|
||||
/// 0, exactly like the SFT cross-entropy masking. Returns a scalar `[1]` Var.
|
||||
///
|
||||
/// Reuses the CE forward (per-row `−log p(target)`) and backward, so no new kernel:
|
||||
/// `seq_logprob = −Σ per_row`, and `d(seq_logprob)/d(logits) = −(probs − onehot)`
|
||||
/// = `cross_entropy_backward(probs, target, −upstream)` (a SUM, so no mean
|
||||
/// division — contrast [`cross_entropy`], which divides by `valid_rows`).
|
||||
pub fn seq_logprob(x: &Var, target: &Tensor) -> Var {
|
||||
let logit_dtype = x.value().dtype();
|
||||
let (probs, per_row) = x.value().cross_entropy(target);
|
||||
// per_row[r] = −log p(target_r), and is 0 for ignored rows (target < 0), so the
|
||||
// sum already counts only the supervised (completion) positions.
|
||||
let sum_neg_lp: f32 = per_row
|
||||
.to_device(xtrain_tensor::Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.sum();
|
||||
let out = Tensor::from_slice(&[-sum_neg_lp], &[1]).to_device(x.value().device());
|
||||
|
||||
let target = target.clone();
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![x.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||
// d(Σ log p)/d(logits) = −(probs − onehot); SUM, so no /valid_rows.
|
||||
let dx = Tensor::cross_entropy_backward(&probs, &target, -upstream);
|
||||
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// DPO loss (Rafailov et al., M3) for one preference pair, as a scalar `[1]` Var
|
||||
/// whose two parents are the POLICY sequence-logprobs of the chosen and rejected
|
||||
/// completions (from [`seq_logprob`]); the REFERENCE logprobs are constants
|
||||
/// (precomputed once from the frozen SFT model). With
|
||||
/// `Δ = β·[(lpθ_chosen − lpref_chosen) − (lpθ_rejected − lpref_rejected)]`
|
||||
/// the loss is `L = −log σ(Δ) = softplus(−Δ)`. Only the policy terms carry gradient:
|
||||
/// `∂L/∂lpθ_chosen = −β·(1−σ(Δ))`, `∂L/∂lpθ_rejected = +β·(1−σ(Δ))`.
|
||||
/// Degenerate points the M3 gate pins: `πθ == πref` ⇒ `Δ = 0`, `L = log 2`, implicit
|
||||
/// reward 0; `β → 0` ⇒ gradient → 0. Same formula as TRL
|
||||
/// (`-logsigmoid(β·(pol_c − pol_r − (ref_c − ref_r)))`).
|
||||
pub fn dpo_loss(
|
||||
lp_pol_chosen: &Var,
|
||||
lp_pol_rejected: &Var,
|
||||
lp_ref_chosen: f32,
|
||||
lp_ref_rejected: f32,
|
||||
beta: f32,
|
||||
) -> Var {
|
||||
use xtrain_tensor::Device;
|
||||
let scalar = |v: &Var| v.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let pc = scalar(lp_pol_chosen);
|
||||
let pr = scalar(lp_pol_rejected);
|
||||
let delta = beta * ((pc - lp_ref_chosen) - (pr - lp_ref_rejected));
|
||||
// L = softplus(−Δ) = log(1 + e^{−Δ}) (numerically stable).
|
||||
let nd = -delta;
|
||||
let l = nd.max(0.0) + (-(nd.abs())).exp().ln_1p();
|
||||
let dev = lp_pol_chosen.value().device();
|
||||
let out = Tensor::from_slice(&[l], &[1]).to_device(dev);
|
||||
|
||||
Var::from_op(
|
||||
out,
|
||||
vec![lp_pol_chosen.clone(), lp_pol_rejected.clone()],
|
||||
Box::new(move |d, parents| {
|
||||
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
// s = σ(−Δ) = 1 − σ(Δ); ∂L/∂Δ = −s, and ∂Δ/∂pc = β, ∂Δ/∂pr = −β.
|
||||
let s = 1.0 / (1.0 + delta.exp());
|
||||
let g = up * beta * s;
|
||||
let dev = parents[0].value().device();
|
||||
Var::push_grad(&parents[0], Tensor::from_slice(&[-g], &[1]).to_device(dev));
|
||||
Var::push_grad(&parents[1], Tensor::from_slice(&[g], &[1]).to_device(dev));
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1005,3 +1005,83 @@ fn transpose_var(x: &Var) -> Var {
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// seq_logprob (M3 DPO): Σ log p(target) over non-ignored rows. Grad-check with a
|
||||
// completion mask — rows 0,1 are -100 (prompt, contribute 0), rows 2..6 supervised.
|
||||
#[test]
|
||||
fn seq_logprob_bwd() {
|
||||
require_gpu();
|
||||
let (rows, cols) = (6usize, 9usize);
|
||||
let x_h = fill(rows * cols, 202);
|
||||
let targets: Vec<i32> = (0..rows)
|
||||
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
|
||||
.collect();
|
||||
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||
|
||||
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||
let lp = ops::seq_logprob(&x, &target);
|
||||
lp.backward();
|
||||
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||
|
||||
// Numeric scalar = seq_logprob = −Σ per_row (per_row is 0 for ignored rows).
|
||||
let tgt = targets.clone();
|
||||
let lx = move |v: &[f32], s: &[usize]| {
|
||||
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
|
||||
let (_, per_row) = cuda(v, s).cross_entropy(&t);
|
||||
-per_row
|
||||
.to_device(Device::Cpu)
|
||||
.as_slice::<f32>()
|
||||
.iter()
|
||||
.sum::<f32>()
|
||||
};
|
||||
report(
|
||||
"seq_logprob dX",
|
||||
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||||
);
|
||||
}
|
||||
|
||||
// dpo_loss (M3): scalar DPO loss with the two policy logprobs as parents. Grad-check
|
||||
// each parent (finite diff of softplus(−Δ)) + the degenerate points the gate pins:
|
||||
// policy==reference ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0.
|
||||
#[test]
|
||||
fn dpo_loss_bwd_and_degenerate() {
|
||||
require_gpu();
|
||||
let (ref_c, ref_r, beta) = (0.5f32, 0.9f32, 0.1f32);
|
||||
let (pc0, pr0) = (1.2f32, 0.7f32);
|
||||
let softplus = |z: f32| z.max(0.0) + (-(z.abs())).exp().ln_1p();
|
||||
|
||||
let pc = Var::leaf(cuda(&[pc0], &[1]));
|
||||
let pr = Var::leaf(cuda(&[pr0], &[1]));
|
||||
let l = ops::dpo_loss(&pc, &pr, ref_c, ref_r, beta);
|
||||
l.backward();
|
||||
let dpc = pc.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let dpr = pr.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
|
||||
let l_of_pc = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((v[0] - ref_c) - (pr0 - ref_r))));
|
||||
report("dpo_loss dpc", &grad_check(&[pc0], &[1], &l_of_pc, &[dpc], cfg_nonlinear()));
|
||||
let l_of_pr = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((pc0 - ref_c) - (v[0] - ref_r))));
|
||||
report("dpo_loss dpr", &grad_check(&[pr0], &[1], &l_of_pr, &[dpr], cfg_nonlinear()));
|
||||
|
||||
// Degenerate 1: policy == reference ⇒ Δ=0 ⇒ L=log2, grads = (∓β/2).
|
||||
let pc2 = Var::leaf(cuda(&[ref_c], &[1]));
|
||||
let pr2 = Var::leaf(cuda(&[ref_r], &[1]));
|
||||
let l2 = ops::dpo_loss(&pc2, &pr2, ref_c, ref_r, beta);
|
||||
let lval = l2.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
l2.backward();
|
||||
let d2c = pc2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
let d2r = pr2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
assert!((lval - 2f32.ln()).abs() < 1e-5, "L at Δ=0 must be log2, got {lval}");
|
||||
assert!(
|
||||
(d2c + beta * 0.5).abs() < 1e-5 && (d2r - beta * 0.5).abs() < 1e-5,
|
||||
"grads at Δ=0 must be ∓β/2, got ({d2c},{d2r})"
|
||||
);
|
||||
|
||||
// Degenerate 2: β=0 ⇒ grads 0.
|
||||
let pc3 = Var::leaf(cuda(&[pc0], &[1]));
|
||||
let pr3 = Var::leaf(cuda(&[pr0], &[1]));
|
||||
let l3 = ops::dpo_loss(&pc3, &pr3, ref_c, ref_r, 0.0);
|
||||
l3.backward();
|
||||
let d3c = pc3.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
||||
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user