post-train: M3 — seq_logprob + dpo_loss autograd ops
Two new ops for DPO (M3), both reusing existing kernels (no new CUDA): - seq_logprob(logits, target): Σ log πθ(target) over non-ignored (target≥0) positions — the per-sequence logprob DPO compares between policy and reference. = −Σ per_row of cross_entropy (ignored rows already 0, like SFT masking); backward = cross_entropy_backward(probs, target, −upstream) (sum, no mean division). Gate: finite-diff grad-check with a -100 completion mask. - dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β): scalar L = −log σ(Δ) = softplus(−Δ) with the two policy logprobs as parents (ref logprobs constant). Gate: grad-check both parents + degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0). Same formula as TRL. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -439,3 +439,81 @@ pub fn cross_entropy(x: &Var, target: &Tensor) -> Var {
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Per-sequence log-probability: `Σ log πθ(target)` over the non-ignored
|
||||||
|
/// (`target ≥ 0`) positions — the quantity DPO (M3) compares between policy and
|
||||||
|
/// reference. `target` is `[rows]` I32 carrying `-100` (ignore) at masked positions
|
||||||
|
/// (e.g. the prompt) and the gold token id elsewhere; ignored positions contribute
|
||||||
|
/// 0, exactly like the SFT cross-entropy masking. Returns a scalar `[1]` Var.
|
||||||
|
///
|
||||||
|
/// Reuses the CE forward (per-row `−log p(target)`) and backward, so no new kernel:
|
||||||
|
/// `seq_logprob = −Σ per_row`, and `d(seq_logprob)/d(logits) = −(probs − onehot)`
|
||||||
|
/// = `cross_entropy_backward(probs, target, −upstream)` (a SUM, so no mean
|
||||||
|
/// division — contrast [`cross_entropy`], which divides by `valid_rows`).
|
||||||
|
pub fn seq_logprob(x: &Var, target: &Tensor) -> Var {
|
||||||
|
let logit_dtype = x.value().dtype();
|
||||||
|
let (probs, per_row) = x.value().cross_entropy(target);
|
||||||
|
// per_row[r] = −log p(target_r), and is 0 for ignored rows (target < 0), so the
|
||||||
|
// sum already counts only the supervised (completion) positions.
|
||||||
|
let sum_neg_lp: f32 = per_row
|
||||||
|
.to_device(xtrain_tensor::Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.iter()
|
||||||
|
.sum();
|
||||||
|
let out = Tensor::from_slice(&[-sum_neg_lp], &[1]).to_device(x.value().device());
|
||||||
|
|
||||||
|
let target = target.clone();
|
||||||
|
Var::from_op(
|
||||||
|
out,
|
||||||
|
vec![x.clone()],
|
||||||
|
Box::new(move |d, parents| {
|
||||||
|
let upstream = d.to_device(xtrain_tensor::Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
// d(Σ log p)/d(logits) = −(probs − onehot); SUM, so no /valid_rows.
|
||||||
|
let dx = Tensor::cross_entropy_backward(&probs, &target, -upstream);
|
||||||
|
Var::push_grad(&parents[0], dx.to_dtype(logit_dtype));
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DPO loss (Rafailov et al., M3) for one preference pair, as a scalar `[1]` Var
|
||||||
|
/// whose two parents are the POLICY sequence-logprobs of the chosen and rejected
|
||||||
|
/// completions (from [`seq_logprob`]); the REFERENCE logprobs are constants
|
||||||
|
/// (precomputed once from the frozen SFT model). With
|
||||||
|
/// `Δ = β·[(lpθ_chosen − lpref_chosen) − (lpθ_rejected − lpref_rejected)]`
|
||||||
|
/// the loss is `L = −log σ(Δ) = softplus(−Δ)`. Only the policy terms carry gradient:
|
||||||
|
/// `∂L/∂lpθ_chosen = −β·(1−σ(Δ))`, `∂L/∂lpθ_rejected = +β·(1−σ(Δ))`.
|
||||||
|
/// Degenerate points the M3 gate pins: `πθ == πref` ⇒ `Δ = 0`, `L = log 2`, implicit
|
||||||
|
/// reward 0; `β → 0` ⇒ gradient → 0. Same formula as TRL
|
||||||
|
/// (`-logsigmoid(β·(pol_c − pol_r − (ref_c − ref_r)))`).
|
||||||
|
pub fn dpo_loss(
|
||||||
|
lp_pol_chosen: &Var,
|
||||||
|
lp_pol_rejected: &Var,
|
||||||
|
lp_ref_chosen: f32,
|
||||||
|
lp_ref_rejected: f32,
|
||||||
|
beta: f32,
|
||||||
|
) -> Var {
|
||||||
|
use xtrain_tensor::Device;
|
||||||
|
let scalar = |v: &Var| v.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let pc = scalar(lp_pol_chosen);
|
||||||
|
let pr = scalar(lp_pol_rejected);
|
||||||
|
let delta = beta * ((pc - lp_ref_chosen) - (pr - lp_ref_rejected));
|
||||||
|
// L = softplus(−Δ) = log(1 + e^{−Δ}) (numerically stable).
|
||||||
|
let nd = -delta;
|
||||||
|
let l = nd.max(0.0) + (-(nd.abs())).exp().ln_1p();
|
||||||
|
let dev = lp_pol_chosen.value().device();
|
||||||
|
let out = Tensor::from_slice(&[l], &[1]).to_device(dev);
|
||||||
|
|
||||||
|
Var::from_op(
|
||||||
|
out,
|
||||||
|
vec![lp_pol_chosen.clone(), lp_pol_rejected.clone()],
|
||||||
|
Box::new(move |d, parents| {
|
||||||
|
let up = d.to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
// s = σ(−Δ) = 1 − σ(Δ); ∂L/∂Δ = −s, and ∂Δ/∂pc = β, ∂Δ/∂pr = −β.
|
||||||
|
let s = 1.0 / (1.0 + delta.exp());
|
||||||
|
let g = up * beta * s;
|
||||||
|
let dev = parents[0].value().device();
|
||||||
|
Var::push_grad(&parents[0], Tensor::from_slice(&[-g], &[1]).to_device(dev));
|
||||||
|
Var::push_grad(&parents[1], Tensor::from_slice(&[g], &[1]).to_device(dev));
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1005,3 +1005,83 @@ fn transpose_var(x: &Var) -> Var {
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// seq_logprob (M3 DPO): Σ log p(target) over non-ignored rows. Grad-check with a
|
||||||
|
// completion mask — rows 0,1 are -100 (prompt, contribute 0), rows 2..6 supervised.
|
||||||
|
#[test]
|
||||||
|
fn seq_logprob_bwd() {
|
||||||
|
require_gpu();
|
||||||
|
let (rows, cols) = (6usize, 9usize);
|
||||||
|
let x_h = fill(rows * cols, 202);
|
||||||
|
let targets: Vec<i32> = (0..rows)
|
||||||
|
.map(|r| if r < 2 { -100 } else { (r * 2 % cols) as i32 })
|
||||||
|
.collect();
|
||||||
|
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
|
||||||
|
|
||||||
|
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
|
||||||
|
let lp = ops::seq_logprob(&x, &target);
|
||||||
|
lp.backward();
|
||||||
|
let dx = x.grad().unwrap().to_device(Device::Cpu);
|
||||||
|
|
||||||
|
// Numeric scalar = seq_logprob = −Σ per_row (per_row is 0 for ignored rows).
|
||||||
|
let tgt = targets.clone();
|
||||||
|
let lx = move |v: &[f32], s: &[usize]| {
|
||||||
|
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
|
||||||
|
let (_, per_row) = cuda(v, s).cross_entropy(&t);
|
||||||
|
-per_row
|
||||||
|
.to_device(Device::Cpu)
|
||||||
|
.as_slice::<f32>()
|
||||||
|
.iter()
|
||||||
|
.sum::<f32>()
|
||||||
|
};
|
||||||
|
report(
|
||||||
|
"seq_logprob dX",
|
||||||
|
&grad_check(&x_h, &[rows, cols], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// dpo_loss (M3): scalar DPO loss with the two policy logprobs as parents. Grad-check
|
||||||
|
// each parent (finite diff of softplus(−Δ)) + the degenerate points the gate pins:
|
||||||
|
// policy==reference ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0.
|
||||||
|
#[test]
|
||||||
|
fn dpo_loss_bwd_and_degenerate() {
|
||||||
|
require_gpu();
|
||||||
|
let (ref_c, ref_r, beta) = (0.5f32, 0.9f32, 0.1f32);
|
||||||
|
let (pc0, pr0) = (1.2f32, 0.7f32);
|
||||||
|
let softplus = |z: f32| z.max(0.0) + (-(z.abs())).exp().ln_1p();
|
||||||
|
|
||||||
|
let pc = Var::leaf(cuda(&[pc0], &[1]));
|
||||||
|
let pr = Var::leaf(cuda(&[pr0], &[1]));
|
||||||
|
let l = ops::dpo_loss(&pc, &pr, ref_c, ref_r, beta);
|
||||||
|
l.backward();
|
||||||
|
let dpc = pc.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let dpr = pr.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
|
||||||
|
let l_of_pc = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((v[0] - ref_c) - (pr0 - ref_r))));
|
||||||
|
report("dpo_loss dpc", &grad_check(&[pc0], &[1], &l_of_pc, &[dpc], cfg_nonlinear()));
|
||||||
|
let l_of_pr = move |v: &[f32], _s: &[usize]| softplus(-(beta * ((pc0 - ref_c) - (v[0] - ref_r))));
|
||||||
|
report("dpo_loss dpr", &grad_check(&[pr0], &[1], &l_of_pr, &[dpr], cfg_nonlinear()));
|
||||||
|
|
||||||
|
// Degenerate 1: policy == reference ⇒ Δ=0 ⇒ L=log2, grads = (∓β/2).
|
||||||
|
let pc2 = Var::leaf(cuda(&[ref_c], &[1]));
|
||||||
|
let pr2 = Var::leaf(cuda(&[ref_r], &[1]));
|
||||||
|
let l2 = ops::dpo_loss(&pc2, &pr2, ref_c, ref_r, beta);
|
||||||
|
let lval = l2.value().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
l2.backward();
|
||||||
|
let d2c = pc2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
let d2r = pr2.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
assert!((lval - 2f32.ln()).abs() < 1e-5, "L at Δ=0 must be log2, got {lval}");
|
||||||
|
assert!(
|
||||||
|
(d2c + beta * 0.5).abs() < 1e-5 && (d2r - beta * 0.5).abs() < 1e-5,
|
||||||
|
"grads at Δ=0 must be ∓β/2, got ({d2c},{d2r})"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Degenerate 2: β=0 ⇒ grads 0.
|
||||||
|
let pc3 = Var::leaf(cuda(&[pc0], &[1]));
|
||||||
|
let pr3 = Var::leaf(cuda(&[pr0], &[1]));
|
||||||
|
let l3 = ops::dpo_loss(&pc3, &pr3, ref_c, ref_r, 0.0);
|
||||||
|
l3.backward();
|
||||||
|
let d3c = pc3.grad().unwrap().to_device(Device::Cpu).as_slice::<f32>()[0];
|
||||||
|
assert!(d3c.abs() < 1e-9, "β=0 ⇒ grad 0, got {d3c}");
|
||||||
|
println!("dpo_loss OK: grad-check (dpc,dpr) + degenerate (Δ=0→log2 & ∓β/2, β=0→0)");
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user