post-train: M1 — verifiable-arith eval scorer + SFT format-baseline result

eval_arith: load ckpt, greedy-generate per held-out prompt, parse \boxed{}
via the shared task checker, report format(boxed) + correctness pass-rates.
Reused as the verifiable-eval harness for M3 (DPO) / M4 (GRPO).

M1 result (100 held-out prompts, v12 1.05B base): SFT moves answer-format
adherence 0% -> 100%, arithmetic correctness 8% -- the intended split (SFT
buys the format; correctness is the verifiable-reward job of M3/M4). Logged
in docs/18 implementation log + a Phase-3 row in docs/evolution.md.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 11:13:19 +08:00
parent cb64604496
commit 1574e21d89
3 changed files with 248 additions and 14 deletions

View File

@@ -0,0 +1,189 @@
//! Verifiable-task eval (post-training, M1+). Load a checkpoint, greedily generate an
//! answer for each held-out arithmetic prompt, parse the `\boxed{}` answer, and report
//! the exact-match pass-rate against the gold file. Two signals are printed:
//! **format** (fraction that emitted any boxed integer) and **correctness** (fraction
//! whose boxed answer matches gold). This is the M1 format-baseline metric and the
//! reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO).
//!
//! eval_arith <ckpt> <tokenizer.json> --heads 52 --head-dim 32 --kv-heads 13 \
//! --layers 22 --ffn 6656 \
//! --prompts-file <dir>/arith_eval_prompts.txt \
//! --gold-file <dir>/arith_eval_gold.txt --max-tokens 48 --show 8
#[cfg(no_cuda)]
fn main() {
eprintln!("eval_arith: built without CUDA (no_cuda); run on a GPU host (dash5).");
}
#[cfg(not(no_cuda))]
use std::path::PathBuf;
#[cfg(not(no_cuda))]
use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer};
#[cfg(not(no_cuda))]
use xtrain_tensor::Device;
#[cfg(not(no_cuda))]
use xtrain_train::sample::generate;
#[cfg(not(no_cuda))]
use xtrain_train::task::{check_answer, parse_boxed_answer};
// Same deterministic LCG init scheme as bin/train.rs / bin/greedy_sample.rs (the
// values are overwritten by the loaded checkpoint; init just shapes the tensors).
#[cfg(not(no_cuda))]
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
#[cfg(not(no_cuda))]
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|s| s.parse().ok())
.unwrap_or(default)
}
#[cfg(not(no_cuda))]
fn flag_value(args: &[String], name: &str) -> Option<String> {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.cloned()
}
#[cfg(not(no_cuda))]
fn decode_escapes(s: &str) -> String {
s.replace("\\n", "\n").replace("\\t", "\t")
}
/// The model keeps generating past the answer (no EOS stop in the sampler), so keep
/// only the first answer "turn": cut at the first `<|endoftext|>` and then at the
/// first newline. The arithmetic answer is a single line, so this isolates it.
#[cfg(not(no_cuda))]
fn first_answer_segment(continuation: &str) -> &str {
let s = continuation
.split("<|endoftext|>")
.next()
.unwrap_or(continuation);
s.split('\n').next().unwrap_or(s)
}
#[cfg(not(no_cuda))]
fn main() {
use xserv_tokenizer::Tokenizer;
let args: Vec<String> = std::env::args().collect();
let positionals: Vec<&String> = args[1..].iter().filter(|a| !a.starts_with("--")).collect();
let ckpt = positionals
.first()
.map(|s| PathBuf::from(s.as_str()))
.expect("usage: eval_arith <ckpt> <tokenizer.json> [flags]");
let tok_path = positionals
.get(1)
.map(|s| PathBuf::from(s.as_str()))
.unwrap_or_else(|| PathBuf::from("/opt/wjh/models/gpt2/tokenizer.json"));
let n_heads = flag(&args, "--heads", 52usize);
let head_dim = flag(&args, "--head-dim", 32usize);
let n_layers = flag(&args, "--layers", 22usize);
let ffn = flag(&args, "--ffn", 6656usize);
let kv_heads = flag(&args, "--kv-heads", n_heads);
let max_new = flag(&args, "--max-tokens", 48usize);
let n_show = flag(&args, "--show", 8usize);
let prompts_file = flag_value(&args, "--prompts-file").expect("--prompts-file is required");
let gold_file = flag_value(&args, "--gold-file").expect("--gold-file is required");
// Prompts: skip the `#` header / blank lines and decode escaped newlines so the
// count and order line up with the gold file.
let prompts: Vec<String> = std::fs::read_to_string(&prompts_file)
.unwrap_or_else(|e| panic!("read prompts {prompts_file}: {e}"))
.lines()
.map(str::trim)
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.map(decode_escapes)
.collect();
let golds: Vec<i64> = std::fs::read_to_string(&gold_file)
.unwrap_or_else(|e| panic!("read gold {gold_file}: {e}"))
.lines()
.map(str::trim)
.filter(|l| !l.is_empty())
.map(|l| l.parse::<i64>().expect("gold line not an integer"))
.collect();
assert_eq!(
prompts.len(),
golds.len(),
"prompt/gold count mismatch ({} vs {})",
prompts.len(),
golds.len()
);
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(&tok_path);
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
.with_kv_heads(kv_heads);
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.04)
}
});
xtrain_train::checkpoint::load_into(&ckpt, &model.params()).expect("load checkpoint");
println!(
"eval_arith: ckpt {} | {} prompts | max_new {}",
ckpt.display(),
prompts.len(),
max_new
);
let (mut n_boxed, mut n_correct) = (0usize, 0usize);
let mut shown = 0usize;
for (prompt, &gold) in prompts.iter().zip(&golds) {
let ids: Vec<i32> = tok.encode(prompt).into_iter().map(|t| t as i32).collect();
let mut rng = 7u64;
let out = generate(&model, device, &ids, max_new, 0.0, &mut rng);
let cont = tok.decode(&out[ids.len()..].iter().map(|&t| t as u32).collect::<Vec<_>>());
let seg = first_answer_segment(&cont);
if parse_boxed_answer(seg).is_some() {
n_boxed += 1;
}
let ok = check_answer(seg, gold);
if ok {
n_correct += 1;
}
if shown < n_show {
let q = prompt.replace('\n', " ");
println!(" [{}] gold={gold} got={seg:?} {}", q, if ok { "OK" } else { "x" });
shown += 1;
}
}
let n = prompts.len() as f64;
println!(
"RESULT format(boxed)={}/{} ({:.1}%) | correct={}/{} ({:.1}%)",
n_boxed,
prompts.len(),
100.0 * n_boxed as f64 / n,
n_correct,
prompts.len(),
100.0 * n_correct as f64 / n,
);
}

View File

@@ -285,16 +285,18 @@ gates as flash/GQA.
## Implementation log
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
### M1 — SFT task baseline (landed)
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
runs on dash5.
needed); the SFT run + eval ran on dash5 (1×5090). **Result: SFT moves answer-format
adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys
the format; correctness is M3/M4's job).**
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
- Two-operand integer arithmetic, ops `+ ×`; operands `[0,99]` for `+/`, `[0,12]` for `×`
(modest products); subtraction may be negative.
- Two-operand integer arithmetic, ops `+ ×`; operands `[0,999]` for `+/`, `[0,99]` for `×`
(modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep
the unique-key space ≫ requested rows — see the saturation guard below.)
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
@@ -306,7 +308,19 @@ runs on dash5.
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
Train rows are deduped; eval is held out from train (no leakage).
Train rows are deduped; eval is held out from train (no leakage). A **saturation guard**
(`unique_space()` + `assert need·5 ≤ space·4`) rejects requests that approach the unique-key
space, since deduped train + disjoint eval near saturation get pathologically slow (or, for
the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys,
so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s).
**Scorer (`crates/xtrain-train/src/bin/eval_arith.rs`):** loads a checkpoint, greedily
generates a continuation per held-out prompt, isolates the first answer segment (cut at the
first `<|endoftext|>` then first newline), and reports two signals via the shared checker —
**format** (fraction emitting any `\boxed{int}`) and **correctness** (exact-match vs. gold).
This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the *naive*
no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete
motivation for M2 (the KV-cache decode engine).
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
@@ -318,11 +332,31 @@ invariant (always checker-correct over 2000 samples), parser edge cases, and see
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
train/eval leakage.
**Pending on dash5 (needs GPU):**
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
buys the format; correctness is M3/M4's job.
**Run (dash5, 1×5090, from the v12 1.05B base):**
1. dataset: `gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir <dir>` → 20 000 train +
500 held-out eval, 0 leakage.
2. SFT: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt> --heads 52
--head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256
--batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt` → the P0
reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s.
3. eval: `eval_arith <ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt
--gold-file <dir>/arith_eval_gold.txt --max-tokens 32`, base vs. SFT, on 100 held-out prompts.
**M1 result (100 held-out prompts, greedy, max_new 32):**
| checkpoint | format (`\boxed{}`) | correct (exact-match) |
|---------------------|----------------------|-----------------------|
| v12 base (pre-SFT) | 0 / 100 (0%) | 0 / 100 (0%) |
| arith SFT | **100 / 100 (100%)** | 8 / 100 (8%) |
The base model never emits the format — it answers `"I don't know."` / restates the question
and stops. SFT moves format **0% → 100%**: every completion cleanly restates the equation and
boxes an integer (`46 * 80 = \boxed{3380}.`). Correctness is only **8%**: the format is fully
learned but the *arithmetic* is the base model's own weak capability — e.g. it boxes 3380 for
gold 3680, 10 for gold 5; it does get some right (`895 353 = \boxed{542}.` ✓). That residual
gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close.
**Gate met:** format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
skill, which is downstream by design.

View File

@@ -86,6 +86,17 @@ scaling 科学线v0v8收官后项目重启回到本职「学训练
> 📌 两条 integration 发现非回归pre-existing记账① **DDP 三个测试并行会争 2 卡 deadlock** → 文档/测试用 `--test-threads=1`(或标 serial跑。② **fresh-train md5 run-to-run 不定**——反向 atomicAdd 归约序非确定 → 有效的确定性闸门是**导出export重确定性**(同 ckpt 重导 safetensors md5 逐位一致),**不是** fresh-train 复现。
## 三·六、Phase 3 后训练栈SFT → KV-cache → DPO → GRPO详见 [18-post-training-rl-sft.md](18-post-training-rl-sft.md)
Phase 1/2 **预训练全栈**学完后Phase 3 转向**后训练 infra**对齐方向)。锁定路线 DPOGRPOreward model 可选)、**rule-based 可验证 reward 优先**、**KV-cache 增量解码引擎前置自建**、任务取**可验证算术**确定性 exact-match RL 干净可证伪信号)。里程碑 M1SFT baseline)→ M2KV-cache 解码引擎token-identical 闸门)→ M3DPO)→ M4GRPO)→ M5可选 RM)。按维度落点
- **算法**后训练损失族——SFTassistant-only masking已有)→ DPO`seq_logprob` 算子 + Bradley-Terry/σ(Δ) 偏好损失frozen reference)→ GRPOgroup-relative advantage critic + clipped PG + KL leash)。每条沿用 Phase 1/2 闸门规矩新损失/算子有限差分 grad-check + PyTorch parity + 退化检查β0 / G=1 / ε→∞ / ref==policy+ 一条可证伪真在学信号reward margin / 合成 RL overfit)。
- **Infra****KV-cache 增量解码引擎M2前置**是这一阶段的硬核——per-layer K/V cache + token 增量 forwardprompt 灌一次 cache 后逐 token 解码+ ragged 批量解码硬闸门 = **解码逐 token 等价于全重算 greedy**(同 xserv 导出闭环的逐位纪律并先记解码吞吐 baselineprofile-first)。它是 DPO 造对 + GRPO rollout 的共享底座
- **数据集**可验证任务自带数据生成器——两操作数整数算术`+ ×`rule-based checker `\boxed{}` exact-match M1 SFT 数据 + M3 造对 + M4 GRPO reward 的单一共享 spec
- **模型架构**复用 v12 1.05B 基座不动架构
**M1SFT task baseline已落地**可验证算术任务 + 数据生成器 + 评分器一套host-side 9/9 单测过maskingSFT-target 自洽 2000 parser 边界种子确定性)。dash5 单卡从 v12 基座 SFTloss 4.68→~0.34best val 0.386)。**100 留出题 eval格式 `\boxed{}` 习得率 base 0% SFT 100%算术正确率 8%。**——SFT 只买**格式**0%→100% 干净落地算术正确性是 base 模型本身弱项 `46*80` 框成 3380正是 M3/M4 的可验证 reward 要去补的残差一条诚实账M1 用的是**朴素无 KV-cache 采样器** token 全量 forward100 题已经很慢——这正是 M2 解码引擎前置的动机
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)
- **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。