Files
xtrain/docs/18-post-training-rl-sft.md
Gahow Wang 4379868f2d docs: M2d — ragged-batching lever, 9× measured, step bottleneck → rollout
Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free
insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput.

The honest decomposition correction: M2c claimed the training forwards "dominate" the
step; the clean per-component bench falsifies the strong form — they were ~2.5 s of
the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger
share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G
rollout batching (today only the G samples of each prompt decode in lockstep; the B
prompts are still sequential). Same measure-first lesson, once more.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 23:03:28 +08:00

630 lines
42 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document
> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing
> implemented yet. This doc proposes the scope, the staged build, the new infra pieces,
> and the correctness gates for a standard post-training stack on top of the xtrain
> training framework. Decisions D1D4 are resolved (see "Resolved decisions"):
> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode
> engine built up front · a verifiable task as the optimization/eval target.**
## Goal
Build a **standard, from-scratch post-training infrastructure** — the systems layer that
turns a pretrained base LM into an aligned chat model — and use it to run chat
alignment. The deliverable that matters here is the **infra and the lessons**, not the
end-to-end chat quality (see the project's learning-axis framing). Each stage should
teach exactly one new post-training systems concept and ship with a hard correctness
gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical
default paths, profile-first).
Concretely we want to be able to answer, with our own code:
- How does **offline preference optimization (DPO)** differ from SFT in the training
loop — what is the reference model, why two forwards, what is the loss?
- How does a **reward model** turn preferences into a scalar signal?
- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring,
group-relative advantage, the clipped policy-gradient update, the KL leash?
- Where are the **memory and throughput** pressure points that make post-training infra
different from pretraining infra (multiple models resident, generation in the loop)?
## Baseline: what already exists vs. what is missing
What the framework already gives us (verified in code, reused as-is):
| capability | where | reuse for post-training |
|---|---|---|
| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO |
| cross-entropy with **ignore-index 100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking |
| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference |
| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute |
| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 23 models resident |
| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training |
| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path |
| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) |
What is **missing** and must be built (these are the actual lessons):
1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_<t)` over the
completion tokens of a sequence. CE gives a *mean* scalar; DPO/GRPO need a *per-sequence
masked sum*. New op or thin wrapper over the CE per-row machinery.
2. **Frozen reference model** held in memory alongside the trainable policy (no grad, no
optimizer), or its logprobs precomputed and cached.
3. **Pairwise preference loss** (DPO) and **Bradley-Terry ranking loss** (RM).
4. **Reward head** — a `[dim,1]` scalar head reading the last non-pad position (RM only).
5. **Rollout / generation engine** — batched autoregressive sampling. Current `generate`
is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs
batched rollouts; a real **KV-cache incremental-decode engine** is the centerpiece infra
build.
6. **GRPO machinery** — group sampling, group-relative advantage, clipped PG loss, KL
penalty, the actor-learner loop.
## The post-training landscape — where the infra lives
```
data models in memory new systems concept
SFT (prompt, answer) policy loss masking (have it)
DPO (prompt, chosen, reject) policy + ref(frozen) dual forward, pairwise logσ loss
RM (prompt, chosen, reject) reward model scalar head, ranking loss
PPO prompts + reward source policy+ref+RM+critic rollout + GAE + clipped PG (4 models)
GRPO prompts + reward source policy+ref(+RM) rollout + group baseline + clipped PG
```
The pedagogical ladder is **SFT → DPO → (RM) → GRPO**. DPO is the cheapest "real" alignment
method (no generation, no reward model, reuses the training loop almost verbatim) and is the
right first rung. GRPO is chosen over PPO as the online-RL rung because it **drops the value
critic** (group-relative advantage replaces the learned baseline) — that removes a whole
model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted
as an optional later extension, not a primary target.
## Proposed scope & sequencing (recommended path)
> ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward
> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a
> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective
> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see
> "Milestones").
### Stage P0 — SFT chat baseline (light; mostly reuse)
Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen
reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic
prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable
completions; the same template is reused by rollout and eval. The current SFT (commit
`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment
needs:
- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used,
promoted to a documented constant shared by SFT data prep, rollout, and eval),
- optional **multi-turn masking** (supervise every assistant turn, mask user turns),
- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE
per example — note `forward_batched` already isolates sequences, so packing = careful
index bookkeeping, not new attention code).
Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak
loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask
gives a reproducible SFT reference without changing the training numerics for single-turn data
(bit-identical to `fbf4ac2` on single-turn input).
### Stage P1 — DPO (offline preference optimization) ⭐ first real method
New infra:
1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task
there is no off-the-shelf preference set, so we build pairs: sample several completions
per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone),
score each with the rule-based checker, take a **correct** completion as `chosen` and an
**incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training
itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a
completion mask (prompt = masked).
2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of
`log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row
path (CE per-row = `log πθ(target)`), summing `per_row` over the mask. Add a grad-checked
op so the backward is exact.
3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad**
bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and
cache reference logprobs** once over the dataset → the reference model need not stay
resident during training (one model in memory, like SFT).
4. **DPO loss** (Rafailov et al.): with
`Δ = β[(logπθ(yw|x) logπref(yw|x)) (logπθ(yl|x) logπref(yl|x))]`,
`L = log σ(Δ)`. Only `πθ` terms carry gradient.
Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached
logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.
Correctness gates:
- `seq_logprob` finite-difference grad-check (tiny model).
- DPO-loss + grad **PyTorch parity** (the project's standard gate).
- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0;
`β → 0` ⇒ gradient → 0.
- **Health metric**: chosenrejected **reward margin** rises over training; accuracy
(margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a
sufficient signal).
Application: chat alignment via DPO on English preference pairs. This is the **offline
chat-alignment deliverable**.
### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL
> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO
> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only
> if we later want general-chat GRPO). So this whole stage is optional and not on the critical
> path.
New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last
non-pad position; **ranking loss** `log σ(r(x,yw) r(x,yl))`. Reuses the preference data
and the dual-sequence forward from P1.
Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM
loads/serves the scalar correctly.
### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson
This is the centerpiece. It introduces **generation inside the training loop**.
**(a) Rollout / generation engine — built up front (its own milestone).**
> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine
> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is
> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted:
> front-loads the single hardest build and delays the first alignment result, in exchange for
> a real generation engine and a clean, isolated infra lesson.
The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt
once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts
× G samples; sequences hit EOS at different lengths → finished-mask / left-padding /
compaction). The current attention assumes a full causal window over `seq`; incremental decode
needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with
RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a
distinct code path without disturbing the training attention (flash/GQA/composed unchanged).
Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode,
token-identical** greedy output — the same byte-/token-identical discipline the project uses
for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs.
per-token decode) is recorded here, before any rollout optimization (profile-first).
**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic
arithmetic/format task) or RM from P2. Returns a scalar per completion.
**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage
`A_i = (r_i mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE.
**(d) Clipped policy-gradient loss with KL leash.** Per completion token,
`ρ_t = exp(logπθ_t logπθ_old_t)` (old = policy at rollout time), token loss
`min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3
estimator.
**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage →
capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref`
fixed throughout.
Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations.
Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded
(tokens/s of generation vs. update) **before** any rollout optimization, per the project's
profile-first rule.
Correctness gates:
- PG-loss finite-diff grad-check.
- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG;
`β = 0` ⇒ no KL term.
- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a
prerequisite of GRPO.)
- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must
rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal
that the loop is correct, independent of fuzzy chat quality).
## Evaluation
- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the
fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/
after — reusing and extending the doc-13 recommendation for a generation-based eval harness
(exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task
pass rate, and the same fixed-prompt suite.
- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower
post-training loss did not mean better generations.
## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)
- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy
optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations
small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new
variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be
measured, not assumed.
## Correctness-gate philosophy (unchanged from Phase 1/2)
Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity
on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1,
ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic
RL overfit), and (5) **no change to the default training path** when post-training flags are
off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference
gates as flash/GQA.
## Risks & tradeoffs
- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real
build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an
isolated, separately-gated sub-phase.
- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation:
synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first.
- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference
logprob caching.
## Resolved decisions (aligned 2026-06-29)
- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.**
- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional).
- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not
naive-first), as a foundational milestone before DPO/GRPO.
- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with
a deterministic exact-match reward, for a clean, falsifiable RL signal.
## Milestones (locked order)
1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable
task; produces the reference + init checkpoint. Gate: masking unit test; single-turn
bit-identical to `fbf4ac2`.
2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental
decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute
greedy**; record decode throughput baseline.
3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op
(grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO
training loop → run + reward-margin / preference-accuracy curve.
4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage +
clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic
verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run.
5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate;
enables GRPO-with-RM for general chat.
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
## Implementation log
### M1 — SFT task baseline (landed)
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
needed); the SFT run + eval ran on dash5 (1×5090). **Result: SFT moves answer-format
adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys
the format; correctness is M3/M4's job).**
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
- Two-operand integer arithmetic, ops `+ ×`; operands `[0,999]` for `+/`, `[0,99]` for `×`
(modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep
the unique-key space ≫ requested rows — see the saturation guard below.)
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
and M4 (GRPO reward).
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
Train rows are deduped; eval is held out from train (no leakage). A **saturation guard**
(`unique_space()` + `assert need·5 ≤ space·4`) rejects requests that approach the unique-key
space, since deduped train + disjoint eval near saturation get pathologically slow (or, for
the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys,
so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s).
**Scorer (`crates/xtrain-train/src/bin/eval_arith.rs`):** loads a checkpoint, greedily
generates a continuation per held-out prompt, isolates the first answer segment (cut at the
first `<|endoftext|>` then first newline), and reports two signals via the shared checker —
**format** (fraction emitting any `\boxed{int}`) and **correctness** (exact-match vs. gold).
This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the *naive*
no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete
motivation for M2 (the KV-cache decode engine).
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
single-turn path is bit-identical to `fbf4ac2`).
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
train/eval leakage.
**Run (dash5, 1×5090, from the v12 1.05B base):**
1. dataset: `gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir <dir>` → 20 000 train +
500 held-out eval, 0 leakage.
2. SFT: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt> --heads 52
--head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256
--batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt` → the P0
reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s.
3. eval: `eval_arith <ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt
--gold-file <dir>/arith_eval_gold.txt --max-tokens 32`, base vs. SFT, on 100 held-out prompts.
**M1 result (100 held-out prompts, greedy, max_new 32):**
| checkpoint | format (`\boxed{}`) | correct (exact-match) |
|---------------------|----------------------|-----------------------|
| v12 base (pre-SFT) | 0 / 100 (0%) | 0 / 100 (0%) |
| arith SFT | **100 / 100 (100%)** | 8 / 100 (8%) |
The base model never emits the format — it answers `"I don't know."` / restates the question
and stops. SFT moves format **0% → 100%**: every completion cleanly restates the equation and
boxes an integer (`46 * 80 = \boxed{3380}.`). Correctness is only **8%**: the format is fully
learned but the *arithmetic* is the base model's own weak capability — e.g. it boxes 3380 for
gold 3680, 10 for gold 5; it does get some right (`895 353 = \boxed{542}.` ✓). That residual
gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close.
**Gate met:** format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
skill, which is downstream by design.
### M2a — KV-cache incremental-decode engine (single sequence, landed)
The decode engine (D3, built up front) that replaces the naive sampler — which re-runs the
full forward over the growing prefix every step (O(t²), a fresh autograd graph per token). Two
forward-only primitives + a raw-Tensor per-token block forward, each gated in isolation.
**Primitives (`xtrain-tensor`, both forward-only):**
- `Tensor::rope_at(theta, pos0)` — RoPE at a token's *absolute* position (`pos = pos0 + row`,
no modulo), vs the training `rope` (`pos = row % period`) which is left untouched (new CUDA
kernel `rope_at_k` → no training-path risk). Cached K is stored post-RoPE, so it must match
what the full forward produced at that position. **Gate:** bit-identical to the full-sequence
rope's row `t` (`integration::rope_at_matches_full_rope_row`).
- `Tensor::decode_attention(k, v, scale)` — single-query × cached-K/V SDPA (`[bh,1,hd]` vs
`[bh,t,hd]`, no causal mask: the one query sees all cached keys). Composed from the existing
strided batched GEMM + plain softmax — **no new kernel**. **Gate:** equals the full causal
attention's last query row, max |Δ| 6e-8 (`integration::decode_attention_matches_…`).
**Engine (`xtrain-model/src/decode.rs`, `generate_greedy_cached`):** per-layer K/V cache +
single-token incremental forward. Prefill = the first `prompt.len()` decode steps (one code
path). Mirrors `model::block_forward` at the raw-Tensor level (no autograd tape — inference
needs no grads), pulling weights via the public `params()` stable order (no model-internal
visibility changes). The cache is host-accumulated token-major f32, rebuilt per step — the
honest M2a baseline; M2b moves it device-side + adds batched ragged decode.
**Gate (the M2 centerpiece — token-identical):** KV-cache greedy decode is byte-for-byte the
same token sequence as the naive full-recompute greedy. Verified two ways:
- `xtrain-train/tests/decode_kv.rs` — small GQA model (8 query / 2 kv heads), F32, 24 generated
tokens, exact token-equality. (Unit gate runs F32: a random model's near-uniform logits make
argmax fragile to ~1e-6, so the tightest path is used; the trained model below has peaked
logits → robust.)
- v12 1.05B SFT checkpoint: `eval_arith --cached` produces the **identical** eval outcome to the
naive run (format 100/100, correct 8/100) and byte-identical completions.
**Throughput baseline (v12 1.05B, batch 1, F32, profile-first — measured, not assumed):** the
cache win is **sequence-length-dependent**, which is the honest systems finding here:
| max_new | naive | kv-cache | note |
|---------|-------|----------|------|
| 32 | 108 tok/s | 111 tok/s | ~1.0× — both **launch/overhead-bound** at short seq |
| 128 | 69 tok/s | **133 tok/s** | **~1.9×** — naive's O(t²) recompute starts to bite |
| 256 | **OOM** | 129 tok/s | naive rebuilds the O(seq²) graph every step → OOM |
Cached throughput stays ~constant (O(1)/token compute + constant memory); naive **decays**
(108→69 tok/s, O(t)/token) and eventually **OOMs** (the full autograd graph per step). So at the
short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing — it matters for
**long rollouts** (DPO pair-generation, GRPO completions), exactly where M3/M4 use it. (M2a's
per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache
targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral):
the win is real but only in the regime that actually stresses the bottleneck.
### M3 — DPO (offline preference optimization, landed; honest negative result)
The first real alignment method. Infra landed and gated; the empirical finding is that DPO
**does not improve held-out arithmetic correctness on this task** — a genuine, on-theme negative
result (the design doc's "RL is finicky" risk, made concrete).
**Two new autograd ops (`xtrain-autodiff`, both reuse the CE kernel — no new CUDA):**
- `seq_logprob(logits, target)` = `Σ log πθ(target)` over non-ignored positions (the per-
sequence logprob DPO compares). `= −Σ per_row` of cross_entropy (ignored rows already 0, like
SFT masking); backward = `cross_entropy_backward(probs, target, upstream)` (SUM, no mean).
**Gate:** finite-diff grad-check with a `-100` completion mask.
- `dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β)` = `log σ(Δ)` with the
two policy logprobs as parents (ref logprobs constant). **Gate:** grad-check both parents +
degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0).
**Pair construction (`gen_dpo_pairs`, aligned decision):** chosen = gold answer; rejected = the
SFT model's own **greedy** (KV-cache engine, M2a) completion when it's a format-valid WRONG
boxed answer — a hard negative in the model's distribution. Since SFT is ~8% correct (M1),
greedy is wrong ~92% of the time, so this is fast and deterministic; ~8% of prompts are skipped
(greedy correct). 1500 pairs generated (158 skipped) in ~8 min.
**Training (`train_dpo`):** loads the SFT ckpt as policy AND frozen reference; **precomputes the
reference logprobs once** (while policy == reference) and caches them — one resident model. Each
step forwards the policy on chosen + rejected, `seq_logprob` each, minimises `dpo_loss`; the two
forwards share params so backward accumulates both branches. Loss **starts at exactly log2**
(Δ=0 at init) — a built-in correctness check that fired correctly. Tracks reward margin +
preference accuracy.
**Result (v12 1.05B, 1500 pairs, β=0.1; 100 held-out prompts, vs the SFT baseline format
100/100, correct 8/100):**
| run | reward margin | pref-acc | format | correct |
|---------------------------|---------------|----------|--------|---------|
| SFT (baseline) | — | — | 100/100 | 8/100 |
| DPO lr 5e-7 × 300 | +0.78 | ~82% | 100/100 | 7/100 |
| DPO lr 5e-7 × 800 | +1.25 | ~82% | 100/100 | 5/100 |
| DPO lr 1e-6 × 2000 | **+34.2** | ~76% | **0/100** | 0/100 |
The reward margin and preference accuracy rise cleanly (the loss IS being optimized — the infra
is correct), but the implicit reward **does not transfer to held-out correctness**: it stays
~58% (all within the ~2.7% std-error of 100 prompts — statistically flat), and pushing harder
**over-optimizes to collapse** (margin +34 = huge KL from the reference → the model emits
garbage, `46 * 80 = CRAFTIE SERIES SERIES…`, format 0%).
**The lesson (why):** chosen and rejected differ only in the final number tokens, so DPO raises
`log p(correct) log p(wrong)` for the *specific* training pairs — it **reweights the existing
distribution, it does not install the capability**. The base model has no arithmetic algorithm,
so preferring correct-vs-wrong final answers on seen pairs cannot generalize to unseen problems;
and the only way to drive the margin far is to globally distort the distribution → incoherence.
**DPO works when the chosen is already plausible under the policy; it cannot manufacture
knowledge the model lacks.** This is the precise motivation for **M4 GRPO**: optimize the *actual
verifiable reward* online (sample → check → reinforce what is genuinely correct), rather than a
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
margin/acc rise); the correctness flatness is the reported finding, not a bug.
### M4 — GRPO (online RL, critic-free, landed; infra + two honest systems walls)
The centerpiece: generation INSIDE the training loop. Infra built and gated; the run surfaces
two concrete systems findings (the memory long-pole + the rollout long-pole, both flagged in the
design doc's Risks) and the same capability wall as M3.
**Task made learnable first (per the aligned decision "easier task → then M4"):** the v12 SFT
model scores ~8% on the hard task *and* on easy problems — it learned format, not arithmetic. So
the easy task (operands ≤20, ops `+ ×`) was re-SFT'd from the v12 base → **held-out 18.7%**
(100% format), a baseline with reward variance for GRPO. Note: even easy arithmetic plateaus at
~19% held-out (250 vs 600 SFT steps identical) — a 1B web-text model does not generalize the
add/sub algorithm from ~550 examples; it memorizes train (982 total problems, 550 seen).
**New op (`xtrain-autodiff`, reuses the CE kernel + one new primitive):**
- `clipped_pg_loss(logits, target, logp_old, logp_ref, A, ε, β)` — per completion token
`ρ_t = exp(logπθ_t logp_old_t)`, `L = mean min(ρA, clip(ρ,1±ε)A) + β·mean KL` (k3), masked
to completion tokens. Backward reuses `(probs onehot)` + `scale_rows` (a new ~5-line per-row
scale kernel — the per-token coefficient varies, which CE-backward's single scalar can't
express). **Gate:** grad-check the active PG path + the A=0 (KL-only) path; degenerate value
checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
**Loop (`train_grpo`):** per step — sample B prompts, roll out G completions each, score (reward
0/1), group-relative advantage `A=(rmean)/(std+ε)` (no critic; all-correct/all-wrong groups
skipped — zero advantage), capture `logπθ_old`/`logπref` per token, K inner clipped-PG epochs.
Rollout uses the M2 KV-cache engine with **temperature sampling** (added in M4): single-row
`[1,vocab]` logits per step vs the naive sampler's `[seq,vocab]`.
**Systems wall #1 — memory (the design doc's "two/three resident models"):** KL-leash GRPO needs
policy + frozen reference, two 1.05B fp32-master models + AdamW m/v ≈ 21 GB fixed + training
activations → unreliably OOMs on a 32 GB 5090 (fragmentation tips it over). To get a completing
run, `β=0` (pure PG) drops the reference model (4.2 GB). So the *principled* KL-leash version is
memory-bound at this model size on this hardware — a real, reported constraint, not a bug.
**Systems wall #2 — rollout (the design doc's "rollout is the long pole"):** the naive sampler's
growing `[seq,vocab]` allocations fragment the caching allocator over a long rollout → OOM. The
cached temperature rollout (single-row logits) is lighter; but single-sequence cached decode is
slow (the M2a host-round-trip), so rollout still dominates wall-clock (~16 s/step at G=6·B=6).
Batched ragged decode (M2b) is the real fix and is deferred to where it is load-bearing.
**Result (easy task, β=0, G=6·B=6, 40 steps, lr 5e-7; 150 held-out, vs SFT 28/150 = 18.7%):**
mean rollout reward fluctuates ~0.580.81 (noisy, inflated by train-set overlap in the sampled
problems); **format stays 100/100** (no collapse even without the KL leash, at this gentle lr);
**held-out 30/150 = 20.0%**`+1.3 pp`, within the ~3% std-error of 150 prompts, i.e.
**statistically flat**, the same wall as M3 DPO.
**The consistent M3+M4 lesson:** on a task where the base model lacks the underlying capability,
**neither offline preference optimization (DPO) nor online RL (GRPO) moves held-out correctness**
— each optimizes its objective (margin / reward) on the *training distribution* it can reach
(here inflated by memorization), but cannot install a *generalizable* algorithm the model never
had. RL reinforces what the model already does; it does not teach arithmetic. Gate met for M4 =
the infra is correct (PG/KL grad-checks + degenerate checks, the loop runs, reward signal + KL
leash wired, format held); the held-out flatness + the two memory/throughput walls are the
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
limits on what alignment can do for a capability the base model lacks.
### M2b — batched KV-cache decode (landed; completes the M2 engine, fixes the rollout long-pole)
Built after M4 (where the rollout long-pole bit hardest): decode the **G samples of one prompt in
lockstep** — one forward per step over the whole group → G× fewer kernel launches, the deferred
fix from M2a.
**One new primitive:** `rope_pos(x, positions[])` — RoPE with a *per-row* absolute position (new
forward-only kernel), since the G batched rows share one decode position (M2a's `rope_at` does
`pos0 + row`, wrong for a batch at a single position). **Gate:** bit-identical to the full rope
for positions `[0..n]`, and to `rope_at(P)` per row for a uniform `P`.
**Engine (`generate_cached_batch`):** `BatchKVCache` carries a G dimension (`[T, G·num_kv, hd]`
host-accumulated → `[G·num_kv, T, hd]`); the batched `decode_step` threads G through embed /
projections / QK-norm / `rope_pos` / cache. Two M2a pieces drop in unchanged: `decode_attention`
is already batch-agnostic (`bh = G·nh`), and `repeat_kv(nh, batch=G)` broadcasts per group. No
finished-mask (all G generate `max_new`; the caller cuts at EOS) and no ragged-length prompts yet
— both perf-only follow-ups.
**Gate (token-identical):** all G **greedy** rows are byte-identical to the single-sequence decode
(`tests/decode_batch.rs`, 8 query / 2 kv heads → exercises the `repeat_kv` batching) — pins that
G-way batching indexes each sequence's K/V with no cross-row contamination.
**Throughput (v12 1.05B, G=6·B=6, easy task, rollout wired into `train_grpo`):** ~8.5 s/step vs
~1416 s/step for the single-seq cached rollout — **~1.7×**, rollout-inclusive. Short of the full
G× because (a) the per-token-logp forwards + the PG update also cost, and (b) the M2a per-layer
**host round-trip** is still there (now G× the data in one transfer, not removed). The full
device-side cache (no host round-trip) is the remaining decode-engine optimization. Batching also
**stabilises memory**: one batched forward per step vs G separate allocations that fragmented the
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
with the device-cache as the named next lever.
### M2c — device-side KV cache (landed; the bottleneck moved, a profile-first finding)
The named M2b follow-up: keep K/V on the GPU (`[bh,T,hd]`, an `Option<Tensor>` per layer) and
grow it by one token per step via a new `cat_seq` kernel (concat along the seq dim) — removing the
M2a/M2b per-layer **host round-trip** (`to_cpu`/`from_slice`/re-upload) *and* the `transpose_3d01`.
Both single-seq and batched decode refactored to it (cleaner than the host `Vec` + rebuild).
**Gates hold:** `cat_seq == host concat`; `decode_kv` single-seq + `decode_batch` G-way both still
**token-identical**; GQA training path unaffected.
**The finding (why this is a measure-first lesson, not a speedup story):** removing the host
round-trip buys **~10%** on *pure* single-seq decode (133 → 147 tok/s @128) but **does not move the
GRPO step** (~8.5 s/step, unchanged). Because after M2b batching, the rollout is no longer the
step's bottleneck — the per-sample **`per_token_logp` captures** (2 forwards/sample) and the
**PG-update** forwards+backwards (`model.forward`, full-sequence, per sample) now dominate. So the
long pole **shifted** from the rollout to the training-side forwards (cf. T11/T17/M2a: profile
before optimizing — the bottleneck you fixed is not the one that remains). The device cache is
still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decode); the honest
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
**Correctness gates (exact, not bf16-noisy):**
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|-------------------------|---------------------|---------------|---------|
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
now dominate the step." The clean per-component bench falsifies the strong form: the training
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
training stack stays complete, now with the step decomposition measured, not asserted.