post-train: M1 — verifiable arithmetic task + SFT data generator

First post-training milestone (docs/18). Lands the verifiable task + its data
pipeline, all verified host-side (no CUDA); the SFT run itself reuses the
existing --sft-tsv path on the GPU box.

- task.rs: the shared task spec — two-operand integer arithmetic, answer in
  \boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based
  reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward).
- gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out
  arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train
  deduped, eval disjoint from train.
- data.rs: extract assistant-only masking into a pure, testable sft_row()
  (behavior-preserving; single-turn bit-identical to fbf4ac2).

Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass
(masking, SFT-target self-consistency over 2000 samples, parser edges, seed
determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives,
0 train/eval leakage. SFT training run + format-eval pending on dash5.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-29 22:52:25 +08:00
parent ab32168dcc
commit 9c70e99ae4
5 changed files with 392 additions and 4 deletions

View File

@@ -0,0 +1,94 @@
//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no
//! CUDA): writes
//! <out>/arith_sft.tsv user<TAB>assistant rows for `train --sft-tsv`
//! <out>/arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out)
//! <out>/arith_eval_gold.txt parallel gold integers for the checker
//!
//! Eval problems are deduped against train (no leakage). The SFT rows carry just the
//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:`
//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here
//! reconstruct exactly that frame (`User: <q>\nAssistant:`, literal `\n` decoded by
//! greedy_sample).
//!
//! Example:
//! cargo run -p xtrain-train --release --bin gen_arith_task -- \
//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith
use std::collections::HashSet;
use std::fs::{self, File};
use std::io::{BufWriter, Write};
use std::path::PathBuf;
use xtrain_train::task::{GenConfig, Op, gen_problem};
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|v| v.parse().ok())
.unwrap_or(default)
}
fn main() {
let args: Vec<String> = std::env::args().collect();
let n_train: usize = flag(&args, "--n", 20000);
let n_eval: usize = flag(&args, "--eval", 500);
let seed: u64 = flag(&args, "--seed", 1);
let max_add: i64 = flag(&args, "--max-add", 99);
let max_mul: i64 = flag(&args, "--max-mul", 12);
let out_dir: PathBuf = args
.iter()
.position(|a| a == "--out-dir")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from)
.expect("--out-dir <dir> is required");
fs::create_dir_all(&out_dir).expect("create out dir");
let cfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let mut rng = seed.max(1);
// Train: dedup so the same problem is not repeated and so eval can be held out.
let mut train_keys = HashSet::new();
let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv"));
while train_keys.len() < n_train {
let p = gen_problem(&mut rng, &cfg);
if !train_keys.insert(p.key()) {
continue;
}
writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv");
}
tsv.flush().expect("flush tsv");
// Eval: disjoint from train (skip any key seen in train) and from itself.
let mut prompts =
BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval"));
let mut golds =
BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold"));
writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)")
.expect("write header");
let mut eval_keys = HashSet::new();
while eval_keys.len() < n_eval {
let p = gen_problem(&mut rng, &cfg);
if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) {
continue;
}
writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt");
writeln!(golds, "{}", p.answer()).expect("write gold");
}
prompts.flush().expect("flush prompts");
golds.flush().expect("flush golds");
println!(
"wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})",
train_keys.len(),
eval_keys.len(),
out_dir.display(),
max_add,
max_mul,
seed
);
}

View File

@@ -124,10 +124,9 @@ impl Corpus {
let answer = format!(" {assistant}\n<|endoftext|>"); let answer = format!(" {assistant}\n<|endoftext|>");
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect(); let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect(); let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
labels.extend(std::iter::repeat(-100).take(prompt_ids.len())); let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids);
labels.extend(answer_ids.iter().copied()); tokens.extend(row_tokens);
tokens.extend(prompt_ids); labels.extend(row_labels);
tokens.extend(answer_ids);
} }
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch"); assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
write_u16_cache(&token_cache, &tokens); write_u16_cache(&token_cache, &tokens);
@@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String {
s.replace("\\n", "\n").replace("\\t", "\t") s.replace("\\n", "\n").replace("\\t", "\t")
} }
/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer
/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy`
/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer
/// / no CUDA) so the assistant-only masking is unit-testable directly.
fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec<i32>, Vec<i32>) {
let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
tokens.extend_from_slice(prompt_ids);
tokens.extend_from_slice(answer_ids);
let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
labels.extend_from_slice(answer_ids);
(tokens, labels)
}
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset /// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
/// sampling is reproducible from a single u64 seed. /// sampling is reproducible from a single u64 seed.
fn next_rand(state: &mut u64) -> u64 { fn next_rand(state: &mut u64) -> u64 {
@@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 {
.wrapping_add(1442695040888963407); .wrapping_add(1442695040888963407);
*state >> 16 *state >> 16
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sft_row_masks_prompt_supervises_answer() {
let prompt = [5, 6, 7];
let answer = [8, 9]; // includes the EOS token in real use
let (tokens, labels) = sft_row(&prompt, &answer);
// Tokens are prompt then answer, in order.
assert_eq!(tokens, vec![5, 6, 7, 8, 9]);
// Prompt positions are ignore-index (-100); answer positions are supervised.
assert_eq!(labels, vec![-100, -100, -100, 8, 9]);
assert_eq!(tokens.len(), labels.len());
}
#[test]
fn sft_row_handles_empty_answer() {
let (tokens, labels) = sft_row(&[1, 2], &[]);
assert_eq!(tokens, vec![1, 2]);
assert_eq!(labels, vec![-100, -100]);
}
}

View File

@@ -10,6 +10,7 @@
pub mod clip; pub mod clip;
pub mod data; pub mod data;
pub mod schedule; pub mod schedule;
pub mod task;
#[cfg(not(no_cuda))] #[cfg(not(no_cuda))]
pub mod checkpoint; pub mod checkpoint;

View File

@@ -0,0 +1,212 @@
//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer
//! arithmetic task with a deterministic, rule-based checker: the assistant must end
//! its answer with `\boxed{N}`, and the reward is exact-match on `N`.
//!
//! This single module is the shared task spec for the whole post-training stack —
//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward
//! scoring all parse/score through here, so the task lives in exactly one place.
//!
//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles
//! and unit-tests on a GPU-less host.
use std::fmt;
/// The supported binary operations.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Op {
Add,
Sub,
Mul,
}
impl Op {
pub fn symbol(self) -> char {
match self {
Op::Add => '+',
Op::Sub => '-',
Op::Mul => '*',
}
}
}
impl fmt::Display for Op {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.symbol())
}
}
/// A single two-operand arithmetic problem.
#[derive(Clone, Copy, Debug)]
pub struct Problem {
pub a: i64,
pub b: i64,
pub op: Op,
}
impl Problem {
/// The exact integer answer (the verifiable gold label).
pub fn answer(self) -> i64 {
match self.op {
Op::Add => self.a + self.b,
Op::Sub => self.a - self.b,
Op::Mul => self.a * self.b,
}
}
/// The user-turn question text. No template wrapping — the SFT loader
/// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame.
pub fn question(self) -> String {
format!("What is {} {} {}?", self.a, self.op, self.b)
}
/// The assistant-turn SFT target: restate the equation and end with the boxed
/// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`);
/// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve.
pub fn sft_answer(self) -> String {
format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer())
}
/// A stable dedup key, so eval problems can be held out from train.
pub fn key(self) -> (i64, char, i64) {
(self.a, self.op.symbol(), self.b)
}
}
/// Operand-range configuration for problem sampling. Multiplication uses a smaller
/// range (`max_mul`) so products stay modest; add/sub use `max_add`.
#[derive(Clone)]
pub struct GenConfig {
pub max_add: i64,
pub max_mul: i64,
pub ops: Vec<Op>,
}
impl Default for GenConfig {
fn default() -> Self {
Self {
max_add: 99,
max_mul: 12,
ops: vec![Op::Add, Op::Sub, Op::Mul],
}
}
}
/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn
/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker /
/// parser handle a leading `-`).
pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem {
let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()];
let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add };
let a = rand_range(rng, max);
let b = rand_range(rng, max);
Problem { a, b, op }
}
/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there
/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last"
/// so a model that emits intermediate boxes still scores on its final answer.
pub fn parse_boxed_answer(text: &str) -> Option<i64> {
const TAG: &str = "\\boxed{";
let mut found = None;
let mut rest = text;
while let Some(i) = rest.find(TAG) {
let after = &rest[i + TAG.len()..];
match after.find('}') {
Some(j) => {
if let Ok(n) = after[..j].trim().parse::<i64>() {
found = Some(n);
}
rest = &after[j + 1..];
}
None => break,
}
}
found
}
/// Verifiable reward: does the completion's boxed answer exactly match `gold`?
pub fn check_answer(completion: &str, gold: i64) -> bool {
parse_boxed_answer(completion) == Some(gold)
}
/// `[0, max]` inclusive draw from the LCG.
fn rand_range(rng: &mut u64, max: i64) -> i64 {
debug_assert!(max >= 0);
(next_rand(rng) % (max as u64 + 1)) as i64
}
/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the
/// task module stays dependency-free and host-only.
fn next_rand(state: &mut u64) -> u64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*state >> 1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn answer_question_and_sft_target() {
let p = Problem {
a: 12,
b: 13,
op: Op::Mul,
};
assert_eq!(p.answer(), 156);
assert_eq!(p.question(), "What is 12 * 13?");
assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}.");
let s = Problem {
a: 3,
b: 8,
op: Op::Sub,
};
assert_eq!(s.answer(), -5);
}
#[test]
fn parse_takes_last_boxed_and_handles_edges() {
assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156));
assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7));
assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42));
assert_eq!(parse_boxed_answer("no box here"), None);
assert_eq!(parse_boxed_answer("\\boxed{abc}"), None);
assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None);
}
#[test]
fn check_is_exact_match() {
assert!(check_answer("the result is \\boxed{156}.", 156));
assert!(!check_answer("the result is \\boxed{155}.", 156));
assert!(!check_answer("no boxed answer at all", 156));
}
#[test]
fn sft_target_is_always_self_consistent() {
// The SFT target's boxed answer must always check against the problem's own
// gold — across all ops/operands. This is the M1 data invariant.
let cfg = GenConfig::default();
let mut rng = 12345u64;
for _ in 0..2000 {
let p = gen_problem(&mut rng, &cfg);
assert!(
check_answer(&p.sft_answer(), p.answer()),
"self-inconsistent SFT target for {p:?}"
);
}
}
#[test]
fn generation_is_deterministic_from_seed() {
let cfg = GenConfig::default();
let (mut r1, mut r2) = (7u64, 7u64);
for _ in 0..200 {
assert_eq!(
gen_problem(&mut r1, &cfg).key(),
gen_problem(&mut r2, &cfg).key()
);
}
}
}

View File

@@ -282,3 +282,47 @@ gates as flash/GQA.
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a > Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands. > row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
## Implementation log
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
runs on dash5.
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
- Two-operand integer arithmetic, ops `+ ×`; operands `[0,99]` for `+/`, `[0,12]` for `×`
(modest products); subtraction may be negative.
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
and M4 (GRPO reward).
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
Train rows are deduped; eval is held out from train (no leakage).
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
single-turn path is bit-identical to `fbf4ac2`).
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
train/eval leakage.
**Pending on dash5 (needs GPU):**
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
buys the format; correctness is M3/M4's job.