Compare commits
2 Commits
7a1fba95b5
...
9c70e99ae4
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c70e99ae4 | |||
| ab32168dcc |
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
94
crates/xtrain-train/src/bin/gen_arith_task.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no
|
||||
//! CUDA): writes
|
||||
//! <out>/arith_sft.tsv user<TAB>assistant rows for `train --sft-tsv`
|
||||
//! <out>/arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out)
|
||||
//! <out>/arith_eval_gold.txt parallel gold integers for the checker
|
||||
//!
|
||||
//! Eval problems are deduped against train (no leakage). The SFT rows carry just the
|
||||
//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:`
|
||||
//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here
|
||||
//! reconstruct exactly that frame (`User: <q>\nAssistant:`, literal `\n` decoded by
|
||||
//! greedy_sample).
|
||||
//!
|
||||
//! Example:
|
||||
//! cargo run -p xtrain-train --release --bin gen_arith_task -- \
|
||||
//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use xtrain_train::task::{GenConfig, Op, gen_problem};
|
||||
|
||||
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
|
||||
args.iter()
|
||||
.position(|a| a == name)
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let n_train: usize = flag(&args, "--n", 20000);
|
||||
let n_eval: usize = flag(&args, "--eval", 500);
|
||||
let seed: u64 = flag(&args, "--seed", 1);
|
||||
let max_add: i64 = flag(&args, "--max-add", 99);
|
||||
let max_mul: i64 = flag(&args, "--max-mul", 12);
|
||||
let out_dir: PathBuf = args
|
||||
.iter()
|
||||
.position(|a| a == "--out-dir")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.map(PathBuf::from)
|
||||
.expect("--out-dir <dir> is required");
|
||||
|
||||
fs::create_dir_all(&out_dir).expect("create out dir");
|
||||
let cfg = GenConfig {
|
||||
max_add,
|
||||
max_mul,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
};
|
||||
let mut rng = seed.max(1);
|
||||
|
||||
// Train: dedup so the same problem is not repeated and so eval can be held out.
|
||||
let mut train_keys = HashSet::new();
|
||||
let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv"));
|
||||
while train_keys.len() < n_train {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
if !train_keys.insert(p.key()) {
|
||||
continue;
|
||||
}
|
||||
writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv");
|
||||
}
|
||||
tsv.flush().expect("flush tsv");
|
||||
|
||||
// Eval: disjoint from train (skip any key seen in train) and from itself.
|
||||
let mut prompts =
|
||||
BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval"));
|
||||
let mut golds =
|
||||
BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold"));
|
||||
writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)")
|
||||
.expect("write header");
|
||||
let mut eval_keys = HashSet::new();
|
||||
while eval_keys.len() < n_eval {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) {
|
||||
continue;
|
||||
}
|
||||
writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt");
|
||||
writeln!(golds, "{}", p.answer()).expect("write gold");
|
||||
}
|
||||
prompts.flush().expect("flush prompts");
|
||||
golds.flush().expect("flush golds");
|
||||
|
||||
println!(
|
||||
"wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})",
|
||||
train_keys.len(),
|
||||
eval_keys.len(),
|
||||
out_dir.display(),
|
||||
max_add,
|
||||
max_mul,
|
||||
seed
|
||||
);
|
||||
}
|
||||
@@ -124,10 +124,9 @@ impl Corpus {
|
||||
let answer = format!(" {assistant}\n<|endoftext|>");
|
||||
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
|
||||
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend(answer_ids.iter().copied());
|
||||
tokens.extend(prompt_ids);
|
||||
tokens.extend(answer_ids);
|
||||
let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids);
|
||||
tokens.extend(row_tokens);
|
||||
labels.extend(row_labels);
|
||||
}
|
||||
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
|
||||
write_u16_cache(&token_cache, &tokens);
|
||||
@@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String {
|
||||
s.replace("\\n", "\n").replace("\\t", "\t")
|
||||
}
|
||||
|
||||
/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer
|
||||
/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy`
|
||||
/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer
|
||||
/// / no CUDA) so the assistant-only masking is unit-testable directly.
|
||||
fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec<i32>, Vec<i32>) {
|
||||
let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||
tokens.extend_from_slice(prompt_ids);
|
||||
tokens.extend_from_slice(answer_ids);
|
||||
let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
|
||||
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
|
||||
labels.extend_from_slice(answer_ids);
|
||||
(tokens, labels)
|
||||
}
|
||||
|
||||
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
|
||||
/// sampling is reproducible from a single u64 seed.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
@@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 {
|
||||
.wrapping_add(1442695040888963407);
|
||||
*state >> 16
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sft_row_masks_prompt_supervises_answer() {
|
||||
let prompt = [5, 6, 7];
|
||||
let answer = [8, 9]; // includes the EOS token in real use
|
||||
let (tokens, labels) = sft_row(&prompt, &answer);
|
||||
// Tokens are prompt then answer, in order.
|
||||
assert_eq!(tokens, vec![5, 6, 7, 8, 9]);
|
||||
// Prompt positions are ignore-index (-100); answer positions are supervised.
|
||||
assert_eq!(labels, vec![-100, -100, -100, 8, 9]);
|
||||
assert_eq!(tokens.len(), labels.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sft_row_handles_empty_answer() {
|
||||
let (tokens, labels) = sft_row(&[1, 2], &[]);
|
||||
assert_eq!(tokens, vec![1, 2]);
|
||||
assert_eq!(labels, vec![-100, -100]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
pub mod clip;
|
||||
pub mod data;
|
||||
pub mod schedule;
|
||||
pub mod task;
|
||||
|
||||
#[cfg(not(no_cuda))]
|
||||
pub mod checkpoint;
|
||||
|
||||
212
crates/xtrain-train/src/task.rs
Normal file
212
crates/xtrain-train/src/task.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer
|
||||
//! arithmetic task with a deterministic, rule-based checker: the assistant must end
|
||||
//! its answer with `\boxed{N}`, and the reward is exact-match on `N`.
|
||||
//!
|
||||
//! This single module is the shared task spec for the whole post-training stack —
|
||||
//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward
|
||||
//! scoring all parse/score through here, so the task lives in exactly one place.
|
||||
//!
|
||||
//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles
|
||||
//! and unit-tests on a GPU-less host.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// The supported binary operations.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum Op {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
}
|
||||
|
||||
impl Op {
|
||||
pub fn symbol(self) -> char {
|
||||
match self {
|
||||
Op::Add => '+',
|
||||
Op::Sub => '-',
|
||||
Op::Mul => '*',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Op {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.symbol())
|
||||
}
|
||||
}
|
||||
|
||||
/// A single two-operand arithmetic problem.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Problem {
|
||||
pub a: i64,
|
||||
pub b: i64,
|
||||
pub op: Op,
|
||||
}
|
||||
|
||||
impl Problem {
|
||||
/// The exact integer answer (the verifiable gold label).
|
||||
pub fn answer(self) -> i64 {
|
||||
match self.op {
|
||||
Op::Add => self.a + self.b,
|
||||
Op::Sub => self.a - self.b,
|
||||
Op::Mul => self.a * self.b,
|
||||
}
|
||||
}
|
||||
|
||||
/// The user-turn question text. No template wrapping — the SFT loader
|
||||
/// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame.
|
||||
pub fn question(self) -> String {
|
||||
format!("What is {} {} {}?", self.a, self.op, self.b)
|
||||
}
|
||||
|
||||
/// The assistant-turn SFT target: restate the equation and end with the boxed
|
||||
/// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`);
|
||||
/// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve.
|
||||
pub fn sft_answer(self) -> String {
|
||||
format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer())
|
||||
}
|
||||
|
||||
/// A stable dedup key, so eval problems can be held out from train.
|
||||
pub fn key(self) -> (i64, char, i64) {
|
||||
(self.a, self.op.symbol(), self.b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Operand-range configuration for problem sampling. Multiplication uses a smaller
|
||||
/// range (`max_mul`) so products stay modest; add/sub use `max_add`.
|
||||
#[derive(Clone)]
|
||||
pub struct GenConfig {
|
||||
pub max_add: i64,
|
||||
pub max_mul: i64,
|
||||
pub ops: Vec<Op>,
|
||||
}
|
||||
|
||||
impl Default for GenConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_add: 99,
|
||||
max_mul: 12,
|
||||
ops: vec![Op::Add, Op::Sub, Op::Mul],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn
|
||||
/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker /
|
||||
/// parser handle a leading `-`).
|
||||
pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem {
|
||||
let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()];
|
||||
let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add };
|
||||
let a = rand_range(rng, max);
|
||||
let b = rand_range(rng, max);
|
||||
Problem { a, b, op }
|
||||
}
|
||||
|
||||
/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there
|
||||
/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last"
|
||||
/// so a model that emits intermediate boxes still scores on its final answer.
|
||||
pub fn parse_boxed_answer(text: &str) -> Option<i64> {
|
||||
const TAG: &str = "\\boxed{";
|
||||
let mut found = None;
|
||||
let mut rest = text;
|
||||
while let Some(i) = rest.find(TAG) {
|
||||
let after = &rest[i + TAG.len()..];
|
||||
match after.find('}') {
|
||||
Some(j) => {
|
||||
if let Ok(n) = after[..j].trim().parse::<i64>() {
|
||||
found = Some(n);
|
||||
}
|
||||
rest = &after[j + 1..];
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
/// Verifiable reward: does the completion's boxed answer exactly match `gold`?
|
||||
pub fn check_answer(completion: &str, gold: i64) -> bool {
|
||||
parse_boxed_answer(completion) == Some(gold)
|
||||
}
|
||||
|
||||
/// `[0, max]` inclusive draw from the LCG.
|
||||
fn rand_range(rng: &mut u64, max: i64) -> i64 {
|
||||
debug_assert!(max >= 0);
|
||||
(next_rand(rng) % (max as u64 + 1)) as i64
|
||||
}
|
||||
|
||||
/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the
|
||||
/// task module stays dependency-free and host-only.
|
||||
fn next_rand(state: &mut u64) -> u64 {
|
||||
*state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
*state >> 1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn answer_question_and_sft_target() {
|
||||
let p = Problem {
|
||||
a: 12,
|
||||
b: 13,
|
||||
op: Op::Mul,
|
||||
};
|
||||
assert_eq!(p.answer(), 156);
|
||||
assert_eq!(p.question(), "What is 12 * 13?");
|
||||
assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}.");
|
||||
let s = Problem {
|
||||
a: 3,
|
||||
b: 8,
|
||||
op: Op::Sub,
|
||||
};
|
||||
assert_eq!(s.answer(), -5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_takes_last_boxed_and_handles_edges() {
|
||||
assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156));
|
||||
assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7));
|
||||
assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42));
|
||||
assert_eq!(parse_boxed_answer("no box here"), None);
|
||||
assert_eq!(parse_boxed_answer("\\boxed{abc}"), None);
|
||||
assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_is_exact_match() {
|
||||
assert!(check_answer("the result is \\boxed{156}.", 156));
|
||||
assert!(!check_answer("the result is \\boxed{155}.", 156));
|
||||
assert!(!check_answer("no boxed answer at all", 156));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sft_target_is_always_self_consistent() {
|
||||
// The SFT target's boxed answer must always check against the problem's own
|
||||
// gold — across all ops/operands. This is the M1 data invariant.
|
||||
let cfg = GenConfig::default();
|
||||
let mut rng = 12345u64;
|
||||
for _ in 0..2000 {
|
||||
let p = gen_problem(&mut rng, &cfg);
|
||||
assert!(
|
||||
check_answer(&p.sft_answer(), p.answer()),
|
||||
"self-inconsistent SFT target for {p:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generation_is_deterministic_from_seed() {
|
||||
let cfg = GenConfig::default();
|
||||
let (mut r1, mut r2) = (7u64, 7u64);
|
||||
for _ in 0..200 {
|
||||
assert_eq!(
|
||||
gen_problem(&mut r1, &cfg).key(),
|
||||
gen_problem(&mut r2, &cfg).key()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
328
docs/18-post-training-rl-sft.md
Normal file
328
docs/18-post-training-rl-sft.md
Normal file
@@ -0,0 +1,328 @@
|
||||
# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document
|
||||
|
||||
> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing
|
||||
> implemented yet. This doc proposes the scope, the staged build, the new infra pieces,
|
||||
> and the correctness gates for a standard post-training stack on top of the xtrain
|
||||
> training framework. Decisions D1–D4 are resolved (see "Resolved decisions"):
|
||||
> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode
|
||||
> engine built up front · a verifiable task as the optimization/eval target.**
|
||||
|
||||
## Goal
|
||||
|
||||
Build a **standard, from-scratch post-training infrastructure** — the systems layer that
|
||||
turns a pretrained base LM into an aligned chat model — and use it to run chat
|
||||
alignment. The deliverable that matters here is the **infra and the lessons**, not the
|
||||
end-to-end chat quality (see the project's learning-axis framing). Each stage should
|
||||
teach exactly one new post-training systems concept and ship with a hard correctness
|
||||
gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical
|
||||
default paths, profile-first).
|
||||
|
||||
Concretely we want to be able to answer, with our own code:
|
||||
|
||||
- How does **offline preference optimization (DPO)** differ from SFT in the training
|
||||
loop — what is the reference model, why two forwards, what is the loss?
|
||||
- How does a **reward model** turn preferences into a scalar signal?
|
||||
- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring,
|
||||
group-relative advantage, the clipped policy-gradient update, the KL leash?
|
||||
- Where are the **memory and throughput** pressure points that make post-training infra
|
||||
different from pretraining infra (multiple models resident, generation in the loop)?
|
||||
|
||||
## Baseline: what already exists vs. what is missing
|
||||
|
||||
What the framework already gives us (verified in code, reused as-is):
|
||||
|
||||
| capability | where | reuse for post-training |
|
||||
|---|---|---|
|
||||
| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO |
|
||||
| cross-entropy with **ignore-index −100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking |
|
||||
| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference |
|
||||
| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute |
|
||||
| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 2–3 models resident |
|
||||
| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training |
|
||||
| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path |
|
||||
| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) |
|
||||
|
||||
What is **missing** and must be built (these are the actual lessons):
|
||||
|
||||
1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_<t)` over the
|
||||
completion tokens of a sequence. CE gives a *mean* scalar; DPO/GRPO need a *per-sequence
|
||||
masked sum*. New op or thin wrapper over the CE per-row machinery.
|
||||
2. **Frozen reference model** held in memory alongside the trainable policy (no grad, no
|
||||
optimizer), or its logprobs precomputed and cached.
|
||||
3. **Pairwise preference loss** (DPO) and **Bradley-Terry ranking loss** (RM).
|
||||
4. **Reward head** — a `[dim,1]` scalar head reading the last non-pad position (RM only).
|
||||
5. **Rollout / generation engine** — batched autoregressive sampling. Current `generate`
|
||||
is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs
|
||||
batched rollouts; a real **KV-cache incremental-decode engine** is the centerpiece infra
|
||||
build.
|
||||
6. **GRPO machinery** — group sampling, group-relative advantage, clipped PG loss, KL
|
||||
penalty, the actor-learner loop.
|
||||
|
||||
## The post-training landscape — where the infra lives
|
||||
|
||||
```
|
||||
data models in memory new systems concept
|
||||
SFT (prompt, answer) policy loss masking (have it)
|
||||
DPO (prompt, chosen, reject) policy + ref(frozen) dual forward, pairwise logσ loss
|
||||
RM (prompt, chosen, reject) reward model scalar head, ranking loss
|
||||
PPO prompts + reward source policy+ref+RM+critic rollout + GAE + clipped PG (4 models)
|
||||
GRPO prompts + reward source policy+ref(+RM) rollout + group baseline + clipped PG
|
||||
```
|
||||
|
||||
The pedagogical ladder is **SFT → DPO → (RM) → GRPO**. DPO is the cheapest "real" alignment
|
||||
method (no generation, no reward model, reuses the training loop almost verbatim) and is the
|
||||
right first rung. GRPO is chosen over PPO as the online-RL rung because it **drops the value
|
||||
critic** (group-relative advantage replaces the learned baseline) — that removes a whole
|
||||
model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted
|
||||
as an optional later extension, not a primary target.
|
||||
|
||||
## Proposed scope & sequencing (recommended path)
|
||||
|
||||
> ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward
|
||||
> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a
|
||||
> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective
|
||||
> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see
|
||||
> "Milestones").
|
||||
|
||||
### Stage P0 — SFT chat baseline (light; mostly reuse)
|
||||
|
||||
Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen
|
||||
reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic
|
||||
prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable
|
||||
completions; the same template is reused by rollout and eval. The current SFT (commit
|
||||
`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment
|
||||
needs:
|
||||
|
||||
- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used,
|
||||
promoted to a documented constant shared by SFT data prep, rollout, and eval),
|
||||
- optional **multi-turn masking** (supervise every assistant turn, mask user turns),
|
||||
- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE
|
||||
per example — note `forward_batched` already isolates sequences, so packing = careful
|
||||
index bookkeeping, not new attention code).
|
||||
|
||||
Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak
|
||||
loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask
|
||||
gives a reproducible SFT reference without changing the training numerics for single-turn data
|
||||
(bit-identical to `fbf4ac2` on single-turn input).
|
||||
|
||||
### Stage P1 — DPO (offline preference optimization) ⭐ first real method
|
||||
|
||||
New infra:
|
||||
|
||||
1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task
|
||||
there is no off-the-shelf preference set, so we build pairs: sample several completions
|
||||
per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone),
|
||||
score each with the rule-based checker, take a **correct** completion as `chosen` and an
|
||||
**incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training
|
||||
itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a
|
||||
completion mask (prompt = masked).
|
||||
2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of
|
||||
`log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row
|
||||
path (CE per-row = `−log πθ(target)`), summing `−per_row` over the mask. Add a grad-checked
|
||||
op so the backward is exact.
|
||||
3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad**
|
||||
bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and
|
||||
cache reference logprobs** once over the dataset → the reference model need not stay
|
||||
resident during training (one model in memory, like SFT).
|
||||
4. **DPO loss** (Rafailov et al.): with
|
||||
`Δ = β[(logπθ(yw|x) − logπref(yw|x)) − (logπθ(yl|x) − logπref(yl|x))]`,
|
||||
`L = −log σ(Δ)`. Only `πθ` terms carry gradient.
|
||||
|
||||
Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached
|
||||
logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.
|
||||
|
||||
Correctness gates:
|
||||
- `seq_logprob` finite-difference grad-check (tiny model).
|
||||
- DPO-loss + grad **PyTorch parity** (the project's standard gate).
|
||||
- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0;
|
||||
`β → 0` ⇒ gradient → 0.
|
||||
- **Health metric**: chosen−rejected **reward margin** rises over training; accuracy
|
||||
(margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a
|
||||
sufficient signal).
|
||||
|
||||
Application: chat alignment via DPO on English preference pairs. This is the **offline
|
||||
chat-alignment deliverable**.
|
||||
|
||||
### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL
|
||||
|
||||
> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO
|
||||
> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only
|
||||
> if we later want general-chat GRPO). So this whole stage is optional and not on the critical
|
||||
> path.
|
||||
|
||||
New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last
|
||||
non-pad position; **ranking loss** `−log σ(r(x,yw) − r(x,yl))`. Reuses the preference data
|
||||
and the dual-sequence forward from P1.
|
||||
|
||||
Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM
|
||||
loads/serves the scalar correctly.
|
||||
|
||||
### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson
|
||||
|
||||
This is the centerpiece. It introduces **generation inside the training loop**.
|
||||
|
||||
**(a) Rollout / generation engine — built up front (its own milestone).**
|
||||
|
||||
> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine
|
||||
> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is
|
||||
> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted:
|
||||
> front-loads the single hardest build and delays the first alignment result, in exchange for
|
||||
> a real generation engine and a clean, isolated infra lesson.
|
||||
|
||||
The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt
|
||||
once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts
|
||||
× G samples; sequences hit EOS at different lengths → finished-mask / left-padding /
|
||||
compaction). The current attention assumes a full causal window over `seq`; incremental decode
|
||||
needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with
|
||||
RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a
|
||||
distinct code path without disturbing the training attention (flash/GQA/composed unchanged).
|
||||
|
||||
Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode,
|
||||
token-identical** greedy output — the same byte-/token-identical discipline the project uses
|
||||
for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs.
|
||||
per-token decode) is recorded here, before any rollout optimization (profile-first).
|
||||
|
||||
**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic
|
||||
arithmetic/format task) or RM from P2. Returns a scalar per completion.
|
||||
|
||||
**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage
|
||||
`A_i = (r_i − mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE.
|
||||
|
||||
**(d) Clipped policy-gradient loss with KL leash.** Per completion token,
|
||||
`ρ_t = exp(logπθ_t − logπθ_old_t)` (old = policy at rollout time), token loss
|
||||
`−min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3
|
||||
estimator.
|
||||
|
||||
**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage →
|
||||
capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref`
|
||||
fixed throughout.
|
||||
|
||||
Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations.
|
||||
Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded
|
||||
(tokens/s of generation vs. update) **before** any rollout optimization, per the project's
|
||||
profile-first rule.
|
||||
|
||||
Correctness gates:
|
||||
- PG-loss finite-diff grad-check.
|
||||
- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG;
|
||||
`β = 0` ⇒ no KL term.
|
||||
- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a
|
||||
prerequisite of GRPO.)
|
||||
- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must
|
||||
rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal
|
||||
that the loop is correct, independent of fuzzy chat quality).
|
||||
|
||||
## Evaluation
|
||||
|
||||
- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the
|
||||
fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/
|
||||
after — reusing and extending the doc-13 recommendation for a generation-based eval harness
|
||||
(exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
|
||||
- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task
|
||||
pass rate, and the same fixed-prompt suite.
|
||||
- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower
|
||||
post-training loss did not mean better generations.
|
||||
|
||||
## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)
|
||||
|
||||
- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy
|
||||
optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations
|
||||
small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
|
||||
- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
|
||||
- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new
|
||||
variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be
|
||||
measured, not assumed.
|
||||
|
||||
## Correctness-gate philosophy (unchanged from Phase 1/2)
|
||||
|
||||
Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity
|
||||
on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1,
|
||||
ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic
|
||||
RL overfit), and (5) **no change to the default training path** when post-training flags are
|
||||
off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference
|
||||
gates as flash/GQA.
|
||||
|
||||
## Risks & tradeoffs
|
||||
|
||||
- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real
|
||||
build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an
|
||||
isolated, separately-gated sub-phase.
|
||||
- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation:
|
||||
synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
|
||||
- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first.
|
||||
- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
|
||||
- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference
|
||||
logprob caching.
|
||||
|
||||
## Resolved decisions (aligned 2026-06-29)
|
||||
|
||||
- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.**
|
||||
- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional).
|
||||
- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not
|
||||
naive-first), as a foundational milestone before DPO/GRPO.
|
||||
- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with
|
||||
a deterministic exact-match reward, for a clean, falsifiable RL signal.
|
||||
|
||||
## Milestones (locked order)
|
||||
|
||||
1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable
|
||||
task; produces the reference + init checkpoint. Gate: masking unit test; single-turn
|
||||
bit-identical to `fbf4ac2`.
|
||||
2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental
|
||||
decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute
|
||||
greedy**; record decode throughput baseline.
|
||||
3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op
|
||||
(grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO
|
||||
training loop → run + reward-margin / preference-accuracy curve.
|
||||
4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage +
|
||||
clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic
|
||||
verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run.
|
||||
5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate;
|
||||
enables GRPO-with-RM for general chat.
|
||||
|
||||
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
|
||||
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
|
||||
|
||||
## Implementation log
|
||||
|
||||
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
|
||||
|
||||
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
|
||||
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
|
||||
runs on dash5.
|
||||
|
||||
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
|
||||
|
||||
- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,99]` for `+/−`, `[0,12]` for `×`
|
||||
(modest products); subtraction may be negative.
|
||||
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
|
||||
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
|
||||
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
|
||||
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
|
||||
and M4 (GRPO reward).
|
||||
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
|
||||
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
|
||||
|
||||
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
|
||||
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
|
||||
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
|
||||
Train rows are deduped; eval is held out from train (no leakage).
|
||||
|
||||
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
|
||||
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
|
||||
single-turn path is bit-identical to `fbf4ac2`).
|
||||
|
||||
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
|
||||
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
|
||||
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
|
||||
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
|
||||
train/eval leakage.
|
||||
|
||||
**Pending on dash5 (needs GPU):**
|
||||
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
|
||||
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
|
||||
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
|
||||
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
|
||||
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
|
||||
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
|
||||
buys the format; correctness is M3/M4's job.
|
||||
Reference in New Issue
Block a user