post-train: M3 — DPO pair-gen + training loop (verifiable arithmetic)

gen_dpo_pairs: chosen = gold answer, rejected = the SFT model's own greedy
(KV-cache engine, M2a) completion when it's a format-valid WRONG boxed answer —
a hard negative from the model's distribution. ~8% of prompts skipped (greedy
correct). Writes question<TAB>chosen<TAB>rejected (bare, SFT-framed at train).

train_dpo: loads the SFT ckpt as policy AND frozen reference; precomputes the
reference logprobs ONCE (policy==ref) and caches them (one resident model). Each
step forwards the policy on chosen+rejected, seq_logprob each, minimises
dpo_loss; the two forwards share params so backward accumulates both branches.
Tracks reward margin + preference accuracy (the doc-13 "don't trust loss alone"
health signal). Loss starts at exactly log2 (Δ=0 at init) — a built-in check.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 12:37:01 +08:00
parent f3c764ce95
commit 2f827fd6d8
2 changed files with 390 additions and 0 deletions

View File

@@ -0,0 +1,157 @@
//! Generate DPO preference pairs for the verifiable arithmetic task (M3).
//!
//! Per the aligned decision: **chosen = the gold answer** (`sft_answer`, always
//! correct), **rejected = a sampled-incorrect completion from the SFT model** — a
//! format-valid but wrong boxed answer, i.e. a hard negative drawn from the model's
//! own distribution. Since the SFT model is only ~8% correct (M1), a single GREEDY
//! decode is wrong ~92% of the time, so we use the KV-cache greedy engine (M2a) and
//! simply skip the ~8% of prompts where greedy happens to be correct (no usable
//! negative). Fast (cached), deterministic, and one clean hard negative per prompt.
//!
//! Writes `<out>` as `question<TAB>chosen<TAB>rejected` (bare text, like the SFT
//! TSV — `train_dpo` adds the `User:/Assistant:` frame). Problems are deduped.
#[cfg(no_cuda)]
fn main() {
eprintln!("gen_dpo_pairs: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use std::collections::HashSet;
#[cfg(not(no_cuda))]
use std::io::Write;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, generate_greedy_cached};
#[cfg(not(no_cuda))]
use xtrain_tensor::Device;
#[cfg(not(no_cuda))]
use xtrain_train::task::{Op, GenConfig, check_answer, gen_problem, parse_boxed_answer};
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
/// Keep only the first answer "turn": cut at the first `<|endoftext|>` then the
/// first newline (mirrors eval_arith).
#[cfg(not(no_cuda))]
fn first_answer_segment(continuation: &str) -> &str {
let s = continuation
.split("<|endoftext|>")
.next()
.unwrap_or(continuation);
s.split('\n').next().unwrap_or(s)
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let ckpt = positionals.first().expect("usage: gen_dpo_pairs <sft_ckpt> <tokenizer.json> [flags]");
let tok_path = positionals
.get(1)
.map(|s| s.as_str())
.unwrap_or("/opt/wjh/models/gpt2/tokenizer.json");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let n_pairs: usize = flag(&args, "--n", 2000);
let seed: u64 = flag(&args, "--seed", 1234);
let max_add: i64 = flag(&args, "--max-add", 999);
let max_mul: i64 = flag(&args, "--max-mul", 99);
let max_new: usize = flag(&args, "--max-tokens", 32);
let out = flag_value(&args, "--out").expect("--out <file> is required");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path));
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
.with_kv_heads(kv_heads);
let mut seed_init = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed_init = seed_init.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed_init, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed_init, 0.04)
}
});
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt.as_str()), &model.params())
.expect("load SFT checkpoint");
let gcfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let mut rng = seed.max(1);
let mut keys = HashSet::new();
let mut writer = std::io::BufWriter::new(std::fs::File::create(&out).expect("create out"));
let (mut written, mut skipped, mut attempts) = (0usize, 0usize, 0usize);
while written < n_pairs {
attempts += 1;
if attempts > n_pairs * 4 {
eprintln!("gen_dpo_pairs: stopping early at {written} pairs after {attempts} attempts");
break;
}
let p = gen_problem(&mut rng, &gcfg);
if !keys.insert(p.key()) {
continue;
}
let prompt_text = format!("User: {}\nAssistant:", p.question());
let ids: Vec<i32> = tok.encode(&prompt_text).into_iter().map(|t| t as i32).collect();
let out_ids = generate_greedy_cached(&model, device, &ids, max_new);
let cont = tok.decode(&out_ids[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont).trim();
// A valid hard negative: a well-formed boxed answer that is WRONG.
if parse_boxed_answer(seg).is_some() && !check_answer(seg, p.answer()) {
writeln!(writer, "{}\t{}\t{}", p.question(), p.sft_answer(), seg).expect("write");
written += 1;
} else {
skipped += 1; // greedy was correct (~8%) or malformed → no clean negative
}
}
writer.flush().expect("flush");
println!(
"wrote {written} DPO pairs to {out} (skipped {skipped} no-negative; {attempts} attempts; \
chosen=gold, rejected=greedy-incorrect)"
);
}

View File

@@ -0,0 +1,233 @@
//! DPO training on the verifiable arithmetic task (M3 / Stage P1).
//!
//! Loads the SFT checkpoint as the policy AND uses it as the frozen reference:
//! reference logprobs `log πref(chosen)` / `log πref(rejected)` are **precomputed
//! once** before any optimizer step (when policy == reference), then cached as
//! constants — so only one model stays resident (the design's reference-logprob
//! caching). Each step forwards the policy on the chosen and rejected completions,
//! takes [`seq_logprob`] of each, and minimises [`dpo_loss`]; the two forwards
//! share the policy params, so backward accumulates both branches' grads.
//!
//! Health metrics (per docs/18, the doc-13 "don't trust loss alone" lesson): the
//! chosenrejected **reward margin** and **preference accuracy** (margin > 0) — both
//! should rise. The arithmetic-correctness payoff is measured separately by running
//! `eval_arith` on the saved checkpoint.
//!
//! train_dpo <tokenizer.json> <dpo.tsv> --init-ckpt <sft.ckpt> <arch flags> \
//! --beta 0.1 --steps 1000 --lr 5e-7 --ckpt <out.ckpt>
#[cfg(no_cuda)]
fn main() {
eprintln!("train_dpo: built without CUDA (no_cuda); run on a GPU host.");
}
#[cfg(not(no_cuda))]
use xtrain_autodiff::ops;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer, ids_tensor};
#[cfg(not(no_cuda))]
use xtrain_tensor::Device;
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
/// Frame a (question, completion) the same way the SFT loader does
/// (`User: …\nAssistant:` prompt + ` {completion}\n<|endoftext|>`), then return the
/// next-token (input, target) pair: input = tokens[..L-1], target = labels[1..L]
/// with the prompt positions masked to -100 (only completion tokens supervised).
#[cfg(not(no_cuda))]
fn frame(
tok: &xserv_tokenizer::Tokenizer,
question: &str,
completion: &str,
) -> (Vec<i32>, Vec<i32>) {
let prompt = format!("User: {question}\nAssistant:");
let answer = format!(" {completion}\n<|endoftext|>");
let p_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
let a_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
let mut tokens = p_ids.clone();
tokens.extend_from_slice(&a_ids);
let mut labels = vec![-100i32; p_ids.len()];
labels.extend_from_slice(&a_ids);
let l = tokens.len();
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
}
/// Sequence logprob `Σ log πθ(completion)` of a framed (input, target) pair.
#[cfg(not(no_cuda))]
fn seq_lp(
model: &TinyTransformer,
device: Device,
input: &[i32],
target: &[i32],
) -> xtrain_autodiff::tape::Var {
let logits = model.forward(&ids_tensor(input, device));
ops::seq_logprob(&logits, &ids_tensor(target, device))
}
#[cfg(not(no_cuda))]
fn scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
use xtrain_optim::GpuAdamW;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let tok_path = positionals.first().expect("usage: train_dpo <tokenizer.json> <dpo.tsv> [flags]");
let tsv_path = positionals.get(1).expect("usage: train_dpo <tokenizer.json> <dpo.tsv> [flags]");
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let beta: f32 = flag(&args, "--beta", 0.1);
let steps: usize = flag(&args, "--steps", 1000);
let lr: f32 = flag(&args, "--lr", 5e-7);
let wd: f32 = flag(&args, "--wd", 0.0);
let clip: f32 = flag(&args, "--clip", 1.0);
let log_every: usize = flag(&args, "--log-every", 50);
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
// Load preference pairs: question<TAB>chosen<TAB>rejected.
let raw = std::fs::read_to_string(tsv_path).expect("read dpo tsv");
let pairs: Vec<(String, String, String)> = raw
.lines()
.filter(|l| !l.trim().is_empty())
.map(|l| {
let mut it = l.splitn(3, '\t');
let q = it.next().expect("question").to_string();
let c = it.next().expect("chosen").to_string();
let r = it.next().expect("rejected").to_string();
(q, c, r)
})
.collect();
assert!(!pairs.is_empty(), "no DPO pairs in {tsv_path}");
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
.with_kv_heads(kv_heads);
let mut seed_init = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed_init = seed_init.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed_init, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed_init, 0.04)
}
});
xtrain_train::checkpoint::load_into(std::path::Path::new(&init_ckpt), &model.params())
.expect("load SFT checkpoint");
model.eval(); // DPO runs without dropout (deterministic logprobs)
// Pre-tokenize every pair once.
let framed: Vec<((Vec<i32>, Vec<i32>), (Vec<i32>, Vec<i32>))> = pairs
.iter()
.map(|(q, c, r)| (frame(&tok, q, c), frame(&tok, q, r)))
.collect();
// Reference logprobs: computed ONCE while policy == reference (SFT init), cached.
println!("precomputing reference logprobs for {} pairs…", framed.len());
let mut ref_c = Vec::with_capacity(framed.len());
let mut ref_r = Vec::with_capacity(framed.len());
for ((ci, ct), (ri, rt)) in &framed {
ref_c.push(scalar(&seq_lp(&model, device, ci, ct)));
ref_r.push(scalar(&seq_lp(&model, device, ri, rt)));
}
let params = model.params();
let mut opt = GpuAdamW::new(wd);
let n = framed.len();
// A fixed shuffle (LCG-strided) so steps sweep the dataset without bias.
let mut order: Vec<usize> = (0..n).collect();
let mut s = 0x9E3779B97F4A7C15u64;
for i in (1..n).rev() {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
let j = (s >> 33) as usize % (i + 1);
order.swap(i, j);
}
let start = std::time::Instant::now();
let (mut win_loss, mut win_margin, mut win_acc) = (0f32, 0f32, 0usize);
for step in 0..steps {
let i = order[step % n];
let ((ci, ct), (ri, rt)) = &framed[i];
let lpc = seq_lp(&model, device, ci, ct);
let lpr = seq_lp(&model, device, ri, rt);
let (lpc_v, lpr_v) = (scalar(&lpc), scalar(&lpr));
let margin = (lpc_v - ref_c[i]) - (lpr_v - ref_r[i]); // implicit reward margin
let loss = ops::dpo_loss(&lpc, &lpr, ref_c[i], ref_r[i], beta);
win_loss += scalar(&loss);
win_margin += margin;
win_acc += (margin > 0.0) as usize;
loss.backward();
let _ = xtrain_train::clip::clip_grad_norm_gpu(&params, clip, 1.0);
opt.step(lr, &params);
for p in &params {
p.zero_grad();
}
if (step + 1) % log_every == 0 || step == steps - 1 {
let w = log_every.min(step + 1) as f32;
println!(
"step {:5}/{steps}: loss {:.4} | reward-margin {:+.4} | pref-acc {:.1}% | {:.1}s",
step + 1,
win_loss / w,
win_margin / w,
100.0 * win_acc as f32 / w,
start.elapsed().as_secs_f32(),
);
win_loss = 0.0;
win_margin = 0.0;
win_acc = 0;
}
}
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), &params).expect("save ckpt");
println!(
"DPO done: {} pairs, {steps} steps, beta {beta}, lr {lr:.1e}{out_ckpt}",
framed.len()
);
}