From 1574e21d8958bcbb9348e7df53fe1c93a03d0729 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 30 Jun 2026 11:13:19 +0800 Subject: [PATCH] =?UTF-8?q?post-train:=20M1=20=E2=80=94=20verifiable-arith?= =?UTF-8?q?=20eval=20scorer=20+=20SFT=20format-baseline=20result?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit eval_arith: load ckpt, greedy-generate per held-out prompt, parse \boxed{} via the shared task checker, report format(boxed) + correctness pass-rates. Reused as the verifiable-eval harness for M3 (DPO) / M4 (GRPO). M1 result (100 held-out prompts, v12 1.05B base): SFT moves answer-format adherence 0% -> 100%, arithmetic correctness 8% -- the intended split (SFT buys the format; correctness is the verifiable-reward job of M3/M4). Logged in docs/18 implementation log + a Phase-3 row in docs/evolution.md. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-train/src/bin/eval_arith.rs | 189 ++++++++++++++++++++++ docs/18-post-training-rl-sft.md | 62 +++++-- docs/evolution.md | 11 ++ 3 files changed, 248 insertions(+), 14 deletions(-) create mode 100644 crates/xtrain-train/src/bin/eval_arith.rs diff --git a/crates/xtrain-train/src/bin/eval_arith.rs b/crates/xtrain-train/src/bin/eval_arith.rs new file mode 100644 index 0000000..bdba382 --- /dev/null +++ b/crates/xtrain-train/src/bin/eval_arith.rs @@ -0,0 +1,189 @@ +//! Verifiable-task eval (post-training, M1+). Load a checkpoint, greedily generate an +//! answer for each held-out arithmetic prompt, parse the `\boxed{}` answer, and report +//! the exact-match pass-rate against the gold file. Two signals are printed: +//! **format** (fraction that emitted any boxed integer) and **correctness** (fraction +//! whose boxed answer matches gold). This is the M1 format-baseline metric and the +//! reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). +//! +//! eval_arith --heads 52 --head-dim 32 --kv-heads 13 \ +//! --layers 22 --ffn 6656 \ +//! --prompts-file /arith_eval_prompts.txt \ +//! --gold-file /arith_eval_gold.txt --max-tokens 48 --show 8 + +#[cfg(no_cuda)] +fn main() { + eprintln!("eval_arith: built without CUDA (no_cuda); run on a GPU host (dash5)."); +} + +#[cfg(not(no_cuda))] +use std::path::PathBuf; +#[cfg(not(no_cuda))] +use xtrain_cuda::device; +#[cfg(not(no_cuda))] +use xtrain_model::{Config, TinyTransformer}; +#[cfg(not(no_cuda))] +use xtrain_tensor::Device; +#[cfg(not(no_cuda))] +use xtrain_train::sample::generate; +#[cfg(not(no_cuda))] +use xtrain_train::task::{check_answer, parse_boxed_answer}; + +// Same deterministic LCG init scheme as bin/train.rs / bin/greedy_sample.rs (the +// values are overwritten by the loaded checkpoint; init just shapes the tensors). +#[cfg(not(no_cuda))] +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +#[cfg(not(no_cuda))] +fn flag(args: &[String], name: &str, default: T) -> T { + args.iter() + .position(|a| a == name) + .and_then(|i| args.get(i + 1)) + .and_then(|s| s.parse().ok()) + .unwrap_or(default) +} + +#[cfg(not(no_cuda))] +fn flag_value(args: &[String], name: &str) -> Option { + args.iter() + .position(|a| a == name) + .and_then(|i| args.get(i + 1)) + .cloned() +} + +#[cfg(not(no_cuda))] +fn decode_escapes(s: &str) -> String { + s.replace("\\n", "\n").replace("\\t", "\t") +} + +/// The model keeps generating past the answer (no EOS stop in the sampler), so keep +/// only the first answer "turn": cut at the first `<|endoftext|>` and then at the +/// first newline. The arithmetic answer is a single line, so this isolates it. +#[cfg(not(no_cuda))] +fn first_answer_segment(continuation: &str) -> &str { + let s = continuation + .split("<|endoftext|>") + .next() + .unwrap_or(continuation); + s.split('\n').next().unwrap_or(s) +} + +#[cfg(not(no_cuda))] +fn main() { + use xserv_tokenizer::Tokenizer; + + let args: Vec = std::env::args().collect(); + let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect(); + let ckpt = positionals + .first() + .map(|s| PathBuf::from(s.as_str())) + .expect("usage: eval_arith [flags]"); + let tok_path = positionals + .get(1) + .map(|s| PathBuf::from(s.as_str())) + .unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json")); + + let n_heads = flag(&args, "--heads", 52usize); + let head_dim = flag(&args, "--head-dim", 32usize); + let n_layers = flag(&args, "--layers", 22usize); + let ffn = flag(&args, "--ffn", 6656usize); + let kv_heads = flag(&args, "--kv-heads", n_heads); + let max_new = flag(&args, "--max-tokens", 48usize); + let n_show = flag(&args, "--show", 8usize); + let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required"); + let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required"); + + // Prompts: skip the `#` header / blank lines and decode escaped newlines so the + // count and order line up with the gold file. + let prompts: Vec = std::fs::read_to_string(&prompts_file) + .unwrap_or_else(|e| panic!("read prompts {prompts_file}: {e}")) + .lines() + .map(str::trim) + .filter(|l| !l.is_empty() && !l.starts_with('#')) + .map(decode_escapes) + .collect(); + let golds: Vec = std::fs::read_to_string(&gold_file) + .unwrap_or_else(|e| panic!("read gold {gold_file}: {e}")) + .lines() + .map(str::trim) + .filter(|l| !l.is_empty()) + .map(|l| l.parse::().expect("gold line not an integer")) + .collect(); + assert_eq!( + prompts.len(), + golds.len(), + "prompt/gold count mismatch ({} vs {})", + prompts.len(), + golds.len() + ); + + assert!(device::device_count().unwrap() > 0, "no CUDA device"); + device::set_device(0).unwrap(); + let device = Device::Cuda(0); + + let tok = Tokenizer::from_file(&tok_path); + let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn) + .with_kv_heads(kv_heads); + let mut seed = 1u64; + let model = TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.04) + } + }); + xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint"); + + println!( + "eval_arith: ckpt {} | {} prompts | max_new {}", + ckpt.display(), + prompts.len(), + max_new + ); + + let (mut n_boxed, mut n_correct) = (0usize, 0usize); + let mut shown = 0usize; + for (prompt, &gold) in prompts.iter().zip(&golds) { + let ids: Vec = tok.encode(prompt).into_iter().map(|t| t as i32).collect(); + let mut rng = 7u64; + let out = generate(&model, device, &ids, max_new, 0.0, &mut rng); + let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::>()); + let seg = first_answer_segment(&cont); + if parse_boxed_answer(seg).is_some() { + n_boxed += 1; + } + let ok = check_answer(seg, gold); + if ok { + n_correct += 1; + } + if shown < n_show { + let q = prompt.replace('\n', " "); + println!(" [{}] gold={gold} got={seg:?} {}", q, if ok { "OK" } else { "x" }); + shown += 1; + } + } + + let n = prompts.len() as f64; + println!( + "RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)", + n_boxed, + prompts.len(), + 100.0 * n_boxed as f64 / n, + n_correct, + prompts.len(), + 100.0 * n_correct as f64 / n, + ); +} diff --git a/docs/18-post-training-rl-sft.md b/docs/18-post-training-rl-sft.md index f3b60a2..c671fc4 100644 --- a/docs/18-post-training-rl-sft.md +++ b/docs/18-post-training-rl-sft.md @@ -285,16 +285,18 @@ gates as flash/GQA. ## Implementation log -### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box) +### M1 — SFT task baseline (landed) The verifiable task and its data pipeline are implemented and verified host-side (no CUDA -needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and -runs on dash5. +needed); the SFT run + eval ran on dash5 (1×5090). **Result: SFT moves answer-format +adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys +the format; correctness is M3/M4's job).** **Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):** -- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,99]` for `+/−`, `[0,12]` for `×` - (modest products); subtraction may be negative. +- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,999]` for `+/−`, `[0,99]` for `×` + (modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep + the unique-key space ≫ requested rows — see the saturation guard below.) - User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT; the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve. - Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer` @@ -306,7 +308,19 @@ runs on dash5. **Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):** writes `arith_sft.tsv` (`userassistant` for `--sft-tsv`), `arith_eval_prompts.txt` (`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints). -Train rows are deduped; eval is held out from train (no leakage). +Train rows are deduped; eval is held out from train (no leakage). A **saturation guard** +(`unique_space()` + `assert need·5 ≤ space·4`) rejects requests that approach the unique-key +space, since deduped train + disjoint eval near saturation get pathologically slow (or, for +the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys, +so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s). + +**Scorer (`crates/xtrain-train/src/bin/eval_arith.rs`):** loads a checkpoint, greedily +generates a continuation per held-out prompt, isolates the first answer segment (cut at the +first `<|endoftext|>` then first newline), and reports two signals via the shared checker — +**format** (fraction emitting any `\boxed{int}`) and **correctness** (exact-match vs. gold). +This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the *naive* +no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete +motivation for M2 (the KV-cache decode engine). **Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the @@ -318,11 +332,31 @@ invariant (always checker-correct over 2000 samples), parser edge cases, and see A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0 train/eval leakage. -**Pending on dash5 (needs GPU):** -1. generate the dataset: `gen_arith_task --n --eval --seed 1 --out-dir `; -2. SFT from the v12 base: `train /arith_sft.tsv --sft-tsv --init-ckpt - --bf16 --recompute --flash --ckpt ` → the P0 reference/init checkpoint; -3. format eval: `greedy_sample --prompts-file /arith_eval_prompts.txt`, - then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a - one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only - buys the format; correctness is M3/M4's job. +**Run (dash5, 1×5090, from the v12 1.05B base):** +1. dataset: `gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir ` → 20 000 train + + 500 held-out eval, 0 leakage. +2. SFT: `train /arith_sft.tsv --sft-tsv --init-ckpt --heads 52 + --head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256 + --batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt` → the P0 + reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s. +3. eval: `eval_arith --prompts-file /arith_eval_prompts.txt + --gold-file /arith_eval_gold.txt --max-tokens 32`, base vs. SFT, on 100 held-out prompts. + +**M1 result (100 held-out prompts, greedy, max_new 32):** + +| checkpoint | format (`\boxed{}`) | correct (exact-match) | +|---------------------|----------------------|-----------------------| +| v12 base (pre-SFT) | 0 / 100 (0%) | 0 / 100 (0%) | +| arith SFT | **100 / 100 (100%)** | 8 / 100 (8%) | + +The base model never emits the format — it answers `"I don't know."` / restates the question +and stops. SFT moves format **0% → 100%**: every completion cleanly restates the equation and +boxes an integer (`46 * 80 = \boxed{3380}.`). Correctness is only **8%**: the format is fully +learned but the *arithmetic* is the base model's own weak capability — e.g. it boxes 3380 for +gold 3680, −10 for gold 5; it does get some right (`895 − 353 = \boxed{542}.` ✓). That residual +gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close. + +**Gate met:** format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the +held-out correct > 0 confirms the checker + eval harness score real matches (not just format). +M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic +skill, which is downstream by design. diff --git a/docs/evolution.md b/docs/evolution.md index 355836d..e97e278 100644 --- a/docs/evolution.md +++ b/docs/evolution.md @@ -86,6 +86,17 @@ scaling 科学线(v0–v8)收官后,项目重启回到本职「学训练 > 📌 两条 integration 发现(非回归,pre-existing,记账):① **DDP 三个测试并行会争 2 卡 deadlock** → 文档/测试用 `--test-threads=1`(或标 serial)跑。② **fresh-train md5 run-to-run 不定**——反向 atomicAdd 归约序非确定 → 有效的确定性闸门是**导出(export)重确定性**(同 ckpt 重导 safetensors md5 逐位一致),**不是** fresh-train 复现。 +## 三·六、Phase 3 后训练栈(SFT → KV-cache → DPO → GRPO,详见 [18-post-training-rl-sft.md](18-post-training-rl-sft.md)) + +Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(对齐方向)。锁定路线 DPO→GRPO(reward model 可选)、**rule-based 可验证 reward 优先**、**KV-cache 增量解码引擎前置自建**、任务取**可验证算术**(确定性 exact-match,给 RL 干净可证伪信号)。里程碑 M1(SFT baseline)→ M2(KV-cache 解码引擎,token-identical 闸门)→ M3(DPO)→ M4(GRPO)→ M5(可选 RM)。按维度落点: + +- **算法**:后训练损失族——SFT(assistant-only masking,已有)→ DPO(`seq_logprob` 算子 + Bradley-Terry/σ(Δ) 偏好损失,frozen reference)→ GRPO(group-relative advantage,无 critic + clipped PG + KL leash)。每条沿用 Phase 1/2 闸门规矩:新损失/算子有限差分 grad-check + PyTorch parity + 退化检查(β→0 / G=1 / ε→∞ / ref==policy)+ 一条可证伪「真在学」信号(reward margin↑ / 合成 RL overfit)。 +- **Infra**:**KV-cache 增量解码引擎(M2,前置)**是这一阶段的硬核——per-layer K/V cache + 单 token 增量 forward(prompt 灌一次 cache 后逐 token 解码)+ ragged 批量解码。硬闸门 = **解码逐 token 等价于全重算 greedy**(同 xserv 导出闭环的逐位纪律),并先记解码吞吐 baseline(profile-first)。它是 DPO 造对 + GRPO rollout 的共享底座。 +- **数据集**:可验证任务自带数据生成器——两操作数整数算术(`+ − ×`),rule-based checker 读 `\boxed{}` 做 exact-match,是 M1 SFT 数据 + M3 造对 + M4 GRPO reward 的单一共享 spec。 +- **模型架构**:复用 v12 1.05B 基座,不动架构。 + +**M1(SFT task baseline,已落地)**:可验证算术任务 + 数据生成器 + 评分器一套,host-side 9/9 单测过(masking、SFT-target 自洽 2000 样、parser 边界、种子确定性)。dash5 单卡从 v12 基座 SFT(loss 4.68→~0.34,best val 0.386)。**100 留出题 eval:格式 `\boxed{}` 习得率 base 0% → SFT 100%;算术正确率 8%。**——SFT 只买**格式**(0%→100% 干净落地),算术正确性是 base 模型本身弱项(如 `46*80` 框成 3380),正是 M3/M4 的可验证 reward 要去补的残差。一条诚实账:M1 用的是**朴素无 KV-cache 采样器**(每 token 全量 forward),100 题已经很慢——这正是 M2 解码引擎前置的动机。 + ## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)) - **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。