diff --git a/docs/18-post-training-rl-sft.md b/docs/18-post-training-rl-sft.md new file mode 100644 index 0000000..4fddd11 --- /dev/null +++ b/docs/18-post-training-rl-sft.md @@ -0,0 +1,284 @@ +# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document + +> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing +> implemented yet. This doc proposes the scope, the staged build, the new infra pieces, +> and the correctness gates for a standard post-training stack on top of the xtrain +> training framework. Decisions D1–D4 are resolved (see "Resolved decisions"): +> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode +> engine built up front · a verifiable task as the optimization/eval target.** + +## Goal + +Build a **standard, from-scratch post-training infrastructure** — the systems layer that +turns a pretrained base LM into an aligned chat model — and use it to run chat +alignment. The deliverable that matters here is the **infra and the lessons**, not the +end-to-end chat quality (see the project's learning-axis framing). Each stage should +teach exactly one new post-training systems concept and ship with a hard correctness +gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical +default paths, profile-first). + +Concretely we want to be able to answer, with our own code: + +- How does **offline preference optimization (DPO)** differ from SFT in the training + loop — what is the reference model, why two forwards, what is the loss? +- How does a **reward model** turn preferences into a scalar signal? +- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring, + group-relative advantage, the clipped policy-gradient update, the KL leash? +- Where are the **memory and throughput** pressure points that make post-training infra + different from pretraining infra (multiple models resident, generation in the loop)? + +## Baseline: what already exists vs. what is missing + +What the framework already gives us (verified in code, reused as-is): + +| capability | where | reuse for post-training | +|---|---|---| +| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO | +| cross-entropy with **ignore-index −100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking | +| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference | +| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute | +| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 2–3 models resident | +| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training | +| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path | +| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) | + +What is **missing** and must be built (these are the actual lessons): + +1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_ ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward +> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a +> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective +> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see +> "Milestones"). + +### Stage P0 — SFT chat baseline (light; mostly reuse) + +Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen +reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic +prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable +completions; the same template is reused by rollout and eval. The current SFT (commit +`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment +needs: + +- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used, + promoted to a documented constant shared by SFT data prep, rollout, and eval), +- optional **multi-turn masking** (supervise every assistant turn, mask user turns), +- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE + per example — note `forward_batched` already isolates sequences, so packing = careful + index bookkeeping, not new attention code). + +Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak +loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask +gives a reproducible SFT reference without changing the training numerics for single-turn data +(bit-identical to `fbf4ac2` on single-turn input). + +### Stage P1 — DPO (offline preference optimization) ⭐ first real method + +New infra: + +1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task + there is no off-the-shelf preference set, so we build pairs: sample several completions + per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone), + score each with the rule-based checker, take a **correct** completion as `chosen` and an + **incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training + itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a + completion mask (prompt = masked). +2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of + `log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row + path (CE per-row = `−log πθ(target)`), summing `−per_row` over the mask. Add a grad-checked + op so the backward is exact. +3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad** + bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and + cache reference logprobs** once over the dataset → the reference model need not stay + resident during training (one model in memory, like SFT). +4. **DPO loss** (Rafailov et al.): with + `Δ = β[(logπθ(yw|x) − logπref(yw|x)) − (logπθ(yl|x) − logπref(yl|x))]`, + `L = −log σ(Δ)`. Only `πθ` terms carry gradient. + +Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached +logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably. + +Correctness gates: +- `seq_logprob` finite-difference grad-check (tiny model). +- DPO-loss + grad **PyTorch parity** (the project's standard gate). +- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0; + `β → 0` ⇒ gradient → 0. +- **Health metric**: chosen−rejected **reward margin** rises over training; accuracy + (margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a + sufficient signal). + +Application: chat alignment via DPO on English preference pairs. This is the **offline +chat-alignment deliverable**. + +### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL + +> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO +> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only +> if we later want general-chat GRPO). So this whole stage is optional and not on the critical +> path. + +New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last +non-pad position; **ranking loss** `−log σ(r(x,yw) − r(x,yl))`. Reuses the preference data +and the dual-sequence forward from P1. + +Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM +loads/serves the scalar correctly. + +### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson + +This is the centerpiece. It introduces **generation inside the training loop**. + +**(a) Rollout / generation engine — built up front (its own milestone).** + +> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine +> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is +> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted: +> front-loads the single hardest build and delays the first alignment result, in exchange for +> a real generation engine and a clean, isolated infra lesson. + +The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt +once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts +× G samples; sequences hit EOS at different lengths → finished-mask / left-padding / +compaction). The current attention assumes a full causal window over `seq`; incremental decode +needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with +RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a +distinct code path without disturbing the training attention (flash/GQA/composed unchanged). + +Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode, +token-identical** greedy output — the same byte-/token-identical discipline the project uses +for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs. +per-token decode) is recorded here, before any rollout optimization (profile-first). + +**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic +arithmetic/format task) or RM from P2. Returns a scalar per completion. + +**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage +`A_i = (r_i − mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE. + +**(d) Clipped policy-gradient loss with KL leash.** Per completion token, +`ρ_t = exp(logπθ_t − logπθ_old_t)` (old = policy at rollout time), token loss +`−min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3 +estimator. + +**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage → +capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref` +fixed throughout. + +Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations. +Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded +(tokens/s of generation vs. update) **before** any rollout optimization, per the project's +profile-first rule. + +Correctness gates: +- PG-loss finite-diff grad-check. +- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG; + `β = 0` ⇒ no KL term. +- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a + prerequisite of GRPO.) +- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must + rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal + that the loop is correct, independent of fuzzy chat quality). + +## Evaluation + +- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the + fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/ + after — reusing and extending the doc-13 recommendation for a generation-based eval harness + (exact-match math, code syntax, stop-token, refusal appropriateness, corruption). +- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task + pass rate, and the same fixed-prompt suite. +- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower + post-training loss did not mean better generations. + +## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative) + +- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy + optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations + small; this is why post-training reuses the Phase-1/2 memory levers unchanged. +- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits. +- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new + variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be + measured, not assumed. + +## Correctness-gate philosophy (unchanged from Phase 1/2) + +Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity +on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1, +ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic +RL overfit), and (5) **no change to the default training path** when post-training flags are +off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference +gates as flash/GQA. + +## Risks & tradeoffs + +- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real + build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an + isolated, separately-gated sub-phase. +- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation: + synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward. +- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first. +- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here. +- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference + logprob caching. + +## Resolved decisions (aligned 2026-06-29) + +- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.** +- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional). +- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not + naive-first), as a foundational milestone before DPO/GRPO. +- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with + a deterministic exact-match reward, for a clean, falsifiable RL signal. + +## Milestones (locked order) + +1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable + task; produces the reference + init checkpoint. Gate: masking unit test; single-turn + bit-identical to `fbf4ac2`. +2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental + decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute + greedy**; record decode throughput baseline. +3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op + (grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO + training loop → run + reward-margin / preference-accuracy curve. +4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage + + clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic + verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run. +5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate; + enables GRPO-with-RM for general chat. + +> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a +> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.