Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput. The honest decomposition correction: M2c claimed the training forwards "dominate" the step; the clean per-component bench falsifies the strong form — they were ~2.5 s of the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G rollout batching (today only the G samples of each prompt decode in lockstep; the B prompts are still sequential). Same measure-first lesson, once more. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
630 lines
42 KiB
Markdown
630 lines
42 KiB
Markdown
# Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document
|
||
|
||
> Status: **DESIGN — decisions locked, pending go-ahead to implement.** Nothing
|
||
> implemented yet. This doc proposes the scope, the staged build, the new infra pieces,
|
||
> and the correctness gates for a standard post-training stack on top of the xtrain
|
||
> training framework. Decisions D1–D4 are resolved (see "Resolved decisions"):
|
||
> **DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode
|
||
> engine built up front · a verifiable task as the optimization/eval target.**
|
||
|
||
## Goal
|
||
|
||
Build a **standard, from-scratch post-training infrastructure** — the systems layer that
|
||
turns a pretrained base LM into an aligned chat model — and use it to run chat
|
||
alignment. The deliverable that matters here is the **infra and the lessons**, not the
|
||
end-to-end chat quality (see the project's learning-axis framing). Each stage should
|
||
teach exactly one new post-training systems concept and ship with a hard correctness
|
||
gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical
|
||
default paths, profile-first).
|
||
|
||
Concretely we want to be able to answer, with our own code:
|
||
|
||
- How does **offline preference optimization (DPO)** differ from SFT in the training
|
||
loop — what is the reference model, why two forwards, what is the loss?
|
||
- How does a **reward model** turn preferences into a scalar signal?
|
||
- How does **online RL (GRPO)** actually run — the rollout engine, reward scoring,
|
||
group-relative advantage, the clipped policy-gradient update, the KL leash?
|
||
- Where are the **memory and throughput** pressure points that make post-training infra
|
||
different from pretraining infra (multiple models resident, generation in the loop)?
|
||
|
||
## Baseline: what already exists vs. what is missing
|
||
|
||
What the framework already gives us (verified in code, reused as-is):
|
||
|
||
| capability | where | reuse for post-training |
|
||
|---|---|---|
|
||
| batched forward → logits `[B*S, vocab]` | `model.rs::forward_batched` | logprob extraction for DPO/RM/GRPO |
|
||
| cross-entropy with **ignore-index −100** | `ops.rs::cross_entropy`, `nn.cu` | assistant-only / completion-only masking |
|
||
| assistant-only **SFT** (TSV, masked labels) | `data.rs::load_sft_tsv_cached` (commit `fbf4ac2`) | SFT chat baseline = DPO init + reference |
|
||
| bf16 mixed precision, fp32 master | `with_compute_dtype` | policy + frozen reference both bf16 compute |
|
||
| recompute / flash / grad-accum | `with_recompute` / `with_flash` / `--accum-steps` | bound activation memory with 2–3 models resident |
|
||
| DDP (thread + process-per-GPU) | `xtrain-distributed` | data-parallel post-training |
|
||
| AdamW + clip + LR sched + checkpoint | `xtrain-optim`, `checkpoint.rs`, `schedule.rs` | unchanged optimizer path |
|
||
| single-seq greedy/temperature sampling | `sample.rs::generate` | **slow** rollout fallback (no KV cache) |
|
||
|
||
What is **missing** and must be built (these are the actual lessons):
|
||
|
||
1. **Per-sequence completion logprob** — a way to read `Σ log πθ(y_t | x, y_<t)` over the
|
||
completion tokens of a sequence. CE gives a *mean* scalar; DPO/GRPO need a *per-sequence
|
||
masked sum*. New op or thin wrapper over the CE per-row machinery.
|
||
2. **Frozen reference model** held in memory alongside the trainable policy (no grad, no
|
||
optimizer), or its logprobs precomputed and cached.
|
||
3. **Pairwise preference loss** (DPO) and **Bradley-Terry ranking loss** (RM).
|
||
4. **Reward head** — a `[dim,1]` scalar head reading the last non-pad position (RM only).
|
||
5. **Rollout / generation engine** — batched autoregressive sampling. Current `generate`
|
||
is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs
|
||
batched rollouts; a real **KV-cache incremental-decode engine** is the centerpiece infra
|
||
build.
|
||
6. **GRPO machinery** — group sampling, group-relative advantage, clipped PG loss, KL
|
||
penalty, the actor-learner loop.
|
||
|
||
## The post-training landscape — where the infra lives
|
||
|
||
```
|
||
data models in memory new systems concept
|
||
SFT (prompt, answer) policy loss masking (have it)
|
||
DPO (prompt, chosen, reject) policy + ref(frozen) dual forward, pairwise logσ loss
|
||
RM (prompt, chosen, reject) reward model scalar head, ranking loss
|
||
PPO prompts + reward source policy+ref+RM+critic rollout + GAE + clipped PG (4 models)
|
||
GRPO prompts + reward source policy+ref(+RM) rollout + group baseline + clipped PG
|
||
```
|
||
|
||
The pedagogical ladder is **SFT → DPO → (RM) → GRPO**. DPO is the cheapest "real" alignment
|
||
method (no generation, no reward model, reuses the training loop almost verbatim) and is the
|
||
right first rung. GRPO is chosen over PPO as the online-RL rung because it **drops the value
|
||
critic** (group-relative advantage replaces the learned baseline) — that removes a whole
|
||
model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted
|
||
as an optional later extension, not a primary target.
|
||
|
||
## Proposed scope & sequencing (recommended path)
|
||
|
||
> ✅ **DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward
|
||
> model) optional.** With D3 locked to "KV-cache engine up front", the engine becomes a
|
||
> foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective
|
||
> build order: **P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional)** (see
|
||
> "Milestones").
|
||
|
||
### Stage P0 — SFT chat baseline (light; mostly reuse)
|
||
|
||
Goal: a clean SFT checkpoint to serve as **both the DPO/GRPO init and the frozen
|
||
reference**. With D4 = verifiable task, P0 SFT teaches the **task format** (e.g. arithmetic
|
||
prompts → a parseable answer such as `\boxed{N}`) so the model emits checker-readable
|
||
completions; the same template is reused by rollout and eval. The current SFT (commit
|
||
`fbf4ac2`) already does single-turn assistant-only masking; P0 only adds what alignment
|
||
needs:
|
||
|
||
- a fixed **chat template** (the `User:/Assistant:` + `<|endoftext|>` format already used,
|
||
promoted to a documented constant shared by SFT data prep, rollout, and eval),
|
||
- optional **multi-turn masking** (supervise every assistant turn, mask user turns),
|
||
- optional **sequence packing** (concatenate examples to fill `seq`, reset attention/RoPE
|
||
per example — note `forward_batched` already isolates sequences, so packing = careful
|
||
index bookkeeping, not new attention code).
|
||
|
||
Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak
|
||
loss across example boundaries. **Hypothesis:** a documented chat template + multi-turn mask
|
||
gives a reproducible SFT reference without changing the training numerics for single-turn data
|
||
(bit-identical to `fbf4ac2` on single-turn input).
|
||
|
||
### Stage P1 — DPO (offline preference optimization) ⭐ first real method
|
||
|
||
New infra:
|
||
|
||
1. **Preference data — constructed from the verifiable checker (D4).** On a verifiable task
|
||
there is no off-the-shelf preference set, so we build pairs: sample several completions
|
||
per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone),
|
||
score each with the rule-based checker, take a **correct** completion as `chosen` and an
|
||
**incorrect** one as `rejected`. This is a one-time offline data-prep step; DPO training
|
||
itself is then static. Tokenize each as `template(prompt) + completion + EOS`; build a
|
||
completion mask (prompt = masked).
|
||
2. **`seq_logprob(logits, target_ids, mask) → [B]`**: per-sequence sum of
|
||
`log softmax(logits)[target]` over masked positions. Implement by reusing the CE per-row
|
||
path (CE per-row = `−log πθ(target)`), summing `−per_row` over the mask. Add a grad-checked
|
||
op so the backward is exact.
|
||
3. **Frozen reference** `πref`: load the SFT checkpoint into a second model in **eval/no-grad**
|
||
bf16. Its logprobs are **constants** in the loss. Optimization to teach: **precompute and
|
||
cache reference logprobs** once over the dataset → the reference model need not stay
|
||
resident during training (one model in memory, like SFT).
|
||
4. **DPO loss** (Rafailov et al.): with
|
||
`Δ = β[(logπθ(yw|x) − logπref(yw|x)) − (logπθ(yl|x) − logπref(yl|x))]`,
|
||
`L = −log σ(Δ)`. Only `πθ` terms carry gradient.
|
||
|
||
Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached
|
||
logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.
|
||
|
||
Correctness gates:
|
||
- `seq_logprob` finite-difference grad-check (tiny model).
|
||
- DPO-loss + grad **PyTorch parity** (the project's standard gate).
|
||
- **Degenerate checks**: `πθ == πref` at init ⇒ `Δ = 0`, `L = log 2`, implicit reward 0;
|
||
`β → 0` ⇒ gradient → 0.
|
||
- **Health metric**: chosen−rejected **reward margin** rises over training; accuracy
|
||
(margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a
|
||
sufficient signal).
|
||
|
||
Application: chat alignment via DPO on English preference pairs. This is the **offline
|
||
chat-alignment deliverable**.
|
||
|
||
### Stage P2 — Reward model (Bradley-Terry) — OPTIONAL
|
||
|
||
> ✅ **DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first.** GRPO
|
||
> brings up on the deterministic checker; a learned reward model is **deferred/optional** (only
|
||
> if we later want general-chat GRPO). So this whole stage is optional and not on the critical
|
||
> path.
|
||
|
||
New infra: a **scalar reward head** (`[dim,1]`) reading the hidden state at the last
|
||
non-pad position; **ranking loss** `−log σ(r(x,yw) − r(x,yl))`. Reuses the preference data
|
||
and the dual-sequence forward from P1.
|
||
|
||
Gates: ranking-loss grad-check; held-out **pairwise accuracy** (`r_w > r_l`); a frozen RM
|
||
loads/serves the scalar correctly.
|
||
|
||
### Stage P3 — GRPO (online RL, critic-free) ⭐ the deep infra lesson
|
||
|
||
This is the centerpiece. It introduces **generation inside the training loop**.
|
||
|
||
**(a) Rollout / generation engine — built up front (its own milestone).**
|
||
|
||
> ✅ **DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine
|
||
> up front**, as a foundational milestone *before* DPO/GRPO, rather than starting naive. It is
|
||
> then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted:
|
||
> front-loads the single hardest build and delays the first alignment result, in exchange for
|
||
> a real generation engine and a clean, isolated infra lesson.
|
||
|
||
The engine: per-layer **K/V cache**, **single-token incremental forward** (process the prompt
|
||
once to fill the cache, then decode one token at a time), **batched ragged decode** (B prompts
|
||
× G samples; sequences hit EOS at different lengths → finished-mask / left-padding /
|
||
compaction). The current attention assumes a full causal window over `seq`; incremental decode
|
||
needs a **decode-time attention path** — query length 1 against cached K/V of length `t`, with
|
||
RoPE position = `t`. This reuses the composed SDPA shapes (one-row query), so it can land as a
|
||
distinct code path without disturbing the training attention (flash/GQA/composed unchanged).
|
||
|
||
Hard gate (the centerpiece correctness lesson): **KV-cache decode == full-recompute decode,
|
||
token-identical** greedy output — the same byte-/token-identical discipline the project uses
|
||
for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs.
|
||
per-token decode) is recorded here, before any rollout optimization (profile-first).
|
||
|
||
**(b) Reward scoring.** Rule-based verifiable reward first (e.g., exact-match on a synthetic
|
||
arithmetic/format task) or RM from P2. Returns a scalar per completion.
|
||
|
||
**(c) Group-relative advantage.** Sample `G` completions per prompt; advantage
|
||
`A_i = (r_i − mean(r_group)) / (std(r_group) + ε)`. No critic, no GAE.
|
||
|
||
**(d) Clipped policy-gradient loss with KL leash.** Per completion token,
|
||
`ρ_t = exp(logπθ_t − logπθ_old_t)` (old = policy at rollout time), token loss
|
||
`−min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref)`, masked to completion tokens. KL via the k3
|
||
estimator.
|
||
|
||
**(e) Actor-learner loop.** sample prompt batch → rollout G each → score → advantage →
|
||
capture `πθ_old` logprobs → K inner epochs of clipped PG updates → repeat. Reference `πref`
|
||
fixed throughout.
|
||
|
||
Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations.
|
||
Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded
|
||
(tokens/s of generation vs. update) **before** any rollout optimization, per the project's
|
||
profile-first rule.
|
||
|
||
Correctness gates:
|
||
- PG-loss finite-diff grad-check.
|
||
- **Degenerate checks**: `G = 1` ⇒ advantage 0 ⇒ no PG signal, only KL; `ε → ∞` ⇒ vanilla PG;
|
||
`β = 0` ⇒ no KL term.
|
||
- (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a
|
||
prerequisite of GRPO.)
|
||
- **Synthetic RL overfit**: on a tiny verifiable task with a known optimum, mean reward must
|
||
rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal
|
||
that the loop is correct, independent of fuzzy chat quality).
|
||
|
||
## Evaluation
|
||
|
||
- **Offline (DPO/RM)**: reward margin, preference accuracy, KL drift from reference, plus the
|
||
fixed chat-prompt generation suite (`scripts/chat_alpha_fixed_prompts.txt`) judged before/
|
||
after — reusing and extending the doc-13 recommendation for a generation-based eval harness
|
||
(exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
|
||
- **Online (GRPO)**: mean reward curve, KL-to-reference, response length, the verifiable-task
|
||
pass rate, and the same fixed-prompt suite.
|
||
- **Selection by generation eval, not loss** — the recurring doc-13/v11 lesson: lower
|
||
post-training loss did not mean better generations.
|
||
|
||
## Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)
|
||
|
||
- Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy
|
||
optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations
|
||
small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
|
||
- DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
|
||
- GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new
|
||
variable. **Generation, not the update, is expected to be the throughput bottleneck** — to be
|
||
measured, not assumed.
|
||
|
||
## Correctness-gate philosophy (unchanged from Phase 1/2)
|
||
|
||
Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity
|
||
on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1,
|
||
ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic
|
||
RL overfit), and (5) **no change to the default training path** when post-training flags are
|
||
off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference
|
||
gates as flash/GQA.
|
||
|
||
## Risks & tradeoffs
|
||
|
||
- **Rollout engine is the long pole.** A correct KV-cache incremental-decode path is a real
|
||
build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an
|
||
isolated, separately-gated sub-phase.
|
||
- **RL is finicky.** KL leash, advantage normalization, clip range, reward hacking. Mitigation:
|
||
synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
|
||
- **Reward-model noise** can mislead GRPO. Mitigation: rule-based reward first.
|
||
- **Tokenizer (KI-4)** — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
|
||
- **Two/three resident models** raise memory; bounded by recompute/accum and (for DPO) reference
|
||
logprob caching.
|
||
|
||
## Resolved decisions (aligned 2026-06-29)
|
||
|
||
- **D1 — Scope & sequencing → DPO → GRPO, reward model optional.**
|
||
- **D2 — Online-RL reward source → rule-based / verifiable reward first** (RM deferred/optional).
|
||
- **D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front** (not
|
||
naive-first), as a foundational milestone before DPO/GRPO.
|
||
- **D4 — Alignment task / eval target → a verifiable task** (arithmetic/format/GSM8K-style) with
|
||
a deterministic exact-match reward, for a clean, falsifiable RL signal.
|
||
|
||
## Milestones (locked order)
|
||
|
||
1. **M1 — P0 SFT task baseline.** Chat template + assistant-only masking on the verifiable
|
||
task; produces the reference + init checkpoint. Gate: masking unit test; single-turn
|
||
bit-identical to `fbf4ac2`.
|
||
2. **M2 — KV-cache decode engine** (D3, up front). Per-layer K/V cache + incremental
|
||
decode-time attention + batched ragged decode. Gate: **token-identical to full-recompute
|
||
greedy**; record decode throughput baseline.
|
||
3. **M3 — P1 DPO.** Verifiable-checker pair construction (via M2) → `seq_logprob` op
|
||
(grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO
|
||
training loop → run + reward-margin / preference-accuracy curve.
|
||
4. **M4 — P3 GRPO.** Group rollout (M2) + rule-based reward + group-relative advantage +
|
||
clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; **synthetic
|
||
verifiable-task RL-overfit** (mean reward → known optimum) → verifiable-task GRPO run.
|
||
5. **M5 (optional) — P2 reward model.** Scalar head + ranking loss + pairwise-accuracy gate;
|
||
enables GRPO-with-RM for general chat.
|
||
|
||
> Each milestone is one design+gate cycle; results get appended here (like the run docs) and a
|
||
> row in `docs/evolution.md` (algorithm/infra dimensions) when it lands.
|
||
|
||
## Implementation log
|
||
|
||
### M1 — SFT task baseline (landed)
|
||
|
||
The verifiable task and its data pipeline are implemented and verified host-side (no CUDA
|
||
needed); the SFT run + eval ran on dash5 (1×5090). **Result: SFT moves answer-format
|
||
adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys
|
||
the format; correctness is M3/M4's job).**
|
||
|
||
**Verifiable task (the spec, in one Rust module — `crates/xtrain-train/src/task.rs`):**
|
||
|
||
- Two-operand integer arithmetic, ops `+ − ×`; operands `[0,999]` for `+/−`, `[0,99]` for `×`
|
||
(modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep
|
||
the unique-key space ≫ requested rows — see the saturation guard below.)
|
||
- User turn: `What is A op B?`. SFT target: `A op B = \boxed{N}.` — teaches the answer FORMAT;
|
||
the checker reads only `\boxed{}`, so arithmetic *correctness* is what M3/M4 improve.
|
||
- Rule-based reward: `parse_boxed_answer` (takes the LAST `\boxed{int}`) + `check_answer`
|
||
(exact match vs. gold). This is the single shared checker reused by M3 (pair construction)
|
||
and M4 (GRPO reward).
|
||
- Why this task: trivial deterministic checker, freely scalable difficulty, and it directly
|
||
probes the base model's known arithmetic weakness (v12 SFT failed `12 * 13`).
|
||
|
||
**Data generator (`crates/xtrain-train/src/bin/gen_arith_task.rs`, pure host bin):**
|
||
writes `arith_sft.tsv` (`user<TAB>assistant` for `--sft-tsv`), `arith_eval_prompts.txt`
|
||
(`greedy_sample --prompts-file` format), and `arith_eval_gold.txt` (parallel gold ints).
|
||
Train rows are deduped; eval is held out from train (no leakage). A **saturation guard**
|
||
(`unique_space()` + `assert need·5 ≤ space·4`) rejects requests that approach the unique-key
|
||
space, since deduped train + disjoint eval near saturation get pathologically slow (or, for
|
||
the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys,
|
||
so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s).
|
||
|
||
**Scorer (`crates/xtrain-train/src/bin/eval_arith.rs`):** loads a checkpoint, greedily
|
||
generates a continuation per held-out prompt, isolates the first answer segment (cut at the
|
||
first `<|endoftext|>` then first newline), and reports two signals via the shared checker —
|
||
**format** (fraction emitting any `\boxed{int}`) and **correctness** (exact-match vs. gold).
|
||
This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the *naive*
|
||
no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete
|
||
motivation for M2 (the KV-cache decode engine).
|
||
|
||
**Masking made testable:** the assistant-only label masking in `load_sft_tsv_cached` was
|
||
extracted into a pure `sft_row(prompt_ids, answer_ids)` helper (behavior-preserving — the
|
||
single-turn path is bit-identical to `fbf4ac2`).
|
||
|
||
**Gate (verified locally in `no_cuda` mode):** `cargo test -p xtrain-train --lib` → 9/9 pass,
|
||
including `sft_row` masks prompt→`-100` / supervises answer, the SFT-target self-consistency
|
||
invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism.
|
||
A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0
|
||
train/eval leakage.
|
||
|
||
**Run (dash5, 1×5090, from the v12 1.05B base):**
|
||
1. dataset: `gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir <dir>` → 20 000 train +
|
||
500 held-out eval, 0 leakage.
|
||
2. SFT: `train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt> --heads 52
|
||
--head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256
|
||
--batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt` → the P0
|
||
reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s.
|
||
3. eval: `eval_arith <ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt
|
||
--gold-file <dir>/arith_eval_gold.txt --max-tokens 32`, base vs. SFT, on 100 held-out prompts.
|
||
|
||
**M1 result (100 held-out prompts, greedy, max_new 32):**
|
||
|
||
| checkpoint | format (`\boxed{}`) | correct (exact-match) |
|
||
|---------------------|----------------------|-----------------------|
|
||
| v12 base (pre-SFT) | 0 / 100 (0%) | 0 / 100 (0%) |
|
||
| arith SFT | **100 / 100 (100%)** | 8 / 100 (8%) |
|
||
|
||
The base model never emits the format — it answers `"I don't know."` / restates the question
|
||
and stops. SFT moves format **0% → 100%**: every completion cleanly restates the equation and
|
||
boxes an integer (`46 * 80 = \boxed{3380}.`). Correctness is only **8%**: the format is fully
|
||
learned but the *arithmetic* is the base model's own weak capability — e.g. it boxes 3380 for
|
||
gold 3680, −10 for gold 5; it does get some right (`895 − 353 = \boxed{542}.` ✓). That residual
|
||
gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close.
|
||
|
||
**Gate met:** format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the
|
||
held-out correct > 0 confirms the checker + eval harness score real matches (not just format).
|
||
M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic
|
||
skill, which is downstream by design.
|
||
|
||
### M2a — KV-cache incremental-decode engine (single sequence, landed)
|
||
|
||
The decode engine (D3, built up front) that replaces the naive sampler — which re-runs the
|
||
full forward over the growing prefix every step (O(t²), a fresh autograd graph per token). Two
|
||
forward-only primitives + a raw-Tensor per-token block forward, each gated in isolation.
|
||
|
||
**Primitives (`xtrain-tensor`, both forward-only):**
|
||
- `Tensor::rope_at(theta, pos0)` — RoPE at a token's *absolute* position (`pos = pos0 + row`,
|
||
no modulo), vs the training `rope` (`pos = row % period`) which is left untouched (new CUDA
|
||
kernel `rope_at_k` → no training-path risk). Cached K is stored post-RoPE, so it must match
|
||
what the full forward produced at that position. **Gate:** bit-identical to the full-sequence
|
||
rope's row `t` (`integration::rope_at_matches_full_rope_row`).
|
||
- `Tensor::decode_attention(k, v, scale)` — single-query × cached-K/V SDPA (`[bh,1,hd]` vs
|
||
`[bh,t,hd]`, no causal mask: the one query sees all cached keys). Composed from the existing
|
||
strided batched GEMM + plain softmax — **no new kernel**. **Gate:** equals the full causal
|
||
attention's last query row, max |Δ| 6e-8 (`integration::decode_attention_matches_…`).
|
||
|
||
**Engine (`xtrain-model/src/decode.rs`, `generate_greedy_cached`):** per-layer K/V cache +
|
||
single-token incremental forward. Prefill = the first `prompt.len()` decode steps (one code
|
||
path). Mirrors `model::block_forward` at the raw-Tensor level (no autograd tape — inference
|
||
needs no grads), pulling weights via the public `params()` stable order (no model-internal
|
||
visibility changes). The cache is host-accumulated token-major f32, rebuilt per step — the
|
||
honest M2a baseline; M2b moves it device-side + adds batched ragged decode.
|
||
|
||
**Gate (the M2 centerpiece — token-identical):** KV-cache greedy decode is byte-for-byte the
|
||
same token sequence as the naive full-recompute greedy. Verified two ways:
|
||
- `xtrain-train/tests/decode_kv.rs` — small GQA model (8 query / 2 kv heads), F32, 24 generated
|
||
tokens, exact token-equality. (Unit gate runs F32: a random model's near-uniform logits make
|
||
argmax fragile to ~1e-6, so the tightest path is used; the trained model below has peaked
|
||
logits → robust.)
|
||
- v12 1.05B SFT checkpoint: `eval_arith --cached` produces the **identical** eval outcome to the
|
||
naive run (format 100/100, correct 8/100) and byte-identical completions.
|
||
|
||
**Throughput baseline (v12 1.05B, batch 1, F32, profile-first — measured, not assumed):** the
|
||
cache win is **sequence-length-dependent**, which is the honest systems finding here:
|
||
|
||
| max_new | naive | kv-cache | note |
|
||
|---------|-------|----------|------|
|
||
| 32 | 108 tok/s | 111 tok/s | ~1.0× — both **launch/overhead-bound** at short seq |
|
||
| 128 | 69 tok/s | **133 tok/s** | **~1.9×** — naive's O(t²) recompute starts to bite |
|
||
| 256 | **OOM** | 129 tok/s | naive rebuilds the O(seq²) graph every step → OOM |
|
||
|
||
Cached throughput stays ~constant (O(1)/token compute + constant memory); naive **decays**
|
||
(108→69 tok/s, O(t)/token) and eventually **OOMs** (the full autograd graph per step). So at the
|
||
short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing — it matters for
|
||
**long rollouts** (DPO pair-generation, GRPO completions), exactly where M3/M4 use it. (M2a's
|
||
per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache
|
||
targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral):
|
||
the win is real but only in the regime that actually stresses the bottleneck.
|
||
|
||
### M3 — DPO (offline preference optimization, landed; honest negative result)
|
||
|
||
The first real alignment method. Infra landed and gated; the empirical finding is that DPO
|
||
**does not improve held-out arithmetic correctness on this task** — a genuine, on-theme negative
|
||
result (the design doc's "RL is finicky" risk, made concrete).
|
||
|
||
**Two new autograd ops (`xtrain-autodiff`, both reuse the CE kernel — no new CUDA):**
|
||
- `seq_logprob(logits, target)` = `Σ log πθ(target)` over non-ignored positions (the per-
|
||
sequence logprob DPO compares). `= −Σ per_row` of cross_entropy (ignored rows already 0, like
|
||
SFT masking); backward = `cross_entropy_backward(probs, target, −upstream)` (SUM, no mean).
|
||
**Gate:** finite-diff grad-check with a `-100` completion mask.
|
||
- `dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β)` = `−log σ(Δ)` with the
|
||
two policy logprobs as parents (ref logprobs constant). **Gate:** grad-check both parents +
|
||
degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0).
|
||
|
||
**Pair construction (`gen_dpo_pairs`, aligned decision):** chosen = gold answer; rejected = the
|
||
SFT model's own **greedy** (KV-cache engine, M2a) completion when it's a format-valid WRONG
|
||
boxed answer — a hard negative in the model's distribution. Since SFT is ~8% correct (M1),
|
||
greedy is wrong ~92% of the time, so this is fast and deterministic; ~8% of prompts are skipped
|
||
(greedy correct). 1500 pairs generated (158 skipped) in ~8 min.
|
||
|
||
**Training (`train_dpo`):** loads the SFT ckpt as policy AND frozen reference; **precomputes the
|
||
reference logprobs once** (while policy == reference) and caches them — one resident model. Each
|
||
step forwards the policy on chosen + rejected, `seq_logprob` each, minimises `dpo_loss`; the two
|
||
forwards share params so backward accumulates both branches. Loss **starts at exactly log2**
|
||
(Δ=0 at init) — a built-in correctness check that fired correctly. Tracks reward margin +
|
||
preference accuracy.
|
||
|
||
**Result (v12 1.05B, 1500 pairs, β=0.1; 100 held-out prompts, vs the SFT baseline format
|
||
100/100, correct 8/100):**
|
||
|
||
| run | reward margin | pref-acc | format | correct |
|
||
|---------------------------|---------------|----------|--------|---------|
|
||
| SFT (baseline) | — | — | 100/100 | 8/100 |
|
||
| DPO lr 5e-7 × 300 | +0.78 | ~82% | 100/100 | 7/100 |
|
||
| DPO lr 5e-7 × 800 | +1.25 | ~82% | 100/100 | 5/100 |
|
||
| DPO lr 1e-6 × 2000 | **+34.2** | ~76% | **0/100** | 0/100 |
|
||
|
||
The reward margin and preference accuracy rise cleanly (the loss IS being optimized — the infra
|
||
is correct), but the implicit reward **does not transfer to held-out correctness**: it stays
|
||
~5–8% (all within the ~2.7% std-error of 100 prompts — statistically flat), and pushing harder
|
||
**over-optimizes to collapse** (margin +34 = huge KL from the reference → the model emits
|
||
garbage, `46 * 80 = CRAFTIE SERIES SERIES…`, format 0%).
|
||
|
||
**The lesson (why):** chosen and rejected differ only in the final number tokens, so DPO raises
|
||
`log p(correct) − log p(wrong)` for the *specific* training pairs — it **reweights the existing
|
||
distribution, it does not install the capability**. The base model has no arithmetic algorithm,
|
||
so preferring correct-vs-wrong final answers on seen pairs cannot generalize to unseen problems;
|
||
and the only way to drive the margin far is to globally distort the distribution → incoherence.
|
||
**DPO works when the chosen is already plausible under the policy; it cannot manufacture
|
||
knowledge the model lacks.** This is the precise motivation for **M4 GRPO**: optimize the *actual
|
||
verifiable reward* online (sample → check → reinforce what is genuinely correct), rather than a
|
||
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
|
||
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
|
||
margin/acc rise); the correctness flatness is the reported finding, not a bug.
|
||
|
||
### M4 — GRPO (online RL, critic-free, landed; infra + two honest systems walls)
|
||
|
||
The centerpiece: generation INSIDE the training loop. Infra built and gated; the run surfaces
|
||
two concrete systems findings (the memory long-pole + the rollout long-pole, both flagged in the
|
||
design doc's Risks) and the same capability wall as M3.
|
||
|
||
**Task made learnable first (per the aligned decision "easier task → then M4"):** the v12 SFT
|
||
model scores ~8% on the hard task *and* on easy problems — it learned format, not arithmetic. So
|
||
the easy task (operands ≤20, ops `+ − ×`) was re-SFT'd from the v12 base → **held-out 18.7%**
|
||
(100% format), a baseline with reward variance for GRPO. Note: even easy arithmetic plateaus at
|
||
~19% held-out (250 vs 600 SFT steps identical) — a 1B web-text model does not generalize the
|
||
add/sub algorithm from ~550 examples; it memorizes train (982 total problems, 550 seen).
|
||
|
||
**New op (`xtrain-autodiff`, reuses the CE kernel + one new primitive):**
|
||
- `clipped_pg_loss(logits, target, logp_old, logp_ref, A, ε, β)` — per completion token
|
||
`ρ_t = exp(logπθ_t − logp_old_t)`, `L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL` (k3), masked
|
||
to completion tokens. Backward reuses `(probs − onehot)` + `scale_rows` (a new ~5-line per-row
|
||
scale kernel — the per-token coefficient varies, which CE-backward's single scalar can't
|
||
express). **Gate:** grad-check the active PG path + the A=0 (KL-only) path; degenerate value
|
||
checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL.
|
||
|
||
**Loop (`train_grpo`):** per step — sample B prompts, roll out G completions each, score (reward
|
||
0/1), group-relative advantage `A=(r−mean)/(std+ε)` (no critic; all-correct/all-wrong groups
|
||
skipped — zero advantage), capture `logπθ_old`/`logπref` per token, K inner clipped-PG epochs.
|
||
Rollout uses the M2 KV-cache engine with **temperature sampling** (added in M4): single-row
|
||
`[1,vocab]` logits per step vs the naive sampler's `[seq,vocab]`.
|
||
|
||
**Systems wall #1 — memory (the design doc's "two/three resident models"):** KL-leash GRPO needs
|
||
policy + frozen reference, two 1.05B fp32-master models + AdamW m/v ≈ 21 GB fixed + training
|
||
activations → unreliably OOMs on a 32 GB 5090 (fragmentation tips it over). To get a completing
|
||
run, `β=0` (pure PG) drops the reference model (−4.2 GB). So the *principled* KL-leash version is
|
||
memory-bound at this model size on this hardware — a real, reported constraint, not a bug.
|
||
|
||
**Systems wall #2 — rollout (the design doc's "rollout is the long pole"):** the naive sampler's
|
||
growing `[seq,vocab]` allocations fragment the caching allocator over a long rollout → OOM. The
|
||
cached temperature rollout (single-row logits) is lighter; but single-sequence cached decode is
|
||
slow (the M2a host-round-trip), so rollout still dominates wall-clock (~16 s/step at G=6·B=6).
|
||
Batched ragged decode (M2b) is the real fix and is deferred to where it is load-bearing.
|
||
|
||
**Result (easy task, β=0, G=6·B=6, 40 steps, lr 5e-7; 150 held-out, vs SFT 28/150 = 18.7%):**
|
||
mean rollout reward fluctuates ~0.58–0.81 (noisy, inflated by train-set overlap in the sampled
|
||
problems); **format stays 100/100** (no collapse even without the KL leash, at this gentle lr);
|
||
**held-out 30/150 = 20.0%** — `+1.3 pp`, within the ~3% std-error of 150 prompts, i.e.
|
||
**statistically flat**, the same wall as M3 DPO.
|
||
|
||
**The consistent M3+M4 lesson:** on a task where the base model lacks the underlying capability,
|
||
**neither offline preference optimization (DPO) nor online RL (GRPO) moves held-out correctness**
|
||
— each optimizes its objective (margin / reward) on the *training distribution* it can reach
|
||
(here inflated by memorization), but cannot install a *generalizable* algorithm the model never
|
||
had. RL reinforces what the model already does; it does not teach arithmetic. Gate met for M4 =
|
||
the infra is correct (PG/KL grad-checks + degenerate checks, the loop runs, reward signal + KL
|
||
leash wired, format held); the held-out flatness + the two memory/throughput walls are the
|
||
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
||
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
||
limits on what alignment can do for a capability the base model lacks.
|
||
|
||
### M2b — batched KV-cache decode (landed; completes the M2 engine, fixes the rollout long-pole)
|
||
|
||
Built after M4 (where the rollout long-pole bit hardest): decode the **G samples of one prompt in
|
||
lockstep** — one forward per step over the whole group → G× fewer kernel launches, the deferred
|
||
fix from M2a.
|
||
|
||
**One new primitive:** `rope_pos(x, positions[])` — RoPE with a *per-row* absolute position (new
|
||
forward-only kernel), since the G batched rows share one decode position (M2a's `rope_at` does
|
||
`pos0 + row`, wrong for a batch at a single position). **Gate:** bit-identical to the full rope
|
||
for positions `[0..n]`, and to `rope_at(P)` per row for a uniform `P`.
|
||
|
||
**Engine (`generate_cached_batch`):** `BatchKVCache` carries a G dimension (`[T, G·num_kv, hd]`
|
||
host-accumulated → `[G·num_kv, T, hd]`); the batched `decode_step` threads G through embed /
|
||
projections / QK-norm / `rope_pos` / cache. Two M2a pieces drop in unchanged: `decode_attention`
|
||
is already batch-agnostic (`bh = G·nh`), and `repeat_kv(nh, batch=G)` broadcasts per group. No
|
||
finished-mask (all G generate `max_new`; the caller cuts at EOS) and no ragged-length prompts yet
|
||
— both perf-only follow-ups.
|
||
|
||
**Gate (token-identical):** all G **greedy** rows are byte-identical to the single-sequence decode
|
||
(`tests/decode_batch.rs`, 8 query / 2 kv heads → exercises the `repeat_kv` batching) — pins that
|
||
G-way batching indexes each sequence's K/V with no cross-row contamination.
|
||
|
||
**Throughput (v12 1.05B, G=6·B=6, easy task, rollout wired into `train_grpo`):** ~8.5 s/step vs
|
||
~14–16 s/step for the single-seq cached rollout — **~1.7×**, rollout-inclusive. Short of the full
|
||
G× because (a) the per-token-logp forwards + the PG update also cost, and (b) the M2a per-layer
|
||
**host round-trip** is still there (now G× the data in one transfer, not removed). The full
|
||
device-side cache (no host round-trip) is the remaining decode-engine optimization. Batching also
|
||
**stabilises memory**: one batched forward per step vs G separate allocations that fragmented the
|
||
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
|
||
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
|
||
with the device-cache as the named next lever.
|
||
|
||
### M2c — device-side KV cache (landed; the bottleneck moved, a profile-first finding)
|
||
|
||
The named M2b follow-up: keep K/V on the GPU (`[bh,T,hd]`, an `Option<Tensor>` per layer) and
|
||
grow it by one token per step via a new `cat_seq` kernel (concat along the seq dim) — removing the
|
||
M2a/M2b per-layer **host round-trip** (`to_cpu`/`from_slice`/re-upload) *and* the `transpose_3d01`.
|
||
Both single-seq and batched decode refactored to it (cleaner than the host `Vec` + rebuild).
|
||
|
||
**Gates hold:** `cat_seq == host concat`; `decode_kv` single-seq + `decode_batch` G-way both still
|
||
**token-identical**; GQA training path unaffected.
|
||
|
||
**The finding (why this is a measure-first lesson, not a speedup story):** removing the host
|
||
round-trip buys **~10%** on *pure* single-seq decode (133 → 147 tok/s @128) but **does not move the
|
||
GRPO step** (~8.5 s/step, unchanged). Because after M2b batching, the rollout is no longer the
|
||
step's bottleneck — the per-sample **`per_token_logp` captures** (2 forwards/sample) and the
|
||
**PG-update** forwards+backwards (`model.forward`, full-sequence, per sample) now dominate. So the
|
||
long pole **shifted** from the rollout to the training-side forwards (cf. T11/T17/M2a: profile
|
||
before optimizing — the bottleneck you fixed is not the one that remains). The device cache is
|
||
still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decode); the honest
|
||
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
|
||
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
|
||
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
|
||
|
||
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
|
||
|
||
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
|
||
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
|
||
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
|
||
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
|
||
|
||
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
|
||
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
|
||
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
|
||
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
|
||
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
|
||
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
|
||
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
|
||
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
|
||
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
|
||
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
|
||
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
|
||
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
|
||
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
|
||
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
|
||
|
||
**Correctness gates (exact, not bf16-noisy):**
|
||
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
|
||
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
|
||
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
|
||
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
|
||
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
|
||
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
|
||
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
|
||
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
|
||
|
||
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
|
||
|
||
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|
||
|-------------------------|---------------------|---------------|---------|
|
||
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
|
||
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
|
||
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
|
||
|
||
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
|
||
now dominate the step." The clean per-component bench falsifies the strong form: the training
|
||
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
|
||
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
|
||
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
|
||
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
|
||
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
|
||
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
|
||
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
|
||
training stack stays complete, now with the step decomposition measured, not asserted.
|