docs: M3 — DPO results (infra correct, held-out correctness flat, over-optimization collapse)
Implementation log (docs/18) + Phase-3 row (evolution.md): the two ops + gates, pair-gen (gold chosen / sampled-wrong rejected), reference-logprob caching, the training loop, and the honest finding — reward margin + pref-acc rise but held-out arithmetic correctness stays ~5-8% (flat within std-error) and over-optimizes to collapse (margin +34 → 0% format). DPO reweights, it does not install the capability; motivates M4 GRPO (optimize the verifiable reward online). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -410,3 +410,59 @@ short arithmetic-eval lengths the cache is overhead-bound and gives ~nothing —
|
||||
per-layer host round-trip is part of why short-seq is overhead-bound; M2b's device-side cache
|
||||
targets it.) This is the same measure-first lesson as T17 (process-per-GPU throughput-neutral):
|
||||
the win is real but only in the regime that actually stresses the bottleneck.
|
||||
|
||||
### M3 — DPO (offline preference optimization, landed; honest negative result)
|
||||
|
||||
The first real alignment method. Infra landed and gated; the empirical finding is that DPO
|
||||
**does not improve held-out arithmetic correctness on this task** — a genuine, on-theme negative
|
||||
result (the design doc's "RL is finicky" risk, made concrete).
|
||||
|
||||
**Two new autograd ops (`xtrain-autodiff`, both reuse the CE kernel — no new CUDA):**
|
||||
- `seq_logprob(logits, target)` = `Σ log πθ(target)` over non-ignored positions (the per-
|
||||
sequence logprob DPO compares). `= −Σ per_row` of cross_entropy (ignored rows already 0, like
|
||||
SFT masking); backward = `cross_entropy_backward(probs, target, −upstream)` (SUM, no mean).
|
||||
**Gate:** finite-diff grad-check with a `-100` completion mask.
|
||||
- `dpo_loss(lpθ_chosen, lpθ_rejected, lpref_chosen, lpref_rejected, β)` = `−log σ(Δ)` with the
|
||||
two policy logprobs as parents (ref logprobs constant). **Gate:** grad-check both parents +
|
||||
degenerate points (policy==ref ⇒ Δ=0, L=log2, grads ∓β/2; β=0 ⇒ grads 0).
|
||||
|
||||
**Pair construction (`gen_dpo_pairs`, aligned decision):** chosen = gold answer; rejected = the
|
||||
SFT model's own **greedy** (KV-cache engine, M2a) completion when it's a format-valid WRONG
|
||||
boxed answer — a hard negative in the model's distribution. Since SFT is ~8% correct (M1),
|
||||
greedy is wrong ~92% of the time, so this is fast and deterministic; ~8% of prompts are skipped
|
||||
(greedy correct). 1500 pairs generated (158 skipped) in ~8 min.
|
||||
|
||||
**Training (`train_dpo`):** loads the SFT ckpt as policy AND frozen reference; **precomputes the
|
||||
reference logprobs once** (while policy == reference) and caches them — one resident model. Each
|
||||
step forwards the policy on chosen + rejected, `seq_logprob` each, minimises `dpo_loss`; the two
|
||||
forwards share params so backward accumulates both branches. Loss **starts at exactly log2**
|
||||
(Δ=0 at init) — a built-in correctness check that fired correctly. Tracks reward margin +
|
||||
preference accuracy.
|
||||
|
||||
**Result (v12 1.05B, 1500 pairs, β=0.1; 100 held-out prompts, vs the SFT baseline format
|
||||
100/100, correct 8/100):**
|
||||
|
||||
| run | reward margin | pref-acc | format | correct |
|
||||
|---------------------------|---------------|----------|--------|---------|
|
||||
| SFT (baseline) | — | — | 100/100 | 8/100 |
|
||||
| DPO lr 5e-7 × 300 | +0.78 | ~82% | 100/100 | 7/100 |
|
||||
| DPO lr 5e-7 × 800 | +1.25 | ~82% | 100/100 | 5/100 |
|
||||
| DPO lr 1e-6 × 2000 | **+34.2** | ~76% | **0/100** | 0/100 |
|
||||
|
||||
The reward margin and preference accuracy rise cleanly (the loss IS being optimized — the infra
|
||||
is correct), but the implicit reward **does not transfer to held-out correctness**: it stays
|
||||
~5–8% (all within the ~2.7% std-error of 100 prompts — statistically flat), and pushing harder
|
||||
**over-optimizes to collapse** (margin +34 = huge KL from the reference → the model emits
|
||||
garbage, `46 * 80 = CRAFTIE SERIES SERIES…`, format 0%).
|
||||
|
||||
**The lesson (why):** chosen and rejected differ only in the final number tokens, so DPO raises
|
||||
`log p(correct) − log p(wrong)` for the *specific* training pairs — it **reweights the existing
|
||||
distribution, it does not install the capability**. The base model has no arithmetic algorithm,
|
||||
so preferring correct-vs-wrong final answers on seen pairs cannot generalize to unseen problems;
|
||||
and the only way to drive the margin far is to globally distort the distribution → incoherence.
|
||||
**DPO works when the chosen is already plausible under the policy; it cannot manufacture
|
||||
knowledge the model lacks.** This is the precise motivation for **M4 GRPO**: optimize the *actual
|
||||
verifiable reward* online (sample → check → reinforce what is genuinely correct), rather than a
|
||||
fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric
|
||||
is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init,
|
||||
margin/acc rise); the correctness flatness is the reported finding, not a bug.
|
||||
|
||||
@@ -99,6 +99,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
||||
|
||||
**M2a(KV-cache 增量解码引擎,单序列,已落地)**:两个 forward-only 原语 + 裸 Tensor 逐 token block forward,各自隔离闸门。`rope_at`(绝对位置 RoPE,新 kernel,不动训练 `rope` → 训练路径零风险)逐位等于全序列 rope 的对应行;`decode_attention`(单 query × cached-K/V,由现成 strided-gemm + 普通 softmax 组合,**零新 kernel**)等于全 causal attention 末行(max|Δ| 6e-8)。引擎 `generate_greedy_cached` 镜像 `block_forward` 在 Tensor 层(无 autograd tape,推理不需梯度),靠**公开 `params()` 稳定顺序**拿权重(零 model 可见性改动)。**核心闸门 = token-identical**:与朴素全重算贪心逐 token 一致(小 GQA 单测 + v12 1.05B 上 cached eval 与 naive **逐字节相同**:format 100/100, correct 8/100)。**吞吐 baseline(v12, batch1, F32,profile-first 实测)= cache 收益随序列长度而定**:max_new 32 ≈ 持平(108 vs 111,短序列 launch 开销 bound)、128 **~1.9×**(69 vs 133)、256 naive **OOM** vs cached 129 tok/s。cached 吞吐**近恒定**(O(1)/token + 恒定显存),naive **衰减**(O(t)/token,O(seq²) 图 → OOM)。⇒ 短 eval prompt overhead-bound、cache 几乎无收益,真正受益的是**长 rollout**(DPO 造对 / GRPO completion)——与 T17(process-per-GPU 吞吐中性)同一条 measure-first 教训:收益真实,但只在真正压到瓶颈的 regime 里。M2a 的 per-layer 主机往返是短序列 overhead-bound 的一部分原因,M2b(device 端 cache + 批量 ragged)针对它。
|
||||
|
||||
**M3(DPO,离线偏好优化,已落地 + 诚实负结果)**:两个复用 CE kernel 的新算子(零新 CUDA)——`seq_logprob`(Σ log πθ over 非 mask 位,反向 = CE_backward 取负求和;grad-check + mask)、`dpo_loss`(−log σ(Δ),双 policy logprob 父节点;grad-check + 退化 Δ=0→log2/∓β·½、β=0→0)。造对(`gen_dpo_pairs`)= chosen=gold、rejected=SFT 自己 greedy(用 M2a 引擎)的格式合法**错误**答案(8% greedy 答对的跳过)。训练(`train_dpo`)把 SFT ckpt 同时作 policy 和冻结 reference,**一次性预算 reference logprob 并缓存**(单模型驻留),每步 policy forward chosen+rejected → seq_logprob → dpo_loss,两 forward 共享 param 累积梯度;**loss 起步恰好 log2**(Δ=0 内置校验)。**结果(v12, 1500 对, β0.1;100 留出题 vs SFT 8/100)**:reward-margin 与 pref-acc 干净上升(loss 被正确优化、infra 对),但**不转化为 held-out 正确率**——lr5e-7×300→7%、×800→5%、lr1e-6×2000→margin+34 **崩溃**(0% 格式、输出垃圾),三档都在 100 题 ~2.7% 标准误内 = 统计持平。**教训**:chosen/rejected 只差最终数字 token,DPO 提升的是**特定训练对的 token 偏好、reweight 现有分布,不 install 能力**;base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布→不连贯。**DPO 在 chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样→check→强化真正对的)而非固定对的 proxy(但 GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 的 open question)。与 v8/T17 同源的诚实账:跑通+闸门齐全,负结果如实记。
|
||||
|
||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||
|
||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||
|
||||
Reference in New Issue
Block a user