post-train: M1 — verifiable-arith eval scorer + SFT format-baseline result
eval_arith: load ckpt, greedy-generate per held-out prompt, parse \boxed{}
via the shared task checker, report format(boxed) + correctness pass-rates.
Reused as the verifiable-eval harness for M3 (DPO) / M4 (GRPO).
M1 result (100 held-out prompts, v12 1.05B base): SFT moves answer-format
adherence 0% -> 100%, arithmetic correctness 8% -- the intended split (SFT
buys the format; correctness is the verifiable-reward job of M3/M4). Logged
in docs/18 implementation log + a Phase-3 row in docs/evolution.md.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
189
crates/xtrain-train/src/bin/eval_arith.rs
Normal file
189
crates/xtrain-train/src/bin/eval_arith.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Verifiable-task eval (post-training, M1+). Load a checkpoint, greedily generate an
|
||||
//! answer for each held-out arithmetic prompt, parse the `\boxed{}` answer, and report
|
||||
//! the exact-match pass-rate against the gold file. Two signals are printed:
|
||||
//! **format** (fraction that emitted any boxed integer) and **correctness** (fraction
|
||||
//! whose boxed answer matches gold). This is the M1 format-baseline metric and the
|
||||
//! reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO).
|
||||
//!
|
||||
//! eval_arith <ckpt> <tokenizer.json> --heads 52 --head-dim 32 --kv-heads 13 \
|
||||
//! --layers 22 --ffn 6656 \
|
||||
//! --prompts-file <dir>/arith_eval_prompts.txt \
|
||||
//! --gold-file <dir>/arith_eval_gold.txt --max-tokens 48 --show 8
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("eval_arith: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::path::PathBuf;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::sample::generate;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{check_answer, parse_boxed_answer};
|
||||
|
||||
// Same deterministic LCG init scheme as bin/train.rs / bin/greedy_sample.rs (the
|
||||
// values are overwritten by the loaded checkpoint; init just shapes the tensors).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn decode_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// The model keeps generating past the answer (no EOS stop in the sampler), so keep
|
||||
/// only the first answer "turn": cut at the first `<|endoftext|>` and then at the
|
||||
/// first newline. The arithmetic answer is a single line, so this isolates it.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn first_answer_segment(continuation: &str) -> &str {
|
||||
let s = continuation
|
||||
.split("<|endoftext|>")
|
||||
.next()
|
||||
.unwrap_or(continuation);
|
||||
s.split('\n').next().unwrap_or(s)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let ckpt = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.expect("usage: eval_arith <ckpt> <tokenizer.json> [flags]");
|
||||
let tok_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let max_new = flag(&args, "--max-tokens", 48usize);
|
||||
let n_show = flag(&args, "--show", 8usize);
|
||||
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
|
||||
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
|
||||
|
||||
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
|
||||
// count and order line up with the gold file.
|
||||
let prompts: Vec<String> = std::fs::read_to_string(&prompts_file)
|
||||
.unwrap_or_else(|e| panic!("read prompts {prompts_file}: {e}"))
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|l| !l.is_empty() && !l.starts_with('#'))
|
||||
.map(decode_escapes)
|
||||
.collect();
|
||||
let golds: Vec<i64> = std::fs::read_to_string(&gold_file)
|
||||
.unwrap_or_else(|e| panic!("read gold {gold_file}: {e}"))
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| l.parse::<i64>().expect("gold line not an integer"))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
prompts.len(),
|
||||
golds.len(),
|
||||
"prompt/gold count mismatch ({} vs {})",
|
||||
prompts.len(),
|
||||
golds.len()
|
||||
);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
||||
.with_kv_heads(kv_heads);
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
println!(
|
||||
"eval_arith: ckpt {} | {} prompts | max_new {}",
|
||||
ckpt.display(),
|
||||
prompts.len(),
|
||||
max_new
|
||||
);
|
||||
|
||||
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
|
||||
let mut shown = 0usize;
|
||||
for (prompt, &gold) in prompts.iter().zip(&golds) {
|
||||
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
|
||||
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont);
|
||||
if parse_boxed_answer(seg).is_some() {
|
||||
n_boxed += 1;
|
||||
}
|
||||
let ok = check_answer(seg, gold);
|
||||
if ok {
|
||||
n_correct += 1;
|
||||
}
|
||||
if shown < n_show {
|
||||
let q = prompt.replace('\n', " ");
|
||||
println!(" [{}] gold={gold} got={seg:?} {}", q, if ok { "OK" } else { "x" });
|
||||
shown += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let n = prompts.len() as f64;
|
||||
println!(
|
||||
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
|
||||
n_boxed,
|
||||
prompts.len(),
|
||||
100.0 * n_boxed as f64 / n,
|
||||
n_correct,
|
||||
prompts.len(),
|
||||
100.0 * n_correct as f64 / n,
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user