post-train: M1 — verifiable arithmetic task + SFT data generator
First post-training milestone (docs/18). Lands the verifiable task + its data
pipeline, all verified host-side (no CUDA); the SFT run itself reuses the
existing --sft-tsv path on the GPU box.
- task.rs: the shared task spec — two-operand integer arithmetic, answer in
\boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based
reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward).
- gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out
arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train
deduped, eval disjoint from train.
- data.rs: extract assistant-only masking into a pure, testable sft_row()
(behavior-preserving; single-turn bit-identical to fbf4ac2).
Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass
(masking, SFT-target self-consistency over 2000 samples, parser edges, seed
determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives,
0 train/eval leakage. SFT training run + format-eval pending on dash5.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no
|
||||
//! CUDA): writes
|
||||
//! <out>/arith_sft.tsv user<TAB>assistant rows for `train --sft-tsv`
|
||||
//! <out>/arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out)
|
||||
//! <out>/arith_eval_gold.txt parallel gold integers for the checker
|
||||
//!
|
||||
//! Eval problems are deduped against train (no leakage). The SFT rows carry just the
|
||||
//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:`
|
||||
//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here
|
||||
//! reconstruct exactly that frame (`User: <q>\nAssistant:`, literal `\n` decoded by
|
||||
//! greedy_sample).
|
||||
//!
|
||||
//! Example:
|
||||
//! cargo run -p xtrain-train --release --bin gen_arith_task -- \
|
||||
//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xtrain_train::task::{GenConfig, Op, gen_problem};
|
||||
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let n_train: usize = flag(&args, "--n", 20000);
|
||||
let n_eval: usize = flag(&args, "--eval", 500);
|
||||
let seed: u64 = flag(&args, "--seed", 1);
|
||||
let max_add: i64 = flag(&args, "--max-add", 99);
|
||||
let max_mul: i64 = flag(&args, "--max-mul", 12);
|
||||
let out_dir: PathBuf = args
|
||||
.iter()
|
||||
.position(|a| a == "--out-dir")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from)
|
||||
.expect("--out-dir <dir> is required");
|
||||
|
||||
fs::create_dir_all(&out_dir).expect("create out dir");
|
||||
let cfg = GenConfig {
|
||||
max_add,
|
||||
max_mul,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
};
|
||||
let mut rng = seed.max(1);
|
||||
|
||||
// Train: dedup so the same problem is not repeated and so eval can be held out.
|
||||
let mut train_keys = HashSet::new();
|
||||
let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv"));
|
||||
while train_keys.len() < n_train {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
if !train_keys.insert(p.key()) {
|
||||
continue;
|
||||
}
|
||||
writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv");
|
||||
}
|
||||
tsv.flush().expect("flush tsv");
|
||||
|
||||
// Eval: disjoint from train (skip any key seen in train) and from itself.
|
||||
let mut prompts =
|
||||
BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval"));
|
||||
let mut golds =
|
||||
BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold"));
|
||||
writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)")
|
||||
.expect("write header");
|
||||
let mut eval_keys = HashSet::new();
|
||||
while eval_keys.len() < n_eval {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) {
|
||||
continue;
|
||||
}
|
||||
writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt");
|
||||
writeln!(golds, "{}", p.answer()).expect("write gold");
|
||||
}
|
||||
prompts.flush().expect("flush prompts");
|
||||
golds.flush().expect("flush golds");
|
||||
|
||||
println!(
|
||||
"wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})",
|
||||
train_keys.len(),
|
||||
eval_keys.len(),
|
||||
out_dir.display(),
|
||||
max_add,
|
||||
max_mul,
|
||||
seed
|
||||
);
|
||||
}
|
||||
@@ -124,10 +124,9 @@ impl Corpus {
|
||||
let answer = format!(" {assistant}\n<|endoftext|>");
|
||||
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend(answer_ids.iter().copied());
|
||||
tokens.extend(prompt_ids);
|
||||
tokens.extend(answer_ids);
|
||||
let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids);
|
||||
tokens.extend(row_tokens);
|
||||
labels.extend(row_labels);
|
||||
}
|
||||
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
||||
write_u16_cache(&token_cache, &tokens);
|
||||
@@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer
|
||||
/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy`
|
||||
/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer
|
||||
/// / no CUDA) so the assistant-only masking is unit-testable directly.
|
||||
fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec<i32>, Vec<i32>) {
|
||||
let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||
tokens.extend_from_slice(prompt_ids);
|
||||
tokens.extend_from_slice(answer_ids);
|
||||
let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend_from_slice(answer_ids);
|
||||
(tokens, labels)
|
||||
}
|
||||
|
||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||
/// sampling is reproducible from a single u64 seed.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
@@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 {
|
||||
.wrapping_add(1442695040888963407);
|
||||
*state >> 16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sft_row_masks_prompt_supervises_answer() {
|
||||
let prompt = [5, 6, 7];
|
||||
let answer = [8, 9]; // includes the EOS token in real use
|
||||
let (tokens, labels) = sft_row(&prompt, &answer);
|
||||
// Tokens are prompt then answer, in order.
|
||||
assert_eq!(tokens, vec![5, 6, 7, 8, 9]);
|
||||
// Prompt positions are ignore-index (-100); answer positions are supervised.
|
||||
assert_eq!(labels, vec![-100, -100, -100, 8, 9]);
|
||||
assert_eq!(tokens.len(), labels.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sft_row_handles_empty_answer() {
|
||||
let (tokens, labels) = sft_row(&[1, 2], &[]);
|
||||
assert_eq!(tokens, vec![1, 2]);
|
||||
assert_eq!(labels, vec![-100, -100]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
pub mod clip;
|
||||
pub mod data;
|
||||
pub mod schedule;
|
||||
pub mod task;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
|
||||
212
crates/xtrain-train/src/task.rs
Normal file
212
crates/xtrain-train/src/task.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer
|
||||
//! arithmetic task with a deterministic, rule-based checker: the assistant must end
|
||||
//! its answer with `\boxed{N}`, and the reward is exact-match on `N`.
|
||||
//!
|
||||
//! This single module is the shared task spec for the whole post-training stack —
|
||||
//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward
|
||||
//! scoring all parse/score through here, so the task lives in exactly one place.
|
||||
//!
|
||||
//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles
|
||||
//! and unit-tests on a GPU-less host.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// The supported binary operations.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum Op {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
}
|
||||
|
||||
impl Op {
|
||||
pub fn symbol(self) -> char {
|
||||
match self {
|
||||
Op::Add => '+',
|
||||
Op::Sub => '-',
|
||||
Op::Mul => '*',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Op {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.symbol())
|
||||
}
|
||||
}
|
||||
|
||||
/// A single two-operand arithmetic problem.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Problem {
|
||||
pub a: i64,
|
||||
pub b: i64,
|
||||
pub op: Op,
|
||||
}
|
||||
|
||||
impl Problem {
|
||||
/// The exact integer answer (the verifiable gold label).
|
||||
pub fn answer(self) -> i64 {
|
||||
match self.op {
|
||||
Op::Add => self.a + self.b,
|
||||
Op::Sub => self.a - self.b,
|
||||
Op::Mul => self.a * self.b,
|
||||
}
|
||||
}
|
||||
|
||||
/// The user-turn question text. No template wrapping — the SFT loader
|
||||
/// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame.
|
||||
pub fn question(self) -> String {
|
||||
format!("What is {} {} {}?", self.a, self.op, self.b)
|
||||
}
|
||||
|
||||
/// The assistant-turn SFT target: restate the equation and end with the boxed
|
||||
/// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`);
|
||||
/// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve.
|
||||
pub fn sft_answer(self) -> String {
|
||||
format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer())
|
||||
}
|
||||
|
||||
/// A stable dedup key, so eval problems can be held out from train.
|
||||
pub fn key(self) -> (i64, char, i64) {
|
||||
(self.a, self.op.symbol(), self.b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Operand-range configuration for problem sampling. Multiplication uses a smaller
|
||||
/// range (`max_mul`) so products stay modest; add/sub use `max_add`.
|
||||
#[derive(Clone)]
|
||||
pub struct GenConfig {
|
||||
pub max_add: i64,
|
||||
pub max_mul: i64,
|
||||
pub ops: Vec<Op>,
|
||||
}
|
||||
|
||||
impl Default for GenConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_add: 99,
|
||||
max_mul: 12,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn
|
||||
/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker /
|
||||
/// parser handle a leading `-`).
|
||||
pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem {
|
||||
let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()];
|
||||
let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add };
|
||||
let a = rand_range(rng, max);
|
||||
let b = rand_range(rng, max);
|
||||
Problem { a, b, op }
|
||||
}
|
||||
|
||||
/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there
|
||||
/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last"
|
||||
/// so a model that emits intermediate boxes still scores on its final answer.
|
||||
pub fn parse_boxed_answer(text: &str) -> Option<i64> {
|
||||
const TAG: &str = "\\boxed{";
|
||||
let mut found = None;
|
||||
let mut rest = text;
|
||||
while let Some(i) = rest.find(TAG) {
|
||||
let after = &rest[i + TAG.len()..];
|
||||
match after.find('}') {
|
||||
Some(j) => {
|
||||
if let Ok(n) = after[..j].trim().parse::<i64>() {
|
||||
found = Some(n);
|
||||
}
|
||||
rest = &after[j + 1..];
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
/// Verifiable reward: does the completion's boxed answer exactly match `gold`?
|
||||
pub fn check_answer(completion: &str, gold: i64) -> bool {
|
||||
parse_boxed_answer(completion) == Some(gold)
|
||||
}
|
||||
|
||||
/// `[0, max]` inclusive draw from the LCG.
|
||||
fn rand_range(rng: &mut u64, max: i64) -> i64 {
|
||||
debug_assert!(max >= 0);
|
||||
(next_rand(rng) % (max as u64 + 1)) as i64
|
||||
}
|
||||
|
||||
/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the
|
||||
/// task module stays dependency-free and host-only.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
*state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
*state >> 1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn answer_question_and_sft_target() {
|
||||
let p = Problem {
|
||||
a: 12,
|
||||
b: 13,
|
||||
op: Op::Mul,
|
||||
};
|
||||
assert_eq!(p.answer(), 156);
|
||||
assert_eq!(p.question(), "What is 12 * 13?");
|
||||
assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}.");
|
||||
let s = Problem {
|
||||
a: 3,
|
||||
b: 8,
|
||||
op: Op::Sub,
|
||||
};
|
||||
assert_eq!(s.answer(), -5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_takes_last_boxed_and_handles_edges() {
|
||||
assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156));
|
||||
assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7));
|
||||
assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42));
|
||||
assert_eq!(parse_boxed_answer("no box here"), None);
|
||||
assert_eq!(parse_boxed_answer("\\boxed{abc}"), None);
|
||||
assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_is_exact_match() {
|
||||
assert!(check_answer("the result is \\boxed{156}.", 156));
|
||||
assert!(!check_answer("the result is \\boxed{155}.", 156));
|
||||
assert!(!check_answer("no boxed answer at all", 156));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sft_target_is_always_self_consistent() {
|
||||
// The SFT target's boxed answer must always check against the problem's own
|
||||
// gold — across all ops/operands. This is the M1 data invariant.
|
||||
let cfg = GenConfig::default();
|
||||
let mut rng = 12345u64;
|
||||
for _ in 0..2000 {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
assert!(
|
||||
check_answer(&p.sft_answer(), p.answer()),
|
||||
"self-inconsistent SFT target for {p:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generation_is_deterministic_from_seed() {
|
||||
let cfg = GenConfig::default();
|
||||
let (mut r1, mut r2) = (7u64, 7u64);
|
||||
for _ in 0..200 {
|
||||
assert_eq!(
|
||||
gen_problem(&mut r1, &cfg).key(),
|
||||
gen_problem(&mut r2, &cfg).key()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user