Compare commits

...

2 Commits

Author SHA1 Message Date
9c70e99ae4 post-train: M1 — verifiable arithmetic task + SFT data generator
First post-training milestone (docs/18). Lands the verifiable task + its data
pipeline, all verified host-side (no CUDA); the SFT run itself reuses the
existing --sft-tsv path on the GPU box.

- task.rs: the shared task spec — two-operand integer arithmetic, answer in
  \boxed{N}, with parse_boxed_answer + check_answer (exact-match rule-based
  reward). One module reused by M1 (SFT data), M3 (DPO pairs), M4 (GRPO reward).
- gen_arith_task bin: writes arith_sft.tsv (--sft-tsv format) + held-out
  arith_eval_prompts.txt (greedy_sample format) + arith_eval_gold.txt; train
  deduped, eval disjoint from train.
- data.rs: extract assistant-only masking into a pure, testable sft_row()
  (behavior-preserving; single-turn bit-identical to fbf4ac2).

Gate (verified locally, no_cuda): cargo test -p xtrain-train --lib = 9/9 pass
(masking, SFT-target self-consistency over 2000 samples, parser edges, seed
determinism); a 200/50 gen run = clean 2-col TSV, correct gold incl. negatives,
0 train/eval leakage. SFT training run + format-eval pending on dash5.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 22:52:25 +08:00
ab32168dcc docs: post-training stack design — SFT → KV-cache → DPO → GRPO (docs/18)
Design doc for a from-scratch post-training infra on top of xtrain. Ladder:
SFT (have it) → DPO → reward model (optional) → GRPO, each rung one new
post-training systems concept + a hard correctness gate (grad-check, PyTorch
parity, degenerate checks, a falsifiable 'it learns' signal).

Decisions aligned with the user (D1-D4):
- D1 scope: DPO → GRPO, reward model optional.
- D2 reward: rule-based / verifiable first; learned RM deferred.
- D3 rollout: build the KV-cache incremental-decode engine UP FRONT (not
  naive-first) as the foundational milestone before DPO/GRPO.
- D4 task: a verifiable task (arithmetic/format) with deterministic exact-match
  reward, for a clean RL signal.

Locked milestone order: M1 SFT task baseline → M2 KV-cache decode engine
(token-identical gate) → M3 DPO → M4 GRPO → M5 optional reward model. Status:
design only, no implementation yet.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-29 22:44:25 +08:00
5 changed files with 676 additions and 4 deletions

View File

@@ -0,0 +1,94 @@
//! Generate the M1 verifiable-arithmetic post-training dataset. Pure host tool (no
//! CUDA): writes
//! <out>/arith_sft.tsv user<TAB>assistant rows for `train --sft-tsv`
//! <out>/arith_eval_prompts.txt greedy_sample `--prompts-file` format (held out)
//! <out>/arith_eval_gold.txt parallel gold integers for the checker
//!
//! Eval problems are deduped against train (no leakage). The SFT rows carry just the
//! user/assistant content; `data::load_sft_tsv_cached` adds the `User:/Assistant:`
//! frame + `<|endoftext|>` and masks the prompt, so the eval prompt lines here
//! reconstruct exactly that frame (`User: <q>\nAssistant:`, literal `\n` decoded by
//! greedy_sample).
//!
//! Example:
//! cargo run -p xtrain-train --release --bin gen_arith_task -- \
//! --n 20000 --eval 500 --seed 1 --out-dir /dashscope-tmp/wjh/xtrain_post/arith
use std::collections::HashSet;
use std::fs::{self, File};
use std::io::{BufWriter, Write};
use std::path::PathBuf;
use xtrain_train::task::{GenConfig, Op, gen_problem};
fn flag<T: std::str::FromStr>(args: &[String], name: &str, default: T) -> T {
args.iter()
.position(|a| a == name)
.and_then(|i| args.get(i + 1))
.and_then(|v| v.parse().ok())
.unwrap_or(default)
}
fn main() {
let args: Vec<String> = std::env::args().collect();
let n_train: usize = flag(&args, "--n", 20000);
let n_eval: usize = flag(&args, "--eval", 500);
let seed: u64 = flag(&args, "--seed", 1);
let max_add: i64 = flag(&args, "--max-add", 99);
let max_mul: i64 = flag(&args, "--max-mul", 12);
let out_dir: PathBuf = args
.iter()
.position(|a| a == "--out-dir")
.and_then(|i| args.get(i + 1))
.map(PathBuf::from)
.expect("--out-dir <dir> is required");
fs::create_dir_all(&out_dir).expect("create out dir");
let cfg = GenConfig {
max_add,
max_mul,
ops: vec![Op::Add, Op::Sub, Op::Mul],
};
let mut rng = seed.max(1);
// Train: dedup so the same problem is not repeated and so eval can be held out.
let mut train_keys = HashSet::new();
let mut tsv = BufWriter::new(File::create(out_dir.join("arith_sft.tsv")).expect("create tsv"));
while train_keys.len() < n_train {
let p = gen_problem(&mut rng, &cfg);
if !train_keys.insert(p.key()) {
continue;
}
writeln!(tsv, "{}\t{}", p.question(), p.sft_answer()).expect("write tsv");
}
tsv.flush().expect("flush tsv");
// Eval: disjoint from train (skip any key seen in train) and from itself.
let mut prompts =
BufWriter::new(File::create(out_dir.join("arith_eval_prompts.txt")).expect("create eval"));
let mut golds =
BufWriter::new(File::create(out_dir.join("arith_eval_gold.txt")).expect("create gold"));
writeln!(prompts, "# verifiable arithmetic eval prompts (held out from arith_sft.tsv)")
.expect("write header");
let mut eval_keys = HashSet::new();
while eval_keys.len() < n_eval {
let p = gen_problem(&mut rng, &cfg);
if train_keys.contains(&p.key()) || !eval_keys.insert(p.key()) {
continue;
}
writeln!(prompts, "User: {}\\nAssistant:", p.question()).expect("write prompt");
writeln!(golds, "{}", p.answer()).expect("write gold");
}
prompts.flush().expect("flush prompts");
golds.flush().expect("flush golds");
println!(
"wrote {} train rows + {} eval prompts to {} (ops=+,-,* max_add={} max_mul={} seed={})",
train_keys.len(),
eval_keys.len(),
out_dir.display(),
max_add,
max_mul,
seed
);
}

View File

@@ -124,10 +124,9 @@ impl Corpus {
let answer = format!(" {assistant}\n<|endoftext|>");
let prompt_ids: Vec<i32> = tok.encode(&prompt).into_iter().map(|t| t as i32).collect();
let answer_ids: Vec<i32> = tok.encode(&answer).into_iter().map(|t| t as i32).collect();
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
labels.extend(answer_ids.iter().copied());
tokens.extend(prompt_ids);
tokens.extend(answer_ids);
let (row_tokens, row_labels) = sft_row(&prompt_ids, &answer_ids);
tokens.extend(row_tokens);
labels.extend(row_labels);
}
assert_eq!(tokens.len(), labels.len(), "SFT tokens/labels mismatch");
write_u16_cache(&token_cache, &tokens);
@@ -291,6 +290,20 @@ fn decode_tsv_escapes(s: &str) -> String {
s.replace("\\n", "\n").replace("\\t", "\t")
}
/// Build one SFT example's `(tokens, labels)` from already-tokenized prompt/answer
/// ids: prompt tokens are masked to the ignore-index (`-100`, which `cross_entropy`
/// skips) so only the answer + EOS tokens contribute to the loss. Pure (no tokenizer
/// / no CUDA) so the assistant-only masking is unit-testable directly.
fn sft_row(prompt_ids: &[i32], answer_ids: &[i32]) -> (Vec<i32>, Vec<i32>) {
let mut tokens = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
tokens.extend_from_slice(prompt_ids);
tokens.extend_from_slice(answer_ids);
let mut labels = Vec::with_capacity(prompt_ids.len() + answer_ids.len());
labels.extend(std::iter::repeat(-100).take(prompt_ids.len()));
labels.extend_from_slice(answer_ids);
(tokens, labels)
}
/// Tiny LCG (same constants as the model tests' deterministic fill) so dataset
/// sampling is reproducible from a single u64 seed.
fn next_rand(state: &mut u64) -> u64 {
@@ -299,3 +312,27 @@ fn next_rand(state: &mut u64) -> u64 {
.wrapping_add(1442695040888963407);
*state >> 16
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sft_row_masks_prompt_supervises_answer() {
let prompt = [5, 6, 7];
let answer = [8, 9]; // includes the EOS token in real use
let (tokens, labels) = sft_row(&prompt, &answer);
// Tokens are prompt then answer, in order.
assert_eq!(tokens, vec![5, 6, 7, 8, 9]);
// Prompt positions are ignore-index (-100); answer positions are supervised.
assert_eq!(labels, vec![-100, -100, -100, 8, 9]);
assert_eq!(tokens.len(), labels.len());
}
#[test]
fn sft_row_handles_empty_answer() {
let (tokens, labels) = sft_row(&[1, 2], &[]);
assert_eq!(tokens, vec![1, 2]);
assert_eq!(labels, vec![-100, -100]);
}
}

View File

@@ -10,6 +10,7 @@
pub mod clip;
pub mod data;
pub mod schedule;
pub mod task;
#[cfg(not(no_cuda))]
pub mod checkpoint;

View File

@@ -0,0 +1,212 @@
//! Verifiable arithmetic task (post-training, M1). A tiny two-operand integer
//! arithmetic task with a deterministic, rule-based checker: the assistant must end
//! its answer with `\boxed{N}`, and the reward is exact-match on `N`.
//!
//! This single module is the shared task spec for the whole post-training stack —
//! M1 SFT-data generation, M3 DPO preference-pair construction, and M4 GRPO reward
//! scoring all parse/score through here, so the task lives in exactly one place.
//!
//! Host-only (no CUDA): generation + parsing + checking are pure, so this compiles
//! and unit-tests on a GPU-less host.
use std::fmt;
/// The supported binary operations.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Op {
Add,
Sub,
Mul,
}
impl Op {
pub fn symbol(self) -> char {
match self {
Op::Add => '+',
Op::Sub => '-',
Op::Mul => '*',
}
}
}
impl fmt::Display for Op {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.symbol())
}
}
/// A single two-operand arithmetic problem.
#[derive(Clone, Copy, Debug)]
pub struct Problem {
pub a: i64,
pub b: i64,
pub op: Op,
}
impl Problem {
/// The exact integer answer (the verifiable gold label).
pub fn answer(self) -> i64 {
match self.op {
Op::Add => self.a + self.b,
Op::Sub => self.a - self.b,
Op::Mul => self.a * self.b,
}
}
/// The user-turn question text. No template wrapping — the SFT loader
/// (`data::load_sft_tsv_cached`) adds the `User:/Assistant:` frame.
pub fn question(self) -> String {
format!("What is {} {} {}?", self.a, self.op, self.b)
}
/// The assistant-turn SFT target: restate the equation and end with the boxed
/// answer. This teaches the answer FORMAT (the checker only reads `\boxed{}`);
/// arithmetic correctness is what DPO (M3) / GRPO (M4) later improve.
pub fn sft_answer(self) -> String {
format!("{} {} {} = \\boxed{{{}}}.", self.a, self.op, self.b, self.answer())
}
/// A stable dedup key, so eval problems can be held out from train.
pub fn key(self) -> (i64, char, i64) {
(self.a, self.op.symbol(), self.b)
}
}
/// Operand-range configuration for problem sampling. Multiplication uses a smaller
/// range (`max_mul`) so products stay modest; add/sub use `max_add`.
#[derive(Clone)]
pub struct GenConfig {
pub max_add: i64,
pub max_mul: i64,
pub ops: Vec<Op>,
}
impl Default for GenConfig {
fn default() -> Self {
Self {
max_add: 99,
max_mul: 12,
ops: vec![Op::Add, Op::Sub, Op::Mul],
}
}
}
/// Sample one problem deterministically from the LCG state `rng`. Operands are drawn
/// in `[0, max]` per the op; subtraction may yield a negative answer (the checker /
/// parser handle a leading `-`).
pub fn gen_problem(rng: &mut u64, cfg: &GenConfig) -> Problem {
let op = cfg.ops[(next_rand(rng) as usize) % cfg.ops.len()];
let max = if op == Op::Mul { cfg.max_mul } else { cfg.max_add };
let a = rand_range(rng, max);
let b = rand_range(rng, max);
Problem { a, b, op }
}
/// Parse the integer inside the LAST `\boxed{...}` in `text`. Returns `None` if there
/// is no well-formed boxed integer (no box, empty, or non-integer contents). "Last"
/// so a model that emits intermediate boxes still scores on its final answer.
pub fn parse_boxed_answer(text: &str) -> Option<i64> {
const TAG: &str = "\\boxed{";
let mut found = None;
let mut rest = text;
while let Some(i) = rest.find(TAG) {
let after = &rest[i + TAG.len()..];
match after.find('}') {
Some(j) => {
if let Ok(n) = after[..j].trim().parse::<i64>() {
found = Some(n);
}
rest = &after[j + 1..];
}
None => break,
}
}
found
}
/// Verifiable reward: does the completion's boxed answer exactly match `gold`?
pub fn check_answer(completion: &str, gold: i64) -> bool {
parse_boxed_answer(completion) == Some(gold)
}
/// `[0, max]` inclusive draw from the LCG.
fn rand_range(rng: &mut u64, max: i64) -> i64 {
debug_assert!(max >= 0);
(next_rand(rng) % (max as u64 + 1)) as i64
}
/// Same LCG constants as the dataset sampler (`data::next_rand`), kept local so the
/// task module stays dependency-free and host-only.
fn next_rand(state: &mut u64) -> u64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*state >> 1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn answer_question_and_sft_target() {
let p = Problem {
a: 12,
b: 13,
op: Op::Mul,
};
assert_eq!(p.answer(), 156);
assert_eq!(p.question(), "What is 12 * 13?");
assert_eq!(p.sft_answer(), "12 * 13 = \\boxed{156}.");
let s = Problem {
a: 3,
b: 8,
op: Op::Sub,
};
assert_eq!(s.answer(), -5);
}
#[test]
fn parse_takes_last_boxed_and_handles_edges() {
assert_eq!(parse_boxed_answer("\\boxed{3} then \\boxed{156}."), Some(156));
assert_eq!(parse_boxed_answer("\\boxed{-7}"), Some(-7));
assert_eq!(parse_boxed_answer("\\boxed{ 42 }"), Some(42));
assert_eq!(parse_boxed_answer("no box here"), None);
assert_eq!(parse_boxed_answer("\\boxed{abc}"), None);
assert_eq!(parse_boxed_answer("\\boxed{unterminated"), None);
}
#[test]
fn check_is_exact_match() {
assert!(check_answer("the result is \\boxed{156}.", 156));
assert!(!check_answer("the result is \\boxed{155}.", 156));
assert!(!check_answer("no boxed answer at all", 156));
}
#[test]
fn sft_target_is_always_self_consistent() {
// The SFT target's boxed answer must always check against the problem's own
// gold — across all ops/operands. This is the M1 data invariant.
let cfg = GenConfig::default();
let mut rng = 12345u64;
for _ in 0..2000 {
let p = gen_problem(&mut rng, &cfg);
assert!(
check_answer(&p.sft_answer(), p.answer()),
"self-inconsistent SFT target for {p:?}"
);
}
}
#[test]
fn generation_is_deterministic_from_seed() {
let cfg = GenConfig::default();
let (mut r1, mut r2) = (7u64, 7u64);
for _ in 0..200 {
assert_eq!(
gen_problem(&mut r1, &cfg).key(),
gen_problem(&mut r2, &cfg).key()
);
}
}
}

View File

@@ -0,0 +1,328 @@
# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document
> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing
> implemented yet. This doc proposes the scope, the staged build, the new infra pieces,
> and the correctness gates for a standard post-training stack on top of the xtrain
> training framework. Decisions D1D4 are resolved (see "Resolved decisions"):
> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode
> engine built up front · a verifiable task as the optimization/eval target.**
## Goal
Build a **standard, from-scratch post-training infrastructure** — the systems layer that
turns a pretrained base LM into an aligned chat model — and use it to run chat
alignment. The deliverable that matters here is the **infra and the lessons**, not the
end-to-end chat quality (see the project's learning-axis framing). Each stage should
teach exactly one new post-training systems concept and ship with a hard correctness
gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical
default paths, profile-first).
Concretely we want to be able to answer, with our own code:
- How does **offline preference optimization (DPO)** differ from SFT in the training
loop — what is the reference model, why two forwards, what is the loss?
- How does a **reward model** turn preferences into a scalar signal?
- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring,
group-relative advantage, the clipped policy-gradient update, the KL leash?
- Where are the **memory and throughput** pressure points that make post-training infra
different from pretraining infra (multiple models resident, generation in the loop)?
## Baseline: what already exists vs. what is missing
What the framework already gives us (verified in code, reused as-is):
| capability | where | reuse for post-training |
|---|---|---|
| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO |
| cross-entropy with **ignore-index 100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking |
| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference |
| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute |
| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 23 models resident |
| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training |
| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path |
| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) |
What is **missing** and must be built (these are the actual lessons):
1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_<t)` over the
completion tokens of a sequence. CE gives a *mean* scalar; DPO/GRPO need a *per-sequence
masked sum*. New op or thin wrapper over the CE per-row machinery.
2. **Frozen reference model** held in memory alongside the trainable policy (no grad, no
optimizer), or its logprobs precomputed and cached.
3. **Pairwise preference loss** (DPO) and **Bradley-Terry ranking loss** (RM).
4. **Reward head** — a `[dim,1]` scalar head reading the last non-pad position (RM only).
5. **Rollout / generation engine** — batched autoregressive sampling. Current `generate`
is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs
batched rollouts; a real **KV-cache incremental-decode engine** is the centerpiece infra
build.
6. **GRPO machinery** — group sampling, group-relative advantage, clipped PG loss, KL
penalty, the actor-learner loop.
## The post-training landscape — where the infra lives
```
data models in memory new systems concept
SFT (prompt, answer) policy loss masking (have it)
DPO (prompt, chosen, reject) policy + ref(frozen) dual forward, pairwise logσ loss
RM (prompt, chosen, reject) reward model scalar head, ranking loss
PPO prompts + reward source policy+ref+RM+critic rollout + GAE + clipped PG (4 models)
GRPO prompts + reward source policy+ref(+RM) rollout + group baseline + clipped PG
```
The pedagogical ladder is **SFT → DPO → (RM) → GRPO**. DPO is the cheapest "real" alignment
method (no generation, no reward model, reuses the training loop almost verbatim) and is the
right first rung. GRPO is chosen over PPO as the online-RL rung because it **drops the value
critic** (group-relative advantage replaces the learned baseline) — that removes a whole
model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted
as an optional later extension, not a primary target.
## Proposed scope & sequencing (recommended path)
> ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward
> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a
> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective
> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see
> "Milestones").
### Stage P0 — SFT chat baseline (light; mostly reuse)
Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen
reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic
prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable
completions; the same template is reused by rollout and eval. The current SFT (commit
`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment
needs:
- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used,
promoted to a documented constant shared by SFT data prep, rollout, and eval),
- optional **multi-turn masking** (supervise every assistant turn, mask user turns),
- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE
per example — note `forward_batched` already isolates sequences, so packing = careful
index bookkeeping, not new attention code).
Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak
loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask
gives a reproducible SFT reference without changing the training numerics for single-turn data
(bit-identical to `fbf4ac2` on single-turn input).
### Stage P1 — DPO (offline preference optimization) ⭐ first real method
New infra:
1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task
there is no off-the-shelf preference set, so we build pairs: sample several completions
per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone),
score each with the rule-based checker, take a **correct** completion as `chosen` and an
**incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training
itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a
completion mask (prompt = masked).
2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of
`log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row
path (CE per-row = `log πθ(target)`), summing `per_row` over the mask. Add a grad-checked
op so the backward is exact.
3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad**
bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and
cache reference logprobs** once over the dataset → the reference model need not stay
resident during training (one model in memory, like SFT).
4. **DPO loss** (Rafailov et al.): with
`Δ = β[(logπθ(yw|x) logπref(yw|x)) (logπθ(yl|x) logπref(yl|x))]`,
`L = log σ(Δ)`. Only `πθ` terms carry gradient.
Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached
logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.
Correctness gates:
- `seq_logprob` finite-difference grad-check (tiny model).
- DPO-loss + grad **PyTorch parity** (the project's standard gate).
- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0;
`β → 0` ⇒ gradient → 0.
- **Health metric**: chosenrejected **reward margin** rises over training; accuracy
(margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a
sufficient signal).
Application: chat alignment via DPO on English preference pairs. This is the **offline
chat-alignment deliverable**.
### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL
> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO
> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only
> if we later want general-chat GRPO). So this whole stage is optional and not on the critical
> path.
New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last
non-pad position; **ranking loss** `log σ(r(x,yw) r(x,yl))`. Reuses the preference data
and the dual-sequence forward from P1.
Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM
loads/serves the scalar correctly.
### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson
This is the centerpiece. It introduces **generation inside the training loop**.
**(a) Rollout / generation engine — built up front (its own milestone).**
> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine
> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is
> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted:
> front-loads the single hardest build and delays the first alignment result, in exchange for
> a real generation engine and a clean, isolated infra lesson.
The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt
once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts
× G samples; sequences hit EOS at different lengths → finished-mask / left-padding /
compaction). The current attention assumes a full causal window over `seq`; incremental decode
needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with
RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a
distinct code path without disturbing the training attention (flash/GQA/composed unchanged).
Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode,
token-identical** greedy output — the same byte-/token-identical discipline the project uses
for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs.
per-token decode) is recorded here, before any rollout optimization (profile-first).
**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic
arithmetic/format task) or RM from P2. Returns a scalar per completion.
**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage
`A_i = (r_i mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE.
**(d) Clipped policy-gradient loss with KL leash.** Per completion token,
`ρ_t = exp(logπθ_t logπθ_old_t)` (old = policy at rollout time), token loss
`min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3
estimator.
**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage →
capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref`
fixed throughout.
Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations.
Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded
(tokens/s of generation vs. update) **before** any rollout optimization, per the project's
profile-first rule.
Correctness gates:
- PG-loss finite-diff grad-check.
- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG;
`β = 0` ⇒ no KL term.
- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a
prerequisite of GRPO.)
- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must
rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal
that the loop is correct, independent of fuzzy chat quality).
## Evaluation
- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the
fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/
after — reusing and extending the doc-13 recommendation for a generation-based eval harness
(exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task
pass rate, and the same fixed-prompt suite.
- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower
post-training loss did not mean better generations.
## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)
- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy
optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations
small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new
variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be
measured, not assumed.
## Correctness-gate philosophy (unchanged from Phase 1/2)
Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity
on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1,
ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic
RL overfit), and (5) **no change to the default training path** when post-training flags are
off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference
gates as flash/GQA.
## Risks & tradeoffs
- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real
build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an
isolated, separately-gated sub-phase.
- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation:
synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first.
- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference
logprob caching.
## Resolved decisions (aligned 2026-06-29)
- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.**
- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional).
- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not
naive-first), as a foundational milestone before DPO/GRPO.
- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with
a deterministic exact-match reward, for a clean, falsifiable RL signal.
## Milestones (locked order)
1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable
task; produces the reference + init checkpoint. Gate: masking unit test; single-turn
bit-identical to `fbf4ac2`.
2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental
decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute
greedy**; record decode throughput baseline.
3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op
(grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO
training loop → run + reward-margin / preference-accuracy curve.
4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage +
clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic
verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run.
5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate;
enables GRPO-with-RM for general chat.
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
## Implementation log
### M1 — SFT task baseline (infra landed; SFT run pending on the GPU box)
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
needed). The SFT training run itself is a one-liner over the existing `--sft-tsv` path and
runs on dash5.
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
- Two-operand integer arithmetic, ops `+ ×`; operands `[0,99]` for `+/`, `[0,12]` for `×`
(modest products); subtraction may be negative.
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
and M4 (GRPO reward).
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
Train rows are deduped; eval is held out from train (no leakage).
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
single-turn path is bit-identical to `fbf4ac2`).
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
train/eval leakage.
**Pending on dash5 (needs GPU):**
1. generate the dataset: `gen_arith_task --n <N> --eval <M> --seed 1 --out-dir <dir>`;
2. SFT from the v12 base: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt>
<arch flags> --bf16 --recompute --flash --ckpt <out.ckpt>` → the P0 reference/init checkpoint;
3. format eval: `greedy_sample <out.ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt`,
then score completions with `check_answer` vs. `arith_eval_gold.txt` (a tiny scorer is a
one-off; can fold into M3's pipeline). Baseline pass-rate here is expected to be low — M1 only
buys the format; correctness is M3/M4's job.