Files
xtrain/docs/18-post-training-rl-sft.md
Gahow Wang 99090465bf docs: M3 — DPO results (infra correct, held-out correctness flat, over-optimization collapse)
Implementation log (docs/18) + Phase-3 row (evolution.md): the two ops + gates,
pair-gen (gold chosen / sampled-wrong rejected), reference-logprob caching, the
training loop, and the honest finding — reward margin + pref-acc rise but
held-out arithmetic correctness stays ~5-8% (flat within std-error) and
over-optimizes to collapse (margin +34 → 0% format). DPO reweights, it does not
install the capability; motivates M4 GRPO (optimize the verifiable reward online).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-30 12:38:06 +08:00

30 KiB
Raw Blame History

Phase: Post-Training Infra — SFT / DPO / Reward Model / GRPO — Design Document

Status: DESIGN — decisions locked, pending go-ahead to implement. Nothing implemented yet. This doc proposes the scope, the staged build, the new infra pieces, and the correctness gates for a standard post-training stack on top of the xtrain training framework. Decisions D1D4 are resolved (see "Resolved decisions"): DPO → GRPO (reward model optional) · rule-based/verifiable reward · KV-cache decode engine built up front · a verifiable task as the optimization/eval target.

Goal

Build a standard, from-scratch post-training infrastructure — the systems layer that turns a pretrained base LM into an aligned chat model — and use it to run chat alignment. The deliverable that matters here is the infra and the lessons, not the end-to-end chat quality (see the project's learning-axis framing). Each stage should teach exactly one new post-training systems concept and ship with a hard correctness gate, matching the Phase-1/Phase-2 culture (grad-checks, PyTorch parity, bit-identical default paths, profile-first).

Concretely we want to be able to answer, with our own code:

  • How does offline preference optimization (DPO) differ from SFT in the training loop — what is the reference model, why two forwards, what is the loss?
  • How does a reward model turn preferences into a scalar signal?
  • How does online RL (GRPO) actually run — the rollout engine, reward scoring, group-relative advantage, the clipped policy-gradient update, the KL leash?
  • Where are the memory and throughput pressure points that make post-training infra different from pretraining infra (multiple models resident, generation in the loop)?

Baseline: what already exists vs. what is missing

What the framework already gives us (verified in code, reused as-is):

capability where reuse for post-training
batched forward → logits [B*S, vocab] model.rs::forward_batched logprob extraction for DPO/RM/GRPO
cross-entropy with ignore-index 100 ops.rs::cross_entropy, nn.cu assistant-only / completion-only masking
assistant-only SFT (TSV, masked labels) data.rs::load_sft_tsv_cached (commit fbf4ac2) SFT chat baseline = DPO init + reference
bf16 mixed precision, fp32 master with_compute_dtype policy + frozen reference both bf16 compute
recompute / flash / grad-accum with_recompute / with_flash / --accum-steps bound activation memory with 23 models resident
DDP (thread + process-per-GPU) xtrain-distributed data-parallel post-training
AdamW + clip + LR sched + checkpoint xtrain-optim, checkpoint.rs, schedule.rs unchanged optimizer path
single-seq greedy/temperature sampling sample.rs::generate slow rollout fallback (no KV cache)

What is missing and must be built (these are the actual lessons):

  1. Per-sequence completion logprob — a way to read Σ log πθ(y_t | x, y_<t) over the completion tokens of a sequence. CE gives a mean scalar; DPO/GRPO need a per-sequence masked sum. New op or thin wrapper over the CE per-row machinery.
  2. Frozen reference model held in memory alongside the trainable policy (no grad, no optimizer), or its logprobs precomputed and cached.
  3. Pairwise preference loss (DPO) and Bradley-Terry ranking loss (RM).
  4. Reward head — a [dim,1] scalar head reading the last non-pad position (RM only).
  5. Rollout / generation engine — batched autoregressive sampling. Current generate is single-sequence and re-runs the full forward each step (no KV cache). Online RL needs batched rollouts; a real KV-cache incremental-decode engine is the centerpiece infra build.
  6. GRPO machinery — group sampling, group-relative advantage, clipped PG loss, KL penalty, the actor-learner loop.

The post-training landscape — where the infra lives

                 data                 models in memory        new systems concept
  SFT      (prompt, answer)           policy                  loss masking (have it)
  DPO      (prompt, chosen, reject)   policy + ref(frozen)    dual forward, pairwise logσ loss
  RM       (prompt, chosen, reject)   reward model            scalar head, ranking loss
  PPO      prompts + reward source    policy+ref+RM+critic    rollout + GAE + clipped PG  (4 models)
  GRPO     prompts + reward source    policy+ref(+RM)         rollout + group baseline + clipped PG

The pedagogical ladder is SFT → DPO → (RM) → GRPO. DPO is the cheapest "real" alignment method (no generation, no reward model, reuses the training loop almost verbatim) and is the right first rung. GRPO is chosen over PPO as the online-RL rung because it drops the value critic (group-relative advantage replaces the learned baseline) — that removes a whole model and the GAE machinery while still teaching the complete online-RL loop. PPO is noted as an optional later extension, not a primary target.

DECISION D1 (scope/sequencing) — LOCKED: P0 → P1(DPO) → P3(GRPO), P2(reward model) optional. With D3 locked to "KV-cache engine up front", the engine becomes a foundational milestone that both DPO pair-generation and GRPO rollouts sit on. Effective build order: P0 → KV-cache decode engine → P1(DPO) → P3(GRPO) → P2(optional) (see "Milestones").

Stage P0 — SFT chat baseline (light; mostly reuse)

Goal: a clean SFT checkpoint to serve as both the DPO/GRPO init and the frozen reference. With D4 = verifiable task, P0 SFT teaches the task format (e.g. arithmetic prompts → a parseable answer such as \boxed{N}) so the model emits checker-readable completions; the same template is reused by rollout and eval. The current SFT (commit fbf4ac2) already does single-turn assistant-only masking; P0 only adds what alignment needs:

  • a fixed chat template (the User:/Assistant: + <|endoftext|> format already used, promoted to a documented constant shared by SFT data prep, rollout, and eval),
  • optional multi-turn masking (supervise every assistant turn, mask user turns),
  • optional sequence packing (concatenate examples to fill seq, reset attention/RoPE per example — note forward_batched already isolates sequences, so packing = careful index bookkeeping, not new attention code).

Gate: masking unit test (only assistant tokens contribute to loss); packing does not leak loss across example boundaries. Hypothesis: a documented chat template + multi-turn mask gives a reproducible SFT reference without changing the training numerics for single-turn data (bit-identical to fbf4ac2 on single-turn input).

Stage P1 — DPO (offline preference optimization) first real method

New infra:

  1. Preference data — constructed from the verifiable checker (D4). On a verifiable task there is no off-the-shelf preference set, so we build pairs: sample several completions per prompt from the P0 SFT model (using the KV-cache engine built in the prior milestone), score each with the rule-based checker, take a correct completion as chosen and an incorrect one as rejected. This is a one-time offline data-prep step; DPO training itself is then static. Tokenize each as template(prompt) + completion + EOS; build a completion mask (prompt = masked).
  2. seq_logprob(logits, target_ids, mask) → [B]: per-sequence sum of log softmax(logits)[target] over masked positions. Implement by reusing the CE per-row path (CE per-row = log πθ(target)), summing per_row over the mask. Add a grad-checked op so the backward is exact.
  3. Frozen reference πref: load the SFT checkpoint into a second model in eval/no-grad bf16. Its logprobs are constants in the loss. Optimization to teach: precompute and cache reference logprobs once over the dataset → the reference model need not stay resident during training (one model in memory, like SFT).
  4. DPO loss (Rafailov et al.): with Δ = β[(logπθ(yw|x) logπref(yw|x)) (logπθ(yl|x) logπref(yl|x))], L = log σ(Δ). Only πθ terms carry gradient.

Memory: policy (fp32 master + Adam m/v + bf16 + grads) + reference (bf16 only, or cached logprobs → zero). Recompute + accum keep activations bounded; 1B fits 32 GB comfortably.

Correctness gates:

  • seq_logprob finite-difference grad-check (tiny model).
  • DPO-loss + grad PyTorch parity (the project's standard gate).
  • Degenerate checks: πθ == πref at init ⇒ Δ = 0, L = log 2, implicit reward 0; β → 0 ⇒ gradient → 0.
  • Health metric: chosenrejected reward margin rises over training; accuracy (margin > 0) increases. Reported, not just loss (the doc-13 lesson: val/loss alone is not a sufficient signal).

Application: chat alignment via DPO on English preference pairs. This is the offline chat-alignment deliverable.

Stage P2 — Reward model (Bradley-Terry) — OPTIONAL

DECISION D2 (reward source) — LOCKED: rule-based / verifiable reward first. GRPO brings up on the deterministic checker; a learned reward model is deferred/optional (only if we later want general-chat GRPO). So this whole stage is optional and not on the critical path.

New infra: a scalar reward head ([dim,1]) reading the hidden state at the last non-pad position; ranking loss log σ(r(x,yw) r(x,yl)). Reuses the preference data and the dual-sequence forward from P1.

Gates: ranking-loss grad-check; held-out pairwise accuracy (r_w > r_l); a frozen RM loads/serves the scalar correctly.

Stage P3 — GRPO (online RL, critic-free) the deep infra lesson

This is the centerpiece. It introduces generation inside the training loop.

(a) Rollout / generation engine — built up front (its own milestone).

DECISION D3 (rollout depth) — LOCKED: build the KV-cache incremental-decode engine up front, as a foundational milestone before DPO/GRPO, rather than starting naive. It is then the shared substrate for DPO pair-generation and GRPO rollouts. Tradeoff accepted: front-loads the single hardest build and delays the first alignment result, in exchange for a real generation engine and a clean, isolated infra lesson.

The engine: per-layer K/V cache, single-token incremental forward (process the prompt once to fill the cache, then decode one token at a time), batched ragged decode (B prompts × G samples; sequences hit EOS at different lengths → finished-mask / left-padding / compaction). The current attention assumes a full causal window over seq; incremental decode needs a decode-time attention path — query length 1 against cached K/V of length t, with RoPE position = t. This reuses the composed SDPA shapes (one-row query), so it can land as a distinct code path without disturbing the training attention (flash/GQA/composed unchanged).

Hard gate (the centerpiece correctness lesson): KV-cache decode == full-recompute decode, token-identical greedy output — the same byte-/token-identical discipline the project uses for the xserv export closed loop. A throughput baseline (decode tokens/s, cache-fill vs. per-token decode) is recorded here, before any rollout optimization (profile-first).

(b) Reward scoring. Rule-based verifiable reward first (e.g., exact-match on a synthetic arithmetic/format task) or RM from P2. Returns a scalar per completion.

(c) Group-relative advantage. Sample G completions per prompt; advantage A_i = (r_i mean(r_group)) / (std(r_group) + ε). No critic, no GAE.

(d) Clipped policy-gradient loss with KL leash. Per completion token, ρ_t = exp(logπθ_t logπθ_old_t) (old = policy at rollout time), token loss min(ρ_t A, clip(ρ_t, 1±ε) A) + βKL(πθ‖πref), masked to completion tokens. KL via the k3 estimator.

(e) Actor-learner loop. sample prompt batch → rollout G each → score → advantage → capture πθ_old logprobs → K inner epochs of clipped PG updates → repeat. Reference πref fixed throughout.

Memory: policy + reference (+ RM if learned). Each 1B; recompute + accum bound activations. Throughput note: rollout (generation) will dominate wall-clock — a baseline must be recorded (tokens/s of generation vs. update) before any rollout optimization, per the project's profile-first rule.

Correctness gates:

  • PG-loss finite-diff grad-check.
  • Degenerate checks: G = 1 ⇒ advantage 0 ⇒ no PG signal, only KL; ε → ∞ ⇒ vanilla PG; β = 0 ⇒ no KL term.
  • (KV-cache decode token-identical to full-recompute is gated in the engine milestone, a prerequisite of GRPO.)
  • Synthetic RL overfit: on a tiny verifiable task with a known optimum, mean reward must rise to the optimum (the RL analogue of T5's "overfit 27/27" — a hard, falsifiable signal that the loop is correct, independent of fuzzy chat quality).

Evaluation

  • Offline (DPO/RM): reward margin, preference accuracy, KL drift from reference, plus the fixed chat-prompt generation suite (scripts/chat_alpha_fixed_prompts.txt) judged before/ after — reusing and extending the doc-13 recommendation for a generation-based eval harness (exact-match math, code syntax, stop-token, refusal appropriateness, corruption).
  • Online (GRPO): mean reward curve, KL-to-reference, response length, the verifiable-task pass rate, and the same fixed-prompt suite.
  • Selection by generation eval, not loss — the recurring doc-13/v11 lesson: lower post-training loss did not mean better generations.

Memory & throughput budget (8× RTX 5090, 1.05B model, indicative)

  • Params (bf16) ~2.1 GB; fp32 master ~4.2 GB; AdamW m/v ~8.4 GB; grads ~2.1 GB → policy optimizer state alone ~17 GB before activations. Recompute + grad-accum keep activations small; this is why post-training reuses the Phase-1/2 memory levers unchanged.
  • DPO: + reference (bf16 ~2.1 GB, or 0 if logprobs cached). Fits.
  • GRPO: + reference (~2.1 GB) (+ RM ~2.1 GB if learned). Fits; rollout activations are the new variable. Generation, not the update, is expected to be the throughput bottleneck — to be measured, not assumed.

Correctness-gate philosophy (unchanged from Phase 1/2)

Every stage ships: (1) a finite-difference grad-check on the new loss/op, (2) PyTorch parity on loss + grads where applicable, (3) explicit degenerate-case bit/again checks (β→0, G=1, ε→∞, ref==policy), (4) a falsifiable "it actually learns" signal (reward margin up / synthetic RL overfit), and (5) no change to the default training path when post-training flags are off. New CUDA kernels (if any, e.g. decode-time attention) get the same fwd/bwd-vs-reference gates as flash/GQA.

Risks & tradeoffs

  • Rollout engine is the long pole. A correct KV-cache incremental-decode path is a real build (decode-time attention, ragged batch). Mitigation: naive rollout first; KV-cache as an isolated, separately-gated sub-phase.
  • RL is finicky. KL leash, advantage normalization, clip range, reward hacking. Mitigation: synthetic verifiable task with a known optimum as the bring-up gate before any real chat reward.
  • Reward-model noise can mislead GRPO. Mitigation: rule-based reward first.
  • Tokenizer (KI-4) — gpt2 50257 vocab is kept for the xserv closed loop; unchanged here.
  • Two/three resident models raise memory; bounded by recompute/accum and (for DPO) reference logprob caching.

Resolved decisions (aligned 2026-06-29)

  • D1 — Scope & sequencing → DPO → GRPO, reward model optional.
  • D2 — Online-RL reward source → rule-based / verifiable reward first (RM deferred/optional).
  • D3 — Rollout engine depth → build the KV-cache incremental-decode engine up front (not naive-first), as a foundational milestone before DPO/GRPO.
  • D4 — Alignment task / eval target → a verifiable task (arithmetic/format/GSM8K-style) with a deterministic exact-match reward, for a clean, falsifiable RL signal.

Milestones (locked order)

  1. M1 — P0 SFT task baseline. Chat template + assistant-only masking on the verifiable task; produces the reference + init checkpoint. Gate: masking unit test; single-turn bit-identical to fbf4ac2.
  2. M2 — KV-cache decode engine (D3, up front). Per-layer K/V cache + incremental decode-time attention + batched ragged decode. Gate: token-identical to full-recompute greedy; record decode throughput baseline.
  3. M3 — P1 DPO. Verifiable-checker pair construction (via M2) → seq_logprob op (grad-check) → DPO loss (PyTorch parity; ref==policy and β→0 degenerate checks) → DPO training loop → run + reward-margin / preference-accuracy curve.
  4. M4 — P3 GRPO. Group rollout (M2) + rule-based reward + group-relative advantage + clipped PG with KL leash. Gate: PG grad-check; G=1/ε→∞/β=0 degenerate checks; synthetic verifiable-task RL-overfit (mean reward → known optimum) → verifiable-task GRPO run.
  5. M5 (optional) — P2 reward model. Scalar head + ranking loss + pairwise-accuracy gate; enables GRPO-with-RM for general chat.

Each milestone is one design+gate cycle; results get appended here (like the run docs) and a row in docs/evolution.md (algorithm/infra dimensions) when it lands.

Implementation log

M1 — SFT task baseline (landed)

The verifiable task and its data pipeline are implemented and verified host-side (no CUDA needed); the SFT run + eval ran on dash5 (1×5090). Result: SFT moves answer-format adherence 0% → 100%, with arithmetic correctness 8% — exactly the intended split (SFT buys the format; correctness is M3/M4's job).

Verifiable task (the spec, in one Rust module — crates/xtrain-train/src/task.rs):

  • Two-operand integer arithmetic, ops + ×; operands [0,999] for +/, [0,99] for × (modest products); subtraction may be negative. (Ranges enlarged from the first cut to keep the unique-key space ≫ requested rows — see the saturation guard below.)
  • User turn: What is A op B?. SFT target: A op B = \boxed{N}. — teaches the answer FORMAT; the checker reads only \boxed{}, so arithmetic correctness is what M3/M4 improve.
  • Rule-based reward: parse_boxed_answer (takes the LAST \boxed{int}) + check_answer (exact match vs. gold). This is the single shared checker reused by M3 (pair construction) and M4 (GRPO reward).
  • Why this task: trivial deterministic checker, freely scalable difficulty, and it directly probes the base model's known arithmetic weakness (v12 SFT failed 12 * 13).

Data generator (crates/xtrain-train/src/bin/gen_arith_task.rs, pure host bin): writes arith_sft.tsv (user<TAB>assistant for --sft-tsv), arith_eval_prompts.txt (greedy_sample --prompts-file format), and arith_eval_gold.txt (parallel gold ints). Train rows are deduped; eval is held out from train (no leakage). A saturation guard (unique_space() + assert need·5 ≤ space·4) rejects requests that approach the unique-key space, since deduped train + disjoint eval near saturation get pathologically slow (or, for the disjoint-eval loop, never terminate). With the shipped defaults the space is ~2.01M keys, so a 20 000 + 500 request is a tiny fraction (gen runs in ~0.2 s).

Scorer (crates/xtrain-train/src/bin/eval_arith.rs): loads a checkpoint, greedily generates a continuation per held-out prompt, isolates the first answer segment (cut at the first <|endoftext|> then first newline), and reports two signals via the shared checker — format (fraction emitting any \boxed{int}) and correctness (exact-match vs. gold). This is the reusable verifiable-eval harness for M3 (DPO) / M4 (GRPO). It uses the naive no-KV-cache sampler (full forward per token), so even 100 prompts is slow — concrete motivation for M2 (the KV-cache decode engine).

Masking made testable: the assistant-only label masking in load_sft_tsv_cached was extracted into a pure sft_row(prompt_ids, answer_ids) helper (behavior-preserving — the single-turn path is bit-identical to fbf4ac2).

Gate (verified locally in no_cuda mode): cargo test -p xtrain-train --lib → 9/9 pass, including sft_row masks prompt→-100 / supervises answer, the SFT-target self-consistency invariant (always checker-correct over 2000 samples), parser edge cases, and seed determinism. A 200/50 generation run confirmed clean 2-column TSV, correct gold (incl. negatives), and 0 train/eval leakage.

Run (dash5, 1×5090, from the v12 1.05B base):

  1. dataset: gen_arith_task --n 20000 --eval 500 --seed 1 --out-dir <dir> → 20 000 train + 500 held-out eval, 0 leakage.
  2. SFT: train <tok> <dir>/arith_sft.tsv --sft-tsv --init-ckpt <v12-base.ckpt> --heads 52 --head-dim 32 --kv-heads 13 --layers 22 --ffn 6656 --bf16 --recompute --flash --seq 256 --batch 16 --steps 250 --max-lr 1e-4 --min-lr 1e-5 --ckpt arith_sft_v12.ckpt → the P0 reference/init checkpoint. Train loss 4.68 → ~0.34, best val 0.386, no OOM, ~4.3K tok/s.
  3. eval: eval_arith <ckpt> <tok> <arch> --prompts-file <dir>/arith_eval_prompts.txt --gold-file <dir>/arith_eval_gold.txt --max-tokens 32, base vs. SFT, on 100 held-out prompts.

M1 result (100 held-out prompts, greedy, max_new 32):

checkpoint format (\boxed{}) correct (exact-match)
v12 base (pre-SFT) 0 / 100 (0%) 0 / 100 (0%)
arith SFT 100 / 100 (100%) 8 / 100 (8%)

The base model never emits the format — it answers "I don't know." / restates the question and stops. SFT moves format 0% → 100%: every completion cleanly restates the equation and boxes an integer (46 * 80 = \boxed{3380}.). Correctness is only 8%: the format is fully learned but the arithmetic is the base model's own weak capability — e.g. it boxes 3380 for gold 3680, 10 for gold 5; it does get some right (895 353 = \boxed{542}. ✓). That residual gap is exactly what the verifiable reward in M3 (DPO) / M4 (GRPO) is built to close.

Gate met: format 0% → 100% confirms the assistant-only SFT path is wired end-to-end; the held-out correct > 0 confirms the checker + eval harness score real matches (not just format). M1 delivers the format floor + the reusable task spec / checker / eval harness — not arithmetic skill, which is downstream by design.

M2a — KV-cache incremental-decode engine (single sequence, landed)

The decode engine (D3, built up front) that replaces the naive sampler — which re-runs the full forward over the growing prefix every step (O(t²), a fresh autograd graph per token). Two forward-only primitives + a raw-Tensor per-token block forward, each gated in isolation.

Primitives (xtrain-tensor, both forward-only):

  • Tensor::rope_at(theta, pos0) — RoPE at a token's absolute position (pos = pos0 + row, no modulo), vs the training rope (pos = row % period) which is left untouched (new CUDA kernel rope_at_k → no training-path risk). Cached K is stored post-RoPE, so it must match what the full forward produced at that position. Gate: bit-identical to the full-sequence rope's row t (integration::rope_at_matches_full_rope_row).
  • Tensor::decode_attention(k, v, scale) — single-query × cached-K/V SDPA ([bh,1,hd] vs [bh,t,hd], no causal mask: the one query sees all cached keys). Composed from the existing strided batched GEMM + plain softmax — no new kernel. Gate: equals the full causal attention's last query row, max |Δ| 6e-8 (integration::decode_attention_matches_…).

Engine (xtrain-model/src/decode.rs, generate_greedy_cached): per-layer K/V cache + single-token incremental forward. Prefill = the first prompt.len() decode steps (one code path). Mirrors model::block_forward at the raw-Tensor level (no autograd tape — inference needs no grads), pulling weights via the public params() stable order (no model-internal visibility changes). The cache is host-accumulated token-major f32, rebuilt per step — the honest M2a baseline; M2b moves it device-side + adds batched ragged decode.

Gate (the M2 centerpiece — token-identical): KV-cache greedy decode is byte-for-byte the same token sequence as the naive full-recompute greedy. Verified two ways:

  • xtrain-train/tests/decode_kv.rs — small GQA model (8 query / 2 kv heads), F32, 24 generated tokens, exact token-equality. (Unit gate runs F32: a random model's near-uniform logits make argmax fragile to ~1e-6, so the tightest path is used; the trained model below has peaked logits → robust.)
  • v12 1.05B SFT checkpoint: eval_arith --cached produces the identical eval outcome to the naive run (format 100/100, correct 8/100) and byte-identical completions.

Throughput baseline (v12 1.05B, batch 1, F32, profile-first — measured, not assumed): the cache win is sequence-length-dependent, which is the honest systems finding here:

max_new naive kv-cache note
32 108 tok/s 111 tok/s ~1.0× — both launch/overhead-bound at short seq
128 69 tok/s 133 tok/s ~1.9× — naive's O(t²) recompute starts to bite
256 OOM 129 tok/s naive rebuilds the O(seq²) graph every step → OOM

Cached throughput stays ~constant (O(1)/token compute + constant memory); naive decays (108→69 tok/s, O(t)/token) and eventually OOMs (the full autograd graph per step). So at the short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing — it matters for long rollouts (DPO pair-generation, GRPO completions), exactly where M3/M4 use it. (M2a's per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral): the win is real but only in the regime that actually stresses the bottleneck.

M3 — DPO (offline preference optimization, landed; honest negative result)

The first real alignment method. Infra landed and gated; the empirical finding is that DPO does not improve held-out arithmetic correctness on this task — a genuine, on-theme negative result (the design doc's "RL is finicky" risk, made concrete).

Two new autograd ops (xtrain-autodiff, both reuse the CE kernel — no new CUDA):

  • seq_logprob(logits, target) = Σ log πθ(target) over non-ignored positions (the per- sequence logprob DPO compares). = −Σ per_row of cross_entropy (ignored rows already 0, like SFT masking); backward = cross_entropy_backward(probs, target, upstream) (SUM, no mean). Gate: finite-diff grad-check with a -100 completion mask.
  • dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β) = log σ(Δ) with the two policy logprobs as parents (ref logprobs constant). Gate: grad-check both parents + degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0).

Pair construction (gen_dpo_pairs, aligned decision): chosen = gold answer; rejected = the SFT model's own greedy (KV-cache engine, M2a) completion when it's a format-valid WRONG boxed answer — a hard negative in the model's distribution. Since SFT is ~8% correct (M1), greedy is wrong ~92% of the time, so this is fast and deterministic; ~8% of prompts are skipped (greedy correct). 1500 pairs generated (158 skipped) in ~8 min.

Training (train_dpo): loads the SFT ckpt as policy AND frozen reference; precomputes the reference logprobs once (while policy == reference) and caches them — one resident model. Each step forwards the policy on chosen + rejected, seq_logprob each, minimises dpo_loss; the two forwards share params so backward accumulates both branches. Loss starts at exactly log2 (Δ=0 at init) — a built-in correctness check that fired correctly. Tracks reward margin + preference accuracy.

Result (v12 1.05B, 1500 pairs, β=0.1; 100 held-out prompts, vs the SFT baseline format 100/100, correct 8/100):

run reward margin pref-acc format correct
SFT (baseline) 100/100 8/100
DPO lr 5e-7 × 300 +0.78 ~82% 100/100 7/100
DPO lr 5e-7 × 800 +1.25 ~82% 100/100 5/100
DPO lr 1e-6 × 2000 +34.2 ~76% 0/100 0/100

The reward margin and preference accuracy rise cleanly (the loss IS being optimized — the infra is correct), but the implicit reward does not transfer to held-out correctness: it stays ~58% (all within the ~2.7% std-error of 100 prompts — statistically flat), and pushing harder over-optimizes to collapse (margin +34 = huge KL from the reference → the model emits garbage, 46 * 80 = CRAFTIE SERIES SERIES…, format 0%).

The lesson (why): chosen and rejected differ only in the final number tokens, so DPO raises log p(correct) log p(wrong) for the specific training pairs — it reweights the existing distribution, it does not install the capability. The base model has no arithmetic algorithm, so preferring correct-vs-wrong final answers on seen pairs cannot generalize to unseen problems; and the only way to drive the margin far is to globally distort the distribution → incoherence. DPO works when the chosen is already plausible under the policy; it cannot manufacture knowledge the model lacks. This is the precise motivation for M4 GRPO: optimize the actual verifiable reward online (sample → check → reinforce what is genuinely correct), rather than a fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init, margin/acc rise); the correctness flatness is the reported finding, not a bug.