post-train: M1 — verifiable arithmetic task + SFT data generator
First post-training milestone (docs/18). Lands the verifiable task + its data
pipeline, all verified host-side (no CUDA); the SFT run itself reuses the
existing --sft-tsv path on the GPU box.
- task.rs: the shared task spec — two-operand integer arithmetic, answer in
\boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based
reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward).
- gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out
arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train
deduped, eval disjoint from train.
- data.rs: extract assistant-only masking into a pure, testable sft_row()
(behavior-preserving; single-turn bit-identical to fbf4ac2).
Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass
(masking, SFT-target self-consistency over 2000 samples, parser edges, seed
determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives,
0 train/eval leakage. SFT training run + format-eval pending on dash5.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no
|
||||||
|
//! CUDA): writes
|
||||||
|
//! <out>/arith_sft.tsv user<TAB>assistant rows for `train --sft-tsv`
|
||||||
|
//! <out>/arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out)
|
||||||
|
//! <out>/arith_eval_gold.txt parallel gold integers for the checker
|
||||||
|
//!
|
||||||
|
//! Eval problems are deduped against train (no leakage). The SFT rows carry just the
|
||||||
|
//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:`
|
||||||
|
//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here
|
||||||
|
//! reconstruct exactly that frame (`User: <q>\nAssistant:`, literal `\n` decoded by
|
||||||
|
//! greedy_sample).
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//! cargo run -p xtrain-train --release --bin gen_arith_task -- \
|
||||||
|
//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs::{self, File};
|
||||||
|
use std::io::{BufWriter, Write};
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use xtrain_train::task::{GenConfig, Op, gen_problem};
|
||||||
|
|
||||||
|
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||||
|
args.iter()
|
||||||
|
.position(|a| a == name)
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let args: Vec<String> = std::env::args().collect();
|
||||||
|
let n_train: usize = flag(&args, "--n", 20000);
|
||||||
|
let n_eval: usize = flag(&args, "--eval", 500);
|
||||||
|
let seed: u64 = flag(&args, "--seed", 1);
|
||||||
|
let max_add: i64 = flag(&args, "--max-add", 99);
|
||||||
|
let max_mul: i64 = flag(&args, "--max-mul", 12);
|
||||||
|
let out_dir: PathBuf = args
|
||||||
|
.iter()
|
||||||
|
.position(|a| a == "--out-dir")
|
||||||
|
.and_then(|i| args.get(i + 1))
|
||||||
|
.map(PathBuf::from)
|
||||||
|
.expect("--out-dir <dir> is required");
|
||||||
|
|
||||||
|
fs::create_dir_all(&out_dir).expect("create out dir");
|
||||||
|
let cfg = GenConfig {
|
||||||
|
max_add,
|
||||||
|
max_mul,
|
||||||
|
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||||
|
};
|
||||||
|
let mut rng = seed.max(1);
|
||||||
|
|
||||||
|
// Train: dedup so the same problem is not repeated and so eval can be held out.
|
||||||
|
let mut train_keys = HashSet::new();
|
||||||
|
let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv"));
|
||||||
|
while train_keys.len() < n_train {
|
||||||
|
let p = gen_problem(&mut rng, &cfg);
|
||||||
|
if !train_keys.insert(p.key()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv");
|
||||||
|
}
|
||||||
|
tsv.flush().expect("flush tsv");
|
||||||
|
|
||||||
|
// Eval: disjoint from train (skip any key seen in train) and from itself.
|
||||||
|
let mut prompts =
|
||||||
|
BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval"));
|
||||||
|
let mut golds =
|
||||||
|
BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold"));
|
||||||
|
writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)")
|
||||||
|
.expect("write header");
|
||||||
|
let mut eval_keys = HashSet::new();
|
||||||
|
while eval_keys.len() < n_eval {
|
||||||
|
let p = gen_problem(&mut rng, &cfg);
|
||||||
|
if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt");
|
||||||
|
writeln!(golds, "{}", p.answer()).expect("write gold");
|
||||||
|
}
|
||||||
|
prompts.flush().expect("flush prompts");
|
||||||
|
golds.flush().expect("flush golds");
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})",
|
||||||
|
train_keys.len(),
|
||||||
|
eval_keys.len(),
|
||||||
|
out_dir.display(),
|
||||||
|
max_add,
|
||||||
|
max_mul,
|
||||||
|
seed
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -124,10 +124,9 @@ impl Corpus {
|
|||||||
let answer = format!(" {assistant}\n<|endoftext|>");
|
let answer = format!(" {assistant}\n<|endoftext|>");
|
||||||
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||||
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids);
|
||||||
labels.extend(answer_ids.iter().copied());
|
tokens.extend(row_tokens);
|
||||||
tokens.extend(prompt_ids);
|
labels.extend(row_labels);
|
||||||
tokens.extend(answer_ids);
|
|
||||||
}
|
}
|
||||||
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
||||||
write_u16_cache(&token_cache, &tokens);
|
write_u16_cache(&token_cache, &tokens);
|
||||||
@@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String {
|
|||||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer
|
||||||
|
/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy`
|
||||||
|
/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer
|
||||||
|
/// / no CUDA) so the assistant-only masking is unit-testable directly.
|
||||||
|
fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec<i32>, Vec<i32>) {
|
||||||
|
let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||||
|
tokens.extend_from_slice(prompt_ids);
|
||||||
|
tokens.extend_from_slice(answer_ids);
|
||||||
|
let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||||
|
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||||
|
labels.extend_from_slice(answer_ids);
|
||||||
|
(tokens, labels)
|
||||||
|
}
|
||||||
|
|
||||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||||
/// sampling is reproducible from a single u64 seed.
|
/// sampling is reproducible from a single u64 seed.
|
||||||
fn next_rand(state: &mut u64) -> u64 {
|
fn next_rand(state: &mut u64) -> u64 {
|
||||||
@@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 {
|
|||||||
.wrapping_add(1442695040888963407);
|
.wrapping_add(1442695040888963407);
|
||||||
*state >> 16
|
*state >> 16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sft_row_masks_prompt_supervises_answer() {
|
||||||
|
let prompt = [5, 6, 7];
|
||||||
|
let answer = [8, 9]; // includes the EOS token in real use
|
||||||
|
let (tokens, labels) = sft_row(&prompt, &answer);
|
||||||
|
// Tokens are prompt then answer, in order.
|
||||||
|
assert_eq!(tokens, vec![5, 6, 7, 8, 9]);
|
||||||
|
// Prompt positions are ignore-index (-100); answer positions are supervised.
|
||||||
|
assert_eq!(labels, vec![-100, -100, -100, 8, 9]);
|
||||||
|
assert_eq!(tokens.len(), labels.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sft_row_handles_empty_answer() {
|
||||||
|
let (tokens, labels) = sft_row(&[1, 2], &[]);
|
||||||
|
assert_eq!(tokens, vec![1, 2]);
|
||||||
|
assert_eq!(labels, vec![-100, -100]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
pub mod clip;
|
pub mod clip;
|
||||||
pub mod data;
|
pub mod data;
|
||||||
pub mod schedule;
|
pub mod schedule;
|
||||||
|
pub mod task;
|
||||||
|
|
||||||
#[cfg(not(no_cuda))]
|
#[cfg(not(no_cuda))]
|
||||||
pub mod checkpoint;
|
pub mod checkpoint;
|
||||||
|
|||||||
212
crates/xtrain-train/src/task.rs
Normal file
212
crates/xtrain-train/src/task.rs
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer
|
||||||
|
//! arithmetic task with a deterministic, rule-based checker: the assistant must end
|
||||||
|
//! its answer with `\boxed{N}`, and the reward is exact-match on `N`.
|
||||||
|
//!
|
||||||
|
//! This single module is the shared task spec for the whole post-training stack —
|
||||||
|
//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward
|
||||||
|
//! scoring all parse/score through here, so the task lives in exactly one place.
|
||||||
|
//!
|
||||||
|
//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles
|
||||||
|
//! and unit-tests on a GPU-less host.
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// The supported binary operations.
|
||||||
|
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||||
|
pub enum Op {
|
||||||
|
Add,
|
||||||
|
Sub,
|
||||||
|
Mul,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Op {
|
||||||
|
pub fn symbol(self) -> char {
|
||||||
|
match self {
|
||||||
|
Op::Add => '+',
|
||||||
|
Op::Sub => '-',
|
||||||
|
Op::Mul => '*',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for Op {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "{}", self.symbol())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single two-operand arithmetic problem.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct Problem {
|
||||||
|
pub a: i64,
|
||||||
|
pub b: i64,
|
||||||
|
pub op: Op,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Problem {
|
||||||
|
/// The exact integer answer (the verifiable gold label).
|
||||||
|
pub fn answer(self) -> i64 {
|
||||||
|
match self.op {
|
||||||
|
Op::Add => self.a + self.b,
|
||||||
|
Op::Sub => self.a - self.b,
|
||||||
|
Op::Mul => self.a * self.b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The user-turn question text. No template wrapping — the SFT loader
|
||||||
|
/// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame.
|
||||||
|
pub fn question(self) -> String {
|
||||||
|
format!("What is {} {} {}?", self.a, self.op, self.b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The assistant-turn SFT target: restate the equation and end with the boxed
|
||||||
|
/// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`);
|
||||||
|
/// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve.
|
||||||
|
pub fn sft_answer(self) -> String {
|
||||||
|
format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A stable dedup key, so eval problems can be held out from train.
|
||||||
|
pub fn key(self) -> (i64, char, i64) {
|
||||||
|
(self.a, self.op.symbol(), self.b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Operand-range configuration for problem sampling. Multiplication uses a smaller
|
||||||
|
/// range (`max_mul`) so products stay modest; add/sub use `max_add`.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct GenConfig {
|
||||||
|
pub max_add: i64,
|
||||||
|
pub max_mul: i64,
|
||||||
|
pub ops: Vec<Op>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GenConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_add: 99,
|
||||||
|
max_mul: 12,
|
||||||
|
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn
|
||||||
|
/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker /
|
||||||
|
/// parser handle a leading `-`).
|
||||||
|
pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem {
|
||||||
|
let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()];
|
||||||
|
let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add };
|
||||||
|
let a = rand_range(rng, max);
|
||||||
|
let b = rand_range(rng, max);
|
||||||
|
Problem { a, b, op }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there
|
||||||
|
/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last"
|
||||||
|
/// so a model that emits intermediate boxes still scores on its final answer.
|
||||||
|
pub fn parse_boxed_answer(text: &str) -> Option<i64> {
|
||||||
|
const TAG: &str = "\\boxed{";
|
||||||
|
let mut found = None;
|
||||||
|
let mut rest = text;
|
||||||
|
while let Some(i) = rest.find(TAG) {
|
||||||
|
let after = &rest[i + TAG.len()..];
|
||||||
|
match after.find('}') {
|
||||||
|
Some(j) => {
|
||||||
|
if let Ok(n) = after[..j].trim().parse::<i64>() {
|
||||||
|
found = Some(n);
|
||||||
|
}
|
||||||
|
rest = &after[j + 1..];
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
found
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verifiable reward: does the completion's boxed answer exactly match `gold`?
|
||||||
|
pub fn check_answer(completion: &str, gold: i64) -> bool {
|
||||||
|
parse_boxed_answer(completion) == Some(gold)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `[0, max]` inclusive draw from the LCG.
|
||||||
|
fn rand_range(rng: &mut u64, max: i64) -> i64 {
|
||||||
|
debug_assert!(max >= 0);
|
||||||
|
(next_rand(rng) % (max as u64 + 1)) as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the
|
||||||
|
/// task module stays dependency-free and host-only.
|
||||||
|
fn next_rand(state: &mut u64) -> u64 {
|
||||||
|
*state = state
|
||||||
|
.wrapping_mul(6364136223846793005)
|
||||||
|
.wrapping_add(1442695040888963407);
|
||||||
|
*state >> 1
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn answer_question_and_sft_target() {
|
||||||
|
let p = Problem {
|
||||||
|
a: 12,
|
||||||
|
b: 13,
|
||||||
|
op: Op::Mul,
|
||||||
|
};
|
||||||
|
assert_eq!(p.answer(), 156);
|
||||||
|
assert_eq!(p.question(), "What is 12 * 13?");
|
||||||
|
assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}.");
|
||||||
|
let s = Problem {
|
||||||
|
a: 3,
|
||||||
|
b: 8,
|
||||||
|
op: Op::Sub,
|
||||||
|
};
|
||||||
|
assert_eq!(s.answer(), -5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_takes_last_boxed_and_handles_edges() {
|
||||||
|
assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156));
|
||||||
|
assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7));
|
||||||
|
assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42));
|
||||||
|
assert_eq!(parse_boxed_answer("no box here"), None);
|
||||||
|
assert_eq!(parse_boxed_answer("\\boxed{abc}"), None);
|
||||||
|
assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn check_is_exact_match() {
|
||||||
|
assert!(check_answer("the result is \\boxed{156}.", 156));
|
||||||
|
assert!(!check_answer("the result is \\boxed{155}.", 156));
|
||||||
|
assert!(!check_answer("no boxed answer at all", 156));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sft_target_is_always_self_consistent() {
|
||||||
|
// The SFT target's boxed answer must always check against the problem's own
|
||||||
|
// gold — across all ops/operands. This is the M1 data invariant.
|
||||||
|
let cfg = GenConfig::default();
|
||||||
|
let mut rng = 12345u64;
|
||||||
|
for _ in 0..2000 {
|
||||||
|
let p = gen_problem(&mut rng, &cfg);
|
||||||
|
assert!(
|
||||||
|
check_answer(&p.sft_answer(), p.answer()),
|
||||||
|
"self-inconsistent SFT target for {p:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn generation_is_deterministic_from_seed() {
|
||||||
|
let cfg = GenConfig::default();
|
||||||
|
let (mut r1, mut r2) = (7u64, 7u64);
|
||||||
|
for _ in 0..200 {
|
||||||
|
assert_eq!(
|
||||||
|
gen_problem(&mut r1, &cfg).key(),
|
||||||
|
gen_problem(&mut r2, &cfg).key()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -282,3 +282,47 @@ gates as flash/GQA.
|
|||||||
|
|
||||||
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
|
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
|
||||||
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
|
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
|
||||||
|
|
||||||
|
## Implementation log
|
||||||
|
|
||||||
|
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
|
||||||
|
|
||||||
|
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
|
||||||
|
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
|
||||||
|
runs on dash5.
|
||||||
|
|
||||||
|
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
|
||||||
|
|
||||||
|
- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,99]` for `+/−`, `[0,12]` for `×`
|
||||||
|
(modest products); subtraction may be negative.
|
||||||
|
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
|
||||||
|
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
|
||||||
|
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
|
||||||
|
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
|
||||||
|
and M4 (GRPO reward).
|
||||||
|
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
|
||||||
|
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
|
||||||
|
|
||||||
|
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
|
||||||
|
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
|
||||||
|
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
|
||||||
|
Train rows are deduped; eval is held out from train (no leakage).
|
||||||
|
|
||||||
|
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
|
||||||
|
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
|
||||||
|
single-turn path is bit-identical to `fbf4ac2`).
|
||||||
|
|
||||||
|
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
|
||||||
|
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
|
||||||
|
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
|
||||||
|
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
|
||||||
|
train/eval leakage.
|
||||||
|
|
||||||
|
**Pending on dash5 (needs GPU):**
|
||||||
|
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
|
||||||
|
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
|
||||||
|
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
|
||||||
|
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
|
||||||
|
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
|
||||||
|
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
|
||||||
|
buys the format; correctness is M3/M4's job.
|
||||||
|
|||||||
Reference in New Issue
Block a user