post-train: M3 — DPO pair-gen + training loop (verifiable arithmetic)
gen_dpo_pairs: chosen = gold answer, rejected = the SFT model's own greedy (KV-cache engine, M2a) completion when it's a format-valid WRONG boxed answer — a hard negative from the model's distribution. ~8% of prompts skipped (greedy correct). Writes question<TAB>chosen<TAB>rejected (bare, SFT-framed at train). train_dpo: loads the SFT ckpt as policy AND frozen reference; precomputes the reference logprobs ONCE (policy==ref) and caches them (one resident model). Each step forwards the policy on chosen+rejected, seq_logprob each, minimises dpo_loss; the two forwards share params so backward accumulates both branches. Tracks reward margin + preference accuracy (the doc-13 "don't trust loss alone" health signal). Loss starts at exactly log2 (Δ=0 at init) — a built-in check. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
157
crates/xtrain-train/src/bin/gen_dpo_pairs.rs
Normal file
157
crates/xtrain-train/src/bin/gen_dpo_pairs.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
//! Generate DPO preference pairs for the verifiable arithmetic task (M3).
|
||||
//!
|
||||
//! Per the aligned decision: **chosen = the gold answer** (`sft_answer`, always
|
||||
//! correct), **rejected = a sampled-incorrect completion from the SFT model** — a
|
||||
//! format-valid but wrong boxed answer, i.e. a hard negative drawn from the model's
|
||||
//! own distribution. Since the SFT model is only ~8% correct (M1), a single GREEDY
|
||||
//! decode is wrong ~92% of the time, so we use the KV-cache greedy engine (M2a) and
|
||||
//! simply skip the ~8% of prompts where greedy happens to be correct (no usable
|
||||
//! negative). Fast (cached), deterministic, and one clean hard negative per prompt.
|
||||
//!
|
||||
//! Writes `<out>` as `question<TAB>chosen<TAB>rejected` (bare text, like the SFT
|
||||
//! TSV — `train_dpo` adds the `User:/Assistant:` frame). Problems are deduped.
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("gen_dpo_pairs: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::collections::HashSet;
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::io::Write;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, generate_greedy_cached};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{Op, GenConfig, check_answer, gen_problem, parse_boxed_answer};
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Keep only the first answer "turn": cut at the first `<|endoftext|>` then the
|
||||
/// first newline (mirrors eval_arith).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn first_answer_segment(continuation: &str) -> &str {
|
||||
let s = continuation
|
||||
.split("<|endoftext|>")
|
||||
.next()
|
||||
.unwrap_or(continuation);
|
||||
s.split('\n').next().unwrap_or(s)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let ckpt = positionals.first().expect("usage: gen_dpo_pairs <sft_ckpt> <tokenizer.json> [flags]");
|
||||
let tok_path = positionals
|
||||
.get(1)
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("/opt/wjh/models/gpt2/tokenizer.json");
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let n_pairs: usize = flag(&args, "--n", 2000);
|
||||
let seed: u64 = flag(&args, "--seed", 1234);
|
||||
let max_add: i64 = flag(&args, "--max-add", 999);
|
||||
let max_mul: i64 = flag(&args, "--max-mul", 99);
|
||||
let max_new: usize = flag(&args, "--max-tokens", 32);
|
||||
let out = flag_value(&args, "--out").expect("--out <file> is required");
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path));
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
||||
.with_kv_heads(kv_heads);
|
||||
let mut seed_init = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed_init = seed_init.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed_init, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed_init, 0.04)
|
||||
}
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(std::path::Path::new(ckpt.as_str()), &model.params())
|
||||
.expect("load SFT checkpoint");
|
||||
|
||||
let gcfg = GenConfig {
|
||||
max_add,
|
||||
max_mul,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
};
|
||||
let mut rng = seed.max(1);
|
||||
let mut keys = HashSet::new();
|
||||
let mut writer = std::io::BufWriter::new(std::fs::File::create(&out).expect("create out"));
|
||||
let (mut written, mut skipped, mut attempts) = (0usize, 0usize, 0usize);
|
||||
|
||||
while written < n_pairs {
|
||||
attempts += 1;
|
||||
if attempts > n_pairs * 4 {
|
||||
eprintln!("gen_dpo_pairs: stopping early at {written} pairs after {attempts} attempts");
|
||||
break;
|
||||
}
|
||||
let p = gen_problem(&mut rng, &gcfg);
|
||||
if !keys.insert(p.key()) {
|
||||
continue;
|
||||
}
|
||||
let prompt_text = format!("User: {}\nAssistant:", p.question());
|
||||
let ids: Vec<i32> = tok.encode(&prompt_text).into_iter().map(|t| t as i32).collect();
|
||||
let out_ids = generate_greedy_cached(&model, device, &ids, max_new);
|
||||
let cont = tok.decode(&out_ids[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont).trim();
|
||||
// A valid hard negative: a well-formed boxed answer that is WRONG.
|
||||
if parse_boxed_answer(seg).is_some() && !check_answer(seg, p.answer()) {
|
||||
writeln!(writer, "{}\t{}\t{}", p.question(), p.sft_answer(), seg).expect("write");
|
||||
written += 1;
|
||||
} else {
|
||||
skipped += 1; // greedy was correct (~8%) or malformed → no clean negative
|
||||
}
|
||||
}
|
||||
writer.flush().expect("flush");
|
||||
println!(
|
||||
"wrote {written} DPO pairs to {out} (skipped {skipped} no-negative; {attempts} attempts; \
|
||||
chosen=gold, rejected=greedy-incorrect)"
|
||||
);
|
||||
}
|
||||
233
crates/xtrain-train/src/bin/train_dpo.rs
Normal file
233
crates/xtrain-train/src/bin/train_dpo.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
//! DPO training on the verifiable arithmetic task (M3 / Stage P1).
|
||||
//!
|
||||
//! Loads the SFT checkpoint as the policy AND uses it as the frozen reference:
|
||||
//! reference logprobs `log πref(chosen)` / `log πref(rejected)` are **precomputed
|
||||
//! once** before any optimizer step (when policy == reference), then cached as
|
||||
//! constants — so only one model stays resident (the design's reference-logprob
|
||||
//! caching). Each step forwards the policy on the chosen and rejected completions,
|
||||
//! takes [`seq_logprob`] of each, and minimises [`dpo_loss`]; the two forwards
|
||||
//! share the policy params, so backward accumulates both branches' grads.
|
||||
//!
|
||||
//! Health metrics (per docs/18, the doc-13 "don't trust loss alone" lesson): the
|
||||
//! chosen−rejected **reward margin** and **preference accuracy** (margin > 0) — both
|
||||
//! should rise. The arithmetic-correctness payoff is measured separately by running
|
||||
//! `eval_arith` on the saved checkpoint.
|
||||
//!
|
||||
//! train_dpo <tokenizer.json> <dpo.tsv> --init-ckpt <sft.ckpt> <arch flags> \
|
||||
//! --beta 0.1 --steps 1000 --lr 5e-7 --ckpt <out.ckpt>
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("train_dpo: built without CUDA (no_cuda); run on a GPU host.");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_autodiff::ops;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Frame a (question, completion) the same way the SFT loader does
|
||||
/// (`User: …\nAssistant:` prompt + ` {completion}\n<|endoftext|>`), then return the
|
||||
/// next-token (input, target) pair: input = tokens[..L-1], target = labels[1..L]
|
||||
/// with the prompt positions masked to -100 (only completion tokens supervised).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn frame(
|
||||
tok: &xserv_tokenizer::Tokenizer,
|
||||
question: &str,
|
||||
completion: &str,
|
||||
) -> (Vec<i32>, Vec<i32>) {
|
||||
let prompt = format!("User: {question}\nAssistant:");
|
||||
let answer = format!(" {completion}\n<|endoftext|>");
|
||||
let p_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
let a_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||
let mut tokens = p_ids.clone();
|
||||
tokens.extend_from_slice(&a_ids);
|
||||
let mut labels = vec![-100i32; p_ids.len()];
|
||||
labels.extend_from_slice(&a_ids);
|
||||
let l = tokens.len();
|
||||
(tokens[..l - 1].to_vec(), labels[1..l].to_vec())
|
||||
}
|
||||
|
||||
/// Sequence logprob `Σ log πθ(completion)` of a framed (input, target) pair.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn seq_lp(
|
||||
model: &TinyTransformer,
|
||||
device: Device,
|
||||
input: &[i32],
|
||||
target: &[i32],
|
||||
) -> xtrain_autodiff::tape::Var {
|
||||
let logits = model.forward(&ids_tensor(input, device));
|
||||
ops::seq_logprob(&logits, &ids_tensor(target, device))
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn scalar(v: &xtrain_autodiff::tape::Var) -> f32 {
|
||||
v.value().to_device(Device::Cpu).as_slice::<f32>()[0]
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
use xtrain_optim::GpuAdamW;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let tok_path = positionals.first().expect("usage: train_dpo <tokenizer.json> <dpo.tsv> [flags]");
|
||||
let tsv_path = positionals.get(1).expect("usage: train_dpo <tokenizer.json> <dpo.tsv> [flags]");
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let beta: f32 = flag(&args, "--beta", 0.1);
|
||||
let steps: usize = flag(&args, "--steps", 1000);
|
||||
let lr: f32 = flag(&args, "--lr", 5e-7);
|
||||
let wd: f32 = flag(&args, "--wd", 0.0);
|
||||
let clip: f32 = flag(&args, "--clip", 1.0);
|
||||
let log_every: usize = flag(&args, "--log-every", 50);
|
||||
let init_ckpt = flag_value(&args, "--init-ckpt").expect("--init-ckpt <sft.ckpt> is required");
|
||||
let out_ckpt = flag_value(&args, "--ckpt").expect("--ckpt <out> is required");
|
||||
|
||||
// Load preference pairs: question<TAB>chosen<TAB>rejected.
|
||||
let raw = std::fs::read_to_string(tsv_path).expect("read dpo tsv");
|
||||
let pairs: Vec<(String, String, String)> = raw
|
||||
.lines()
|
||||
.filter(|l| !l.trim().is_empty())
|
||||
.map(|l| {
|
||||
let mut it = l.splitn(3, '\t');
|
||||
let q = it.next().expect("question").to_string();
|
||||
let c = it.next().expect("chosen").to_string();
|
||||
let r = it.next().expect("rejected").to_string();
|
||||
(q, c, r)
|
||||
})
|
||||
.collect();
|
||||
assert!(!pairs.is_empty(), "no DPO pairs in {tsv_path}");
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(std::path::Path::new(tok_path.as_str()));
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
||||
.with_kv_heads(kv_heads);
|
||||
let mut seed_init = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed_init = seed_init.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed_init, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed_init, 0.04)
|
||||
}
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(std::path::Path::new(&init_ckpt), &model.params())
|
||||
.expect("load SFT checkpoint");
|
||||
model.eval(); // DPO runs without dropout (deterministic logprobs)
|
||||
|
||||
// Pre-tokenize every pair once.
|
||||
let framed: Vec<((Vec<i32>, Vec<i32>), (Vec<i32>, Vec<i32>))> = pairs
|
||||
.iter()
|
||||
.map(|(q, c, r)| (frame(&tok, q, c), frame(&tok, q, r)))
|
||||
.collect();
|
||||
|
||||
// Reference logprobs: computed ONCE while policy == reference (SFT init), cached.
|
||||
println!("precomputing reference logprobs for {} pairs…", framed.len());
|
||||
let mut ref_c = Vec::with_capacity(framed.len());
|
||||
let mut ref_r = Vec::with_capacity(framed.len());
|
||||
for ((ci, ct), (ri, rt)) in &framed {
|
||||
ref_c.push(scalar(&seq_lp(&model, device, ci, ct)));
|
||||
ref_r.push(scalar(&seq_lp(&model, device, ri, rt)));
|
||||
}
|
||||
|
||||
let params = model.params();
|
||||
let mut opt = GpuAdamW::new(wd);
|
||||
let n = framed.len();
|
||||
// A fixed shuffle (LCG-strided) so steps sweep the dataset without bias.
|
||||
let mut order: Vec<usize> = (0..n).collect();
|
||||
let mut s = 0x9E3779B97F4A7C15u64;
|
||||
for i in (1..n).rev() {
|
||||
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
let j = (s >> 33) as usize % (i + 1);
|
||||
order.swap(i, j);
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let (mut win_loss, mut win_margin, mut win_acc) = (0f32, 0f32, 0usize);
|
||||
for step in 0..steps {
|
||||
let i = order[step % n];
|
||||
let ((ci, ct), (ri, rt)) = &framed[i];
|
||||
let lpc = seq_lp(&model, device, ci, ct);
|
||||
let lpr = seq_lp(&model, device, ri, rt);
|
||||
let (lpc_v, lpr_v) = (scalar(&lpc), scalar(&lpr));
|
||||
let margin = (lpc_v - ref_c[i]) - (lpr_v - ref_r[i]); // implicit reward margin
|
||||
let loss = ops::dpo_loss(&lpc, &lpr, ref_c[i], ref_r[i], beta);
|
||||
win_loss += scalar(&loss);
|
||||
win_margin += margin;
|
||||
win_acc += (margin > 0.0) as usize;
|
||||
|
||||
loss.backward();
|
||||
let _ = xtrain_train::clip::clip_grad_norm_gpu(¶ms, clip, 1.0);
|
||||
opt.step(lr, ¶ms);
|
||||
for p in ¶ms {
|
||||
p.zero_grad();
|
||||
}
|
||||
|
||||
if (step + 1) % log_every == 0 || step == steps - 1 {
|
||||
let w = log_every.min(step + 1) as f32;
|
||||
println!(
|
||||
"step {:5}/{steps}: loss {:.4} | reward-margin {:+.4} | pref-acc {:.1}% | {:.1}s",
|
||||
step + 1,
|
||||
win_loss / w,
|
||||
win_margin / w,
|
||||
100.0 * win_acc as f32 / w,
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
win_loss = 0.0;
|
||||
win_margin = 0.0;
|
||||
win_acc = 0;
|
||||
}
|
||||
}
|
||||
|
||||
xtrain_train::checkpoint::save(std::path::Path::new(&out_ckpt), ¶ms).expect("save ckpt");
|
||||
println!(
|
||||
"DPO done: {} pairs, {steps} steps, beta {beta}, lr {lr:.1e} → {out_ckpt}",
|
||||
framed.len()
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user