post-train: M1 — verifiable-arith eval scorer + SFT format-baseline result
eval_arith: load ckpt, greedy-generate per held-out prompt, parse \boxed{}
via the shared task checker, report format(boxed) + correctness pass-rates.
Reused as the verifiable-eval harness for M3 (DPO) / M4 (GRPO).
M1 result (100 held-out prompts, v12 1.05B base): SFT moves answer-format
adherence 0% -> 100%, arithmetic correctness 8% -- the intended split (SFT
buys the format; correctness is the verifiable-reward job of M3/M4). Logged
in docs/18 implementation log + a Phase-3 row in docs/evolution.md.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
189
crates/xtrain-train/src/bin/eval_arith.rs
Normal file
189
crates/xtrain-train/src/bin/eval_arith.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Verifiable-task eval (post-training, M1+). Load a checkpoint, greedily generate an
|
||||
//! answer for each held-out arithmetic prompt, parse the `\boxed{}` answer, and report
|
||||
//! the exact-match pass-rate against the gold file. Two signals are printed:
|
||||
//! **format** (fraction that emitted any boxed integer) and **correctness** (fraction
|
||||
//! whose boxed answer matches gold). This is the M1 format-baseline metric and the
|
||||
//! reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO).
|
||||
//!
|
||||
//! eval_arith <ckpt> <tokenizer.json> --heads 52 --head-dim 32 --kv-heads 13 \
|
||||
//! --layers 22 --ffn 6656 \
|
||||
//! --prompts-file <dir>/arith_eval_prompts.txt \
|
||||
//! --gold-file <dir>/arith_eval_gold.txt --max-tokens 48 --show 8
|
||||
|
||||
#[cfg(no_cuda)]
|
||||
fn main() {
|
||||
eprintln!("eval_arith: built without CUDA (no_cuda); run on a GPU host (dash5).");
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
use std::path::PathBuf;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::sample::generate;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::task::{check_answer, parse_boxed_answer};
|
||||
|
||||
// Same deterministic LCG init scheme as bin/train.rs / bin/greedy_sample.rs (the
|
||||
// values are overwritten by the loaded checkpoint; init just shapes the tensors).
|
||||
#[cfg(not(no_cuda))]
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn flag_value(args: &[String], name: &str) -> Option<String> {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn decode_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// The model keeps generating past the answer (no EOS stop in the sampler), so keep
|
||||
/// only the first answer "turn": cut at the first `<|endoftext|>` and then at the
|
||||
/// first newline. The arithmetic answer is a single line, so this isolates it.
|
||||
#[cfg(not(no_cuda))]
|
||||
fn first_answer_segment(continuation: &str) -> &str {
|
||||
let s = continuation
|
||||
.split("<|endoftext|>")
|
||||
.next()
|
||||
.unwrap_or(continuation);
|
||||
s.split('\n').next().unwrap_or(s)
|
||||
}
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
fn main() {
|
||||
use xserv_tokenizer::Tokenizer;
|
||||
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
|
||||
let ckpt = positionals
|
||||
.first()
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.expect("usage: eval_arith <ckpt> <tokenizer.json> [flags]");
|
||||
let tok_path = positionals
|
||||
.get(1)
|
||||
.map(|s| PathBuf::from(s.as_str()))
|
||||
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
|
||||
|
||||
let n_heads = flag(&args, "--heads", 52usize);
|
||||
let head_dim = flag(&args, "--head-dim", 32usize);
|
||||
let n_layers = flag(&args, "--layers", 22usize);
|
||||
let ffn = flag(&args, "--ffn", 6656usize);
|
||||
let kv_heads = flag(&args, "--kv-heads", n_heads);
|
||||
let max_new = flag(&args, "--max-tokens", 48usize);
|
||||
let n_show = flag(&args, "--show", 8usize);
|
||||
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
|
||||
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
|
||||
|
||||
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
|
||||
// count and order line up with the gold file.
|
||||
let prompts: Vec<String> = std::fs::read_to_string(&prompts_file)
|
||||
.unwrap_or_else(|e| panic!("read prompts {prompts_file}: {e}"))
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|l| !l.is_empty() && !l.starts_with('#'))
|
||||
.map(decode_escapes)
|
||||
.collect();
|
||||
let golds: Vec<i64> = std::fs::read_to_string(&gold_file)
|
||||
.unwrap_or_else(|e| panic!("read gold {gold_file}: {e}"))
|
||||
.lines()
|
||||
.map(str::trim)
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(|l| l.parse::<i64>().expect("gold line not an integer"))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
prompts.len(),
|
||||
golds.len(),
|
||||
"prompt/gold count mismatch ({} vs {})",
|
||||
prompts.len(),
|
||||
golds.len()
|
||||
);
|
||||
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let tok = Tokenizer::from_file(&tok_path);
|
||||
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
|
||||
.with_kv_heads(kv_heads);
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
|
||||
|
||||
println!(
|
||||
"eval_arith: ckpt {} | {} prompts | max_new {}",
|
||||
ckpt.display(),
|
||||
prompts.len(),
|
||||
max_new
|
||||
);
|
||||
|
||||
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
|
||||
let mut shown = 0usize;
|
||||
for (prompt, &gold) in prompts.iter().zip(&golds) {
|
||||
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
|
||||
let mut rng = 7u64;
|
||||
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
|
||||
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
|
||||
let seg = first_answer_segment(&cont);
|
||||
if parse_boxed_answer(seg).is_some() {
|
||||
n_boxed += 1;
|
||||
}
|
||||
let ok = check_answer(seg, gold);
|
||||
if ok {
|
||||
n_correct += 1;
|
||||
}
|
||||
if shown < n_show {
|
||||
let q = prompt.replace('\n', " ");
|
||||
println!(" [{}] gold={gold} got={seg:?} {}", q, if ok { "OK" } else { "x" });
|
||||
shown += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let n = prompts.len() as f64;
|
||||
println!(
|
||||
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
|
||||
n_boxed,
|
||||
prompts.len(),
|
||||
100.0 * n_boxed as f64 / n,
|
||||
n_correct,
|
||||
prompts.len(),
|
||||
100.0 * n_correct as f64 / n,
|
||||
);
|
||||
}
|
||||
@@ -285,16 +285,18 @@ gates as flash/GQA.
|
||||
|
||||
## Implementation log
|
||||
|
||||
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
|
||||
### M1 — SFT task baseline (landed)
|
||||
|
||||
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
|
||||
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
|
||||
runs on dash5.
|
||||
needed); the SFT run + eval ran on dash5 (1×5090). **Result: SFT moves answer-format
|
||||
adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys
|
||||
the format; correctness is M3/M4's job).**
|
||||
|
||||
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
|
||||
|
||||
- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,99]` for `+/−`, `[0,12]` for `×`
|
||||
(modest products); subtraction may be negative.
|
||||
- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,999]` for `+/−`, `[0,99]` for `×`
|
||||
(modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep
|
||||
the unique-key space ≫ requested rows — see the saturation guard below.)
|
||||
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
|
||||
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
|
||||
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
|
||||
@@ -306,7 +308,19 @@ runs on dash5.
|
||||
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
|
||||
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
|
||||
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
|
||||
Train rows are deduped; eval is held out from train (no leakage).
|
||||
Train rows are deduped; eval is held out from train (no leakage). A **saturation guard**
|
||||
(`unique_space()` + `assert need·5 ≤ space·4`) rejects requests that approach the unique-key
|
||||
space, since deduped train + disjoint eval near saturation get pathologically slow (or, for
|
||||
the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys,
|
||||
so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s).
|
||||
|
||||
**Scorer (`crates/xtrain-train/src/bin/eval_arith.rs`):** loads a checkpoint, greedily
|
||||
generates a continuation per held-out prompt, isolates the first answer segment (cut at the
|
||||
first `<|endoftext|>` then first newline), and reports two signals via the shared checker —
|
||||
**format** (fraction emitting any `\boxed{int}`) and **correctness** (exact-match vs. gold).
|
||||
This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the *naive*
|
||||
no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete
|
||||
motivation for M2 (the KV-cache decode engine).
|
||||
|
||||
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
|
||||
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
|
||||
@@ -318,11 +332,31 @@ invariant (always checker-correct over 2000 samples), parser edge cases, and see
|
||||
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
|
||||
train/eval leakage.
|
||||
|
||||
**Pending on dash5 (needs GPU):**
|
||||
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
|
||||
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
|
||||
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
|
||||
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
|
||||
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
|
||||
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
|
||||
buys the format; correctness is M3/M4's job.
|
||||
**Run (dash5, 1×5090, from the v12 1.05B base):**
|
||||
1. dataset: `gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir <dir>` → 20 000 train +
|
||||
500 held-out eval, 0 leakage.
|
||||
2. SFT: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt> --heads 52
|
||||
--head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256
|
||||
--batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt` → the P0
|
||||
reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s.
|
||||
3. eval: `eval_arith <ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt
|
||||
--gold-file <dir>/arith_eval_gold.txt --max-tokens 32`, base vs. SFT, on 100 held-out prompts.
|
||||
|
||||
**M1 result (100 held-out prompts, greedy, max_new 32):**
|
||||
|
||||
| checkpoint | format (`\boxed{}`) | correct (exact-match) |
|
||||
|---------------------|----------------------|-----------------------|
|
||||
| v12 base (pre-SFT) | 0 / 100 (0%) | 0 / 100 (0%) |
|
||||
| arith SFT | **100 / 100 (100%)** | 8 / 100 (8%) |
|
||||
|
||||
The base model never emits the format — it answers `"I don't know."` / restates the question
|
||||
and stops. SFT moves format **0% → 100%**: every completion cleanly restates the equation and
|
||||
boxes an integer (`46 * 80 = \boxed{3380}.`). Correctness is only **8%**: the format is fully
|
||||
learned but the *arithmetic* is the base model's own weak capability — e.g. it boxes 3380 for
|
||||
gold 3680, −10 for gold 5; it does get some right (`895 − 353 = \boxed{542}.` ✓). That residual
|
||||
gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close.
|
||||
|
||||
**Gate met:** format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the
|
||||
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
|
||||
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
|
||||
skill, which is downstream by design.
|
||||
|
||||
@@ -86,6 +86,17 @@ scaling 科学线(v0–v8)收官后,项目重启回到本职「学训练
|
||||
|
||||
> 📌 两条 integration 发现(非回归,pre-existing,记账):① **DDP 三个测试并行会争 2 卡 deadlock** → 文档/测试用 `--test-threads=1`(或标 serial)跑。② **fresh-train md5 run-to-run 不定**——反向 atomicAdd 归约序非确定 → 有效的确定性闸门是**导出(export)重确定性**(同 ckpt 重导 safetensors md5 逐位一致),**不是** fresh-train 复现。
|
||||
|
||||
## 三·六、Phase 3 后训练栈(SFT → KV-cache → DPO → GRPO,详见 [18-post-training-rl-sft.md](18-post-training-rl-sft.md))
|
||||
|
||||
Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(对齐方向)。锁定路线 DPO→GRPO(reward model 可选)、**rule-based 可验证 reward 优先**、**KV-cache 增量解码引擎前置自建**、任务取**可验证算术**(确定性 exact-match,给 RL 干净可证伪信号)。里程碑 M1(SFT baseline)→ M2(KV-cache 解码引擎,token-identical 闸门)→ M3(DPO)→ M4(GRPO)→ M5(可选 RM)。按维度落点:
|
||||
|
||||
- **算法**:后训练损失族——SFT(assistant-only masking,已有)→ DPO(`seq_logprob` 算子 + Bradley-Terry/σ(Δ) 偏好损失,frozen reference)→ GRPO(group-relative advantage,无 critic + clipped PG + KL leash)。每条沿用 Phase 1/2 闸门规矩:新损失/算子有限差分 grad-check + PyTorch parity + 退化检查(β→0 / G=1 / ε→∞ / ref==policy)+ 一条可证伪「真在学」信号(reward margin↑ / 合成 RL overfit)。
|
||||
- **Infra**:**KV-cache 增量解码引擎(M2,前置)**是这一阶段的硬核——per-layer K/V cache + 单 token 增量 forward(prompt 灌一次 cache 后逐 token 解码)+ ragged 批量解码。硬闸门 = **解码逐 token 等价于全重算 greedy**(同 xserv 导出闭环的逐位纪律),并先记解码吞吐 baseline(profile-first)。它是 DPO 造对 + GRPO rollout 的共享底座。
|
||||
- **数据集**:可验证任务自带数据生成器——两操作数整数算术(`+ − ×`),rule-based checker 读 `\boxed{}` 做 exact-match,是 M1 SFT 数据 + M3 造对 + M4 GRPO reward 的单一共享 spec。
|
||||
- **模型架构**:复用 v12 1.05B 基座,不动架构。
|
||||
|
||||
**M1(SFT task baseline,已落地)**:可验证算术任务 + 数据生成器 + 评分器一套,host-side 9/9 单测过(masking、SFT-target 自洽 2000 样、parser 边界、种子确定性)。dash5 单卡从 v12 基座 SFT(loss 4.68→~0.34,best val 0.386)。**100 留出题 eval:格式 `\boxed{}` 习得率 base 0% → SFT 100%;算术正确率 8%。**——SFT 只买**格式**(0%→100% 干净落地),算术正确性是 base 模型本身弱项(如 `46*80` 框成 3380),正是 M3/M4 的可验证 reward 要去补的残差。一条诚实账:M1 用的是**朴素无 KV-cache 采样器**(每 token 全量 forward),100 题已经很慢——这正是 M2 解码引擎前置的动机。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user