From 9c70e99ae496e3c1cc1261cd73b37eec7a644640 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 29 Jun 2026 22:52:25 +0800 Subject: [PATCH] =?UTF-8?q?post-train:=20M1=20=E2=80=94=20verifiable=20ari?= =?UTF-8?q?thmetic=20task=20+=20SFT=20data=20generator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First post-training milestone (docs/18). Lands the verifiable task + its data pipeline, all verified host-side (no CUDA); the SFT run itself reuses the existing --sft-tsv path on the GPU box. - task.rs: the shared task spec — two-operand integer arithmetic, answer in \boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward). - gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train deduped, eval disjoint from train. - data.rs: extract assistant-only masking into a pure, testable sft_row() (behavior-preserving; single-turn bit-identical to fbf4ac2). Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass (masking, SFT-target self-consistency over 2000 samples, parser edges, seed determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives, 0 train/eval leakage. SFT training run + format-eval pending on dash5. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-train/src/bin/gen_arith_task.rs | 94 ++++++++ crates/xtrain-train/src/data.rs | 45 +++- crates/xtrain-train/src/lib.rs | 1 + crates/xtrain-train/src/task.rs | 212 ++++++++++++++++++ docs/18-post-training-rl-sft.md | 44 ++++ 5 files changed, 392 insertions(+), 4 deletions(-) create mode 100644 crates/xtrain-train/src/bin/gen_arith_task.rs create mode 100644 crates/xtrain-train/src/task.rs diff --git a/crates/xtrain-train/src/bin/gen_arith_task.rs b/crates/xtrain-train/src/bin/gen_arith_task.rs new file mode 100644 index 0000000..81f8a85 --- /dev/null +++ b/crates/xtrain-train/src/bin/gen_arith_task.rs @@ -0,0 +1,94 @@ +//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no +//! CUDA): writes +//! /arith_sft.tsv userassistant rows for `train --sft-tsv` +//! /arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out) +//! /arith_eval_gold.txt parallel gold integers for the checker +//! +//! Eval problems are deduped against train (no leakage). The SFT rows carry just the +//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:` +//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here +//! reconstruct exactly that frame (`User: \nAssistant:`, literal `\n` decoded by +//! greedy_sample). +//! +//! Example: +//! cargo run -p xtrain-train --release --bin gen_arith_task -- \ +//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith + +use std::collections::HashSet; +use std::fs::{self, File}; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; + +use xtrain_train::task::{GenConfig, Op, gen_problem}; + +fn flag(args: &[String], name: &str, default: T) -> T { + args.iter() + .position(|a| a == name) + .and_then(|i| args.get(i + 1)) + .and_then(|v| v.parse().ok()) + .unwrap_or(default) +} + +fn main() { + let args: Vec = std::env::args().collect(); + let n_train: usize = flag(&args, "--n", 20000); + let n_eval: usize = flag(&args, "--eval", 500); + let seed: u64 = flag(&args, "--seed", 1); + let max_add: i64 = flag(&args, "--max-add", 99); + let max_mul: i64 = flag(&args, "--max-mul", 12); + let out_dir: PathBuf = args + .iter() + .position(|a| a == "--out-dir") + .and_then(|i| args.get(i + 1)) + .map(PathBuf::from) + .expect("--out-dir is required"); + + fs::create_dir_all(&out_dir).expect("create out dir"); + let cfg = GenConfig { + max_add, + max_mul, + ops: vec![Op::Add, Op::Sub, Op::Mul], + }; + let mut rng = seed.max(1); + + // Train: dedup so the same problem is not repeated and so eval can be held out. + let mut train_keys = HashSet::new(); + let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv")); + while train_keys.len() < n_train { + let p = gen_problem(&mut rng, &cfg); + if !train_keys.insert(p.key()) { + continue; + } + writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv"); + } + tsv.flush().expect("flush tsv"); + + // Eval: disjoint from train (skip any key seen in train) and from itself. + let mut prompts = + BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval")); + let mut golds = + BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold")); + writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)") + .expect("write header"); + let mut eval_keys = HashSet::new(); + while eval_keys.len() < n_eval { + let p = gen_problem(&mut rng, &cfg); + if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) { + continue; + } + writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt"); + writeln!(golds, "{}", p.answer()).expect("write gold"); + } + prompts.flush().expect("flush prompts"); + golds.flush().expect("flush golds"); + + println!( + "wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})", + train_keys.len(), + eval_keys.len(), + out_dir.display(), + max_add, + max_mul, + seed + ); +} diff --git a/crates/xtrain-train/src/data.rs b/crates/xtrain-train/src/data.rs index 12bdc83..38523c9 100644 --- a/crates/xtrain-train/src/data.rs +++ b/crates/xtrain-train/src/data.rs @@ -124,10 +124,9 @@ impl Corpus { let answer = format!(" {assistant}\n<|endoftext|>"); let prompt_ids: Vec = tok.encode(&prompt).into_iter().map(|t| t as i32).collect(); let answer_ids: Vec = tok.encode(&answer).into_iter().map(|t| t as i32).collect(); - labels.extend(std::iter::repeat(-100).take(prompt_ids.len())); - labels.extend(answer_ids.iter().copied()); - tokens.extend(prompt_ids); - tokens.extend(answer_ids); + let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids); + tokens.extend(row_tokens); + labels.extend(row_labels); } assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch"); write_u16_cache(&token_cache, &tokens); @@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String { s.replace("\\n", "\n").replace("\\t", "\t") } +/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer +/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy` +/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer +/// / no CUDA) so the assistant-only masking is unit-testable directly. +fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec, Vec) { + let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len()); + tokens.extend_from_slice(prompt_ids); + tokens.extend_from_slice(answer_ids); + let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len()); + labels.extend(std::iter::repeat(-100).take(prompt_ids.len())); + labels.extend_from_slice(answer_ids); + (tokens, labels) +} + /// Tiny LCG (same constants as the model tests' deterministic fill) so dataset /// sampling is reproducible from a single u64 seed. fn next_rand(state: &mut u64) -> u64 { @@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 { .wrapping_add(1442695040888963407); *state >> 16 } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sft_row_masks_prompt_supervises_answer() { + let prompt = [5, 6, 7]; + let answer = [8, 9]; // includes the EOS token in real use + let (tokens, labels) = sft_row(&prompt, &answer); + // Tokens are prompt then answer, in order. + assert_eq!(tokens, vec![5, 6, 7, 8, 9]); + // Prompt positions are ignore-index (-100); answer positions are supervised. + assert_eq!(labels, vec![-100, -100, -100, 8, 9]); + assert_eq!(tokens.len(), labels.len()); + } + + #[test] + fn sft_row_handles_empty_answer() { + let (tokens, labels) = sft_row(&[1, 2], &[]); + assert_eq!(tokens, vec![1, 2]); + assert_eq!(labels, vec![-100, -100]); + } +} diff --git a/crates/xtrain-train/src/lib.rs b/crates/xtrain-train/src/lib.rs index 2c00d1d..58e9f85 100644 --- a/crates/xtrain-train/src/lib.rs +++ b/crates/xtrain-train/src/lib.rs @@ -10,6 +10,7 @@ pub mod clip; pub mod data; pub mod schedule; +pub mod task; #[cfg(not(no_cuda))] pub mod checkpoint; diff --git a/crates/xtrain-train/src/task.rs b/crates/xtrain-train/src/task.rs new file mode 100644 index 0000000..b1efb6b --- /dev/null +++ b/crates/xtrain-train/src/task.rs @@ -0,0 +1,212 @@ +//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer +//! arithmetic task with a deterministic, rule-based checker: the assistant must end +//! its answer with `\boxed{N}`, and the reward is exact-match on `N`. +//! +//! This single module is the shared task spec for the whole post-training stack — +//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward +//! scoring all parse/score through here, so the task lives in exactly one place. +//! +//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles +//! and unit-tests on a GPU-less host. + +use std::fmt; + +/// The supported binary operations. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Op { + Add, + Sub, + Mul, +} + +impl Op { + pub fn symbol(self) -> char { + match self { + Op::Add => '+', + Op::Sub => '-', + Op::Mul => '*', + } + } +} + +impl fmt::Display for Op { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.symbol()) + } +} + +/// A single two-operand arithmetic problem. +#[derive(Clone, Copy, Debug)] +pub struct Problem { + pub a: i64, + pub b: i64, + pub op: Op, +} + +impl Problem { + /// The exact integer answer (the verifiable gold label). + pub fn answer(self) -> i64 { + match self.op { + Op::Add => self.a + self.b, + Op::Sub => self.a - self.b, + Op::Mul => self.a * self.b, + } + } + + /// The user-turn question text. No template wrapping — the SFT loader + /// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame. + pub fn question(self) -> String { + format!("What is {} {} {}?", self.a, self.op, self.b) + } + + /// The assistant-turn SFT target: restate the equation and end with the boxed + /// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`); + /// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve. + pub fn sft_answer(self) -> String { + format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer()) + } + + /// A stable dedup key, so eval problems can be held out from train. + pub fn key(self) -> (i64, char, i64) { + (self.a, self.op.symbol(), self.b) + } +} + +/// Operand-range configuration for problem sampling. Multiplication uses a smaller +/// range (`max_mul`) so products stay modest; add/sub use `max_add`. +#[derive(Clone)] +pub struct GenConfig { + pub max_add: i64, + pub max_mul: i64, + pub ops: Vec, +} + +impl Default for GenConfig { + fn default() -> Self { + Self { + max_add: 99, + max_mul: 12, + ops: vec![Op::Add, Op::Sub, Op::Mul], + } + } +} + +/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn +/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker / +/// parser handle a leading `-`). +pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem { + let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()]; + let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add }; + let a = rand_range(rng, max); + let b = rand_range(rng, max); + Problem { a, b, op } +} + +/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there +/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last" +/// so a model that emits intermediate boxes still scores on its final answer. +pub fn parse_boxed_answer(text: &str) -> Option { + const TAG: &str = "\\boxed{"; + let mut found = None; + let mut rest = text; + while let Some(i) = rest.find(TAG) { + let after = &rest[i + TAG.len()..]; + match after.find('}') { + Some(j) => { + if let Ok(n) = after[..j].trim().parse::() { + found = Some(n); + } + rest = &after[j + 1..]; + } + None => break, + } + } + found +} + +/// Verifiable reward: does the completion's boxed answer exactly match `gold`? +pub fn check_answer(completion: &str, gold: i64) -> bool { + parse_boxed_answer(completion) == Some(gold) +} + +/// `[0, max]` inclusive draw from the LCG. +fn rand_range(rng: &mut u64, max: i64) -> i64 { + debug_assert!(max >= 0); + (next_rand(rng) % (max as u64 + 1)) as i64 +} + +/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the +/// task module stays dependency-free and host-only. +fn next_rand(state: &mut u64) -> u64 { + *state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + *state >> 1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn answer_question_and_sft_target() { + let p = Problem { + a: 12, + b: 13, + op: Op::Mul, + }; + assert_eq!(p.answer(), 156); + assert_eq!(p.question(), "What is 12 * 13?"); + assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}."); + let s = Problem { + a: 3, + b: 8, + op: Op::Sub, + }; + assert_eq!(s.answer(), -5); + } + + #[test] + fn parse_takes_last_boxed_and_handles_edges() { + assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156)); + assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7)); + assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42)); + assert_eq!(parse_boxed_answer("no box here"), None); + assert_eq!(parse_boxed_answer("\\boxed{abc}"), None); + assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None); + } + + #[test] + fn check_is_exact_match() { + assert!(check_answer("the result is \\boxed{156}.", 156)); + assert!(!check_answer("the result is \\boxed{155}.", 156)); + assert!(!check_answer("no boxed answer at all", 156)); + } + + #[test] + fn sft_target_is_always_self_consistent() { + // The SFT target's boxed answer must always check against the problem's own + // gold — across all ops/operands. This is the M1 data invariant. + let cfg = GenConfig::default(); + let mut rng = 12345u64; + for _ in 0..2000 { + let p = gen_problem(&mut rng, &cfg); + assert!( + check_answer(&p.sft_answer(), p.answer()), + "self-inconsistent SFT target for {p:?}" + ); + } + } + + #[test] + fn generation_is_deterministic_from_seed() { + let cfg = GenConfig::default(); + let (mut r1, mut r2) = (7u64, 7u64); + for _ in 0..200 { + assert_eq!( + gen_problem(&mut r1, &cfg).key(), + gen_problem(&mut r2, &cfg).key() + ); + } + } +} diff --git a/docs/18-post-training-rl-sft.md b/docs/18-post-training-rl-sft.md index 4fddd11..f3b60a2 100644 --- a/docs/18-post-training-rl-sft.md +++ b/docs/18-post-training-rl-sft.md @@ -282,3 +282,47 @@ gates as flash/GQA. > Each milestone is one design+gate cycle; results get appended here (like the run docs) and a > row in `docs/evolution.md` (algorithm/infra dimensions) when it lands. + +## Implementation log + +### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box) + +The verifiable task and its data pipeline are implemented and verified host-side (no CUDA +needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and +runs on dash5. + +**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):** + +- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,99]` for `+/−`, `[0,12]` for `×` + (modest products); subtraction may be negative. +- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT; + the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve. +- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer` + (exact match vs. gold). This is the single shared checker reused by M3 (pair construction) + and M4 (GRPO reward). +- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly + probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`). + +**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):** +writes `arith_sft.tsv` (`userassistant` for `--sft-tsv`), `arith_eval_prompts.txt` +(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints). +Train rows are deduped; eval is held out from train (no leakage). + +**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was +extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the +single-turn path is bit-identical to `fbf4ac2`). + +**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass, +including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency +invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism. +A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0 +train/eval leakage. + +**Pending on dash5 (needs GPU):** +1. generate the dataset: `gen_arith_task --n --eval --seed 1 --out-dir `; +2. SFT from the v12 base: `train /arith_sft.tsv --sft-tsv --init-ckpt + --bf16 --recompute --flash --ckpt ` → the P0 reference/init checkpoint; +3. format eval: `greedy_sample --prompts-file /arith_eval_prompts.txt`, + then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a + one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only + buys the format; correctness is M3/M4's job.