docs: post-training stack design — SFT → KV-cache → DPO → GRPO (docs/18)
Design doc for a from-scratch post-training infra on top of xtrain. Ladder: SFT (have it) → DPO → reward model (optional) → GRPO, each rung one new post-training systems concept + a hard correctness gate (grad-check, PyTorch parity, degenerate checks, a falsifiable 'it learns' signal). Decisions aligned with the user (D1-D4): - D1 scope: DPO → GRPO, reward model optional. - D2 reward: rule-based / verifiable first; learned RM deferred. - D3 rollout: build the KV-cache incremental-decode engine UP FRONT (not naive-first) as the foundational milestone before DPO/GRPO. - D4 task: a verifiable task (arithmetic/format) with deterministic exact-match reward, for a clean RL signal. Locked milestone order: M1 SFT task baseline → M2 KV-cache decode engine (token-identical gate) → M3 DPO → M4 GRPO → M5 optional reward model. Status: design only, no implementation yet. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
284
docs/18-post-training-rl-sft.md
Normal file
284
docs/18-post-training-rl-sft.md
Normal file
@@ -0,0 +1,284 @@
|
||||
# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document
|
||||
|
||||
> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing
|
||||
> implemented yet. This doc proposes the scope, the staged build, the new infra pieces,
|
||||
> and the correctness gates for a standard post-training stack on top of the xtrain
|
||||
> training framework. Decisions D1–D4 are resolved (see "Resolved decisions"):
|
||||
> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode
|
||||
> engine built up front · a verifiable task as the optimization/eval target.**
|
||||
|
||||
## Goal
|
||||
|
||||
Build a **standard, from-scratch post-training infrastructure** — the systems layer that
|
||||
turns a pretrained base LM into an aligned chat model — and use it to run chat
|
||||
alignment. The deliverable that matters here is the **infra and the lessons**, not the
|
||||
end-to-end chat quality (see the project's learning-axis framing). Each stage should
|
||||
teach exactly one new post-training systems concept and ship with a hard correctness
|
||||
gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical
|
||||
default paths, profile-first).
|
||||
|
||||
Concretely we want to be able to answer, with our own code:
|
||||
|
||||
- How does **offline preference optimization (DPO)** differ from SFT in the training
|
||||
loop — what is the reference model, why two forwards, what is the loss?
|
||||
- How does a **reward model** turn preferences into a scalar signal?
|
||||
- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring,
|
||||
group-relative advantage, the clipped policy-gradient update, the KL leash?
|
||||
- Where are the **memory and throughput** pressure points that make post-training infra
|
||||
different from pretraining infra (multiple models resident, generation in the loop)?
|
||||
|
||||
## Baseline: what already exists vs. what is missing
|
||||
|
||||
What the framework already gives us (verified in code, reused as-is):
|
||||
|
||||
| capability | where | reuse for post-training |
|
||||
|---|---|---|
|
||||
| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO |
|
||||
| cross-entropy with **ignore-index −100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking |
|
||||
| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference |
|
||||
| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute |
|
||||
| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 2–3 models resident |
|
||||
| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training |
|
||||
| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path |
|
||||
| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) |
|
||||
|
||||
What is **missing** and must be built (these are the actual lessons):
|
||||
|
||||
1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_<t)` over the
|
||||
completion tokens of a sequence. CE gives a *mean* scalar; DPO/GRPO need a *per-sequence
|
||||
masked sum*. New op or thin wrapper over the CE per-row machinery.
|
||||
2. **Frozen reference model** held in memory alongside the trainable policy (no grad, no
|
||||
optimizer), or its logprobs precomputed and cached.
|
||||
3. **Pairwise preference loss** (DPO) and **Bradley-Terry ranking loss** (RM).
|
||||
4. **Reward head** — a `[dim,1]` scalar head reading the last non-pad position (RM only).
|
||||
5. **Rollout / generation engine** — batched autoregressive sampling. Current `generate`
|
||||
is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs
|
||||
batched rollouts; a real **KV-cache incremental-decode engine** is the centerpiece infra
|
||||
build.
|
||||
6. **GRPO machinery** — group sampling, group-relative advantage, clipped PG loss, KL
|
||||
penalty, the actor-learner loop.
|
||||
|
||||
## The post-training landscape — where the infra lives
|
||||
|
||||
```
|
||||
data models in memory new systems concept
|
||||
SFT (prompt, answer) policy loss masking (have it)
|
||||
DPO (prompt, chosen, reject) policy + ref(frozen) dual forward, pairwise logσ loss
|
||||
RM (prompt, chosen, reject) reward model scalar head, ranking loss
|
||||
PPO prompts + reward source policy+ref+RM+critic rollout + GAE + clipped PG (4 models)
|
||||
GRPO prompts + reward source policy+ref(+RM) rollout + group baseline + clipped PG
|
||||
```
|
||||
|
||||
The pedagogical ladder is **SFT → DPO → (RM) → GRPO**. DPO is the cheapest "real" alignment
|
||||
method (no generation, no reward model, reuses the training loop almost verbatim) and is the
|
||||
right first rung. GRPO is chosen over PPO as the online-RL rung because it **drops the value
|
||||
critic** (group-relative advantage replaces the learned baseline) — that removes a whole
|
||||
model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted
|
||||
as an optional later extension, not a primary target.
|
||||
|
||||
## Proposed scope & sequencing (recommended path)
|
||||
|
||||
> ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward
|
||||
> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a
|
||||
> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective
|
||||
> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see
|
||||
> "Milestones").
|
||||
|
||||
### Stage P0 — SFT chat baseline (light; mostly reuse)
|
||||
|
||||
Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen
|
||||
reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic
|
||||
prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable
|
||||
completions; the same template is reused by rollout and eval. The current SFT (commit
|
||||
`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment
|
||||
needs:
|
||||
|
||||
- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used,
|
||||
promoted to a documented constant shared by SFT data prep, rollout, and eval),
|
||||
- optional **multi-turn masking** (supervise every assistant turn, mask user turns),
|
||||
- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE
|
||||
per example — note `forward_batched` already isolates sequences, so packing = careful
|
||||
index bookkeeping, not new attention code).
|
||||
|
||||
Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak
|
||||
loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask
|
||||
gives a reproducible SFT reference without changing the training numerics for single-turn data
|
||||
(bit-identical to `fbf4ac2` on single-turn input).
|
||||
|
||||
### Stage P1 — DPO (offline preference optimization) ⭐ first real method
|
||||
|
||||
New infra:
|
||||
|
||||
1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task
|
||||
there is no off-the-shelf preference set, so we build pairs: sample several completions
|
||||
per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone),
|
||||
score each with the rule-based checker, take a **correct** completion as `chosen` and an
|
||||
**incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training
|
||||
itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a
|
||||
completion mask (prompt = masked).
|
||||
2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of
|
||||
`log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row
|
||||
path (CE per-row = `−log πθ(target)`), summing `−per_row` over the mask. Add a grad-checked
|
||||
op so the backward is exact.
|
||||
3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad**
|
||||
bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and
|
||||
cache reference logprobs** once over the dataset → the reference model need not stay
|
||||
resident during training (one model in memory, like SFT).
|
||||
4. **DPO loss** (Rafailov et al.): with
|
||||
`Δ = β[(logπθ(yw|x) − logπref(yw|x)) − (logπθ(yl|x) − logπref(yl|x))]`,
|
||||
`L = −log σ(Δ)`. Only `πθ` terms carry gradient.
|
||||
|
||||
Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached
|
||||
logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.
|
||||
|
||||
Correctness gates:
|
||||
- `seq_logprob` finite-difference grad-check (tiny model).
|
||||
- DPO-loss + grad **PyTorch parity** (the project's standard gate).
|
||||
- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0;
|
||||
`β → 0` ⇒ gradient → 0.
|
||||
- **Health metric**: chosen−rejected **reward margin** rises over training; accuracy
|
||||
(margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a
|
||||
sufficient signal).
|
||||
|
||||
Application: chat alignment via DPO on English preference pairs. This is the **offline
|
||||
chat-alignment deliverable**.
|
||||
|
||||
### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL
|
||||
|
||||
> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO
|
||||
> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only
|
||||
> if we later want general-chat GRPO). So this whole stage is optional and not on the critical
|
||||
> path.
|
||||
|
||||
New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last
|
||||
non-pad position; **ranking loss** `−log σ(r(x,yw) − r(x,yl))`. Reuses the preference data
|
||||
and the dual-sequence forward from P1.
|
||||
|
||||
Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM
|
||||
loads/serves the scalar correctly.
|
||||
|
||||
### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson
|
||||
|
||||
This is the centerpiece. It introduces **generation inside the training loop**.
|
||||
|
||||
**(a) Rollout / generation engine — built up front (its own milestone).**
|
||||
|
||||
> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine
|
||||
> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is
|
||||
> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted:
|
||||
> front-loads the single hardest build and delays the first alignment result, in exchange for
|
||||
> a real generation engine and a clean, isolated infra lesson.
|
||||
|
||||
The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt
|
||||
once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts
|
||||
× G samples; sequences hit EOS at different lengths → finished-mask / left-padding /
|
||||
compaction). The current attention assumes a full causal window over `seq`; incremental decode
|
||||
needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with
|
||||
RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a
|
||||
distinct code path without disturbing the training attention (flash/GQA/composed unchanged).
|
||||
|
||||
Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode,
|
||||
token-identical** greedy output — the same byte-/token-identical discipline the project uses
|
||||
for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs.
|
||||
per-token decode) is recorded here, before any rollout optimization (profile-first).
|
||||
|
||||
**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic
|
||||
arithmetic/format task) or RM from P2. Returns a scalar per completion.
|
||||
|
||||
**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage
|
||||
`A_i = (r_i − mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE.
|
||||
|
||||
**(d) Clipped policy-gradient loss with KL leash.** Per completion token,
|
||||
`ρ_t = exp(logπθ_t − logπθ_old_t)` (old = policy at rollout time), token loss
|
||||
`−min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3
|
||||
estimator.
|
||||
|
||||
**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage →
|
||||
capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref`
|
||||
fixed throughout.
|
||||
|
||||
Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations.
|
||||
Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded
|
||||
(tokens/s of generation vs. update) **before** any rollout optimization, per the project's
|
||||
profile-first rule.
|
||||
|
||||
Correctness gates:
|
||||
- PG-loss finite-diff grad-check.
|
||||
- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG;
|
||||
`β = 0` ⇒ no KL term.
|
||||
- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a
|
||||
prerequisite of GRPO.)
|
||||
- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must
|
||||
rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal
|
||||
that the loop is correct, independent of fuzzy chat quality).
|
||||
|
||||
## Evaluation
|
||||
|
||||
- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the
|
||||
fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/
|
||||
after — reusing and extending the doc-13 recommendation for a generation-based eval harness
|
||||
(exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
|
||||
- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task
|
||||
pass rate, and the same fixed-prompt suite.
|
||||
- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower
|
||||
post-training loss did not mean better generations.
|
||||
|
||||
## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)
|
||||
|
||||
- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy
|
||||
optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations
|
||||
small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
|
||||
- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
|
||||
- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new
|
||||
variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be
|
||||
measured, not assumed.
|
||||
|
||||
## Correctness-gate philosophy (unchanged from Phase 1/2)
|
||||
|
||||
Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity
|
||||
on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1,
|
||||
ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic
|
||||
RL overfit), and (5) **no change to the default training path** when post-training flags are
|
||||
off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference
|
||||
gates as flash/GQA.
|
||||
|
||||
## Risks & tradeoffs
|
||||
|
||||
- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real
|
||||
build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an
|
||||
isolated, separately-gated sub-phase.
|
||||
- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation:
|
||||
synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
|
||||
- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first.
|
||||
- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
|
||||
- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference
|
||||
logprob caching.
|
||||
|
||||
## Resolved decisions (aligned 2026-06-29)
|
||||
|
||||
- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.**
|
||||
- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional).
|
||||
- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not
|
||||
naive-first), as a foundational milestone before DPO/GRPO.
|
||||
- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with
|
||||
a deterministic exact-match reward, for a clean, falsifiable RL signal.
|
||||
|
||||
## Milestones (locked order)
|
||||
|
||||
1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable
|
||||
task; produces the reference + init checkpoint. Gate: masking unit test; single-turn
|
||||
bit-identical to `fbf4ac2`.
|
||||
2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental
|
||||
decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute
|
||||
greedy**; record decode throughput baseline.
|
||||
3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op
|
||||
(grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO
|
||||
training loop → run + reward-margin / preference-accuracy curve.
|
||||
4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage +
|
||||
clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic
|
||||
verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run.
|
||||
5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate;
|
||||
enables GRPO-with-RM for general chat.
|
||||
|
||||
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
|
||||
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
|
||||
Reference in New Issue
Block a user