From 096e45b845df1260754491e1dc22412240adfd18 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Tue, 30 Jun 2026 17:01:22 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20M4=20=E2=80=94=20GRPO=20results=20(infr?= =?UTF-8?q?a=20+=20memory/rollout=20walls=20+=20capability-wall=20negative?= =?UTF-8?q?=20result)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implementation log (docs/18) + Phase-3 row (evolution.md): the clipped_pg_loss op + gates, the actor-learner loop, the easy-task SFT baseline (held-out 18.7%, plateaus → no generalization), the two systems walls the design doc flagged (two 1B models OOM the 32GB box → β=0; naive rollout fragments the allocator → cached temperature sampling, rollout still the long pole), and the result: format holds, held-out 20.0% (+1.3pp, statistically flat) — the same wall as DPO. Closes the SFT→KV-cache→DPO→GRPO post-training arc with honest limits. Co-Authored-By: Claude Opus 4.8 --- docs/18-post-training-rl-sft.md | 56 +++++++++++++++++++++++++++++++++ docs/evolution.md | 2 ++ 2 files changed, 58 insertions(+) diff --git a/docs/18-post-training-rl-sft.md b/docs/18-post-training-rl-sft.md index e6af508..0727654 100644 --- a/docs/18-post-training-rl-sft.md +++ b/docs/18-post-training-rl-sft.md @@ -466,3 +466,59 @@ verifiable reward* online (sample → check → reinforce what is genuinely corr fixed-pair proxy — though GRPO faces the same 8%-correct sparsity, so whether it moves the metric is M4's open question. Gate met for M3 = the infra is correct (op grad-checks, log2-at-init, margin/acc rise); the correctness flatness is the reported finding, not a bug. + +### M4 — GRPO (online RL, critic-free, landed; infra + two honest systems walls) + +The centerpiece: generation INSIDE the training loop. Infra built and gated; the run surfaces +two concrete systems findings (the memory long-pole + the rollout long-pole, both flagged in the +design doc's Risks) and the same capability wall as M3. + +**Task made learnable first (per the aligned decision "easier task → then M4"):** the v12 SFT +model scores ~8% on the hard task *and* on easy problems — it learned format, not arithmetic. So +the easy task (operands ≤20, ops `+ − ×`) was re-SFT'd from the v12 base → **held-out 18.7%** +(100% format), a baseline with reward variance for GRPO. Note: even easy arithmetic plateaus at +~19% held-out (250 vs 600 SFT steps identical) — a 1B web-text model does not generalize the +add/sub algorithm from ~550 examples; it memorizes train (982 total problems, 550 seen). + +**New op (`xtrain-autodiff`, reuses the CE kernel + one new primitive):** +- `clipped_pg_loss(logits, target, logp_old, logp_ref, A, ε, β)` — per completion token + `ρ_t = exp(logπθ_t − logp_old_t)`, `L = −mean min(ρA, clip(ρ,1±ε)A) + β·mean KL` (k3), masked + to completion tokens. Backward reuses `(probs − onehot)` + `scale_rows` (a new ~5-line per-row + scale kernel — the per-token coefficient varies, which CE-backward's single scalar can't + express). **Gate:** grad-check the active PG path + the A=0 (KL-only) path; degenerate value + checks ε→∞ ⇒ vanilla PG, β=0 ⇒ no KL. + +**Loop (`train_grpo`):** per step — sample B prompts, roll out G completions each, score (reward +0/1), group-relative advantage `A=(r−mean)/(std+ε)` (no critic; all-correct/all-wrong groups +skipped — zero advantage), capture `logπθ_old`/`logπref` per token, K inner clipped-PG epochs. +Rollout uses the M2 KV-cache engine with **temperature sampling** (added in M4): single-row +`[1,vocab]` logits per step vs the naive sampler's `[seq,vocab]`. + +**Systems wall #1 — memory (the design doc's "two/three resident models"):** KL-leash GRPO needs +policy + frozen reference, two 1.05B fp32-master models + AdamW m/v ≈ 21 GB fixed + training +activations → unreliably OOMs on a 32 GB 5090 (fragmentation tips it over). To get a completing +run, `β=0` (pure PG) drops the reference model (−4.2 GB). So the *principled* KL-leash version is +memory-bound at this model size on this hardware — a real, reported constraint, not a bug. + +**Systems wall #2 — rollout (the design doc's "rollout is the long pole"):** the naive sampler's +growing `[seq,vocab]` allocations fragment the caching allocator over a long rollout → OOM. The +cached temperature rollout (single-row logits) is lighter; but single-sequence cached decode is +slow (the M2a host-round-trip), so rollout still dominates wall-clock (~16 s/step at G=6·B=6). +Batched ragged decode (M2b) is the real fix and is deferred to where it is load-bearing. + +**Result (easy task, β=0, G=6·B=6, 40 steps, lr 5e-7; 150 held-out, vs SFT 28/150 = 18.7%):** +mean rollout reward fluctuates ~0.58–0.81 (noisy, inflated by train-set overlap in the sampled +problems); **format stays 100/100** (no collapse even without the KL leash, at this gentle lr); +**held-out 30/150 = 20.0%** — `+1.3 pp`, within the ~3% std-error of 150 prompts, i.e. +**statistically flat**, the same wall as M3 DPO. + +**The consistent M3+M4 lesson:** on a task where the base model lacks the underlying capability, +**neither offline preference optimization (DPO) nor online RL (GRPO) moves held-out correctness** +— each optimizes its objective (margin / reward) on the *training distribution* it can reach +(here inflated by memorization), but cannot install a *generalizable* algorithm the model never +had. RL reinforces what the model already does; it does not teach arithmetic. Gate met for M4 = +the infra is correct (PG/KL grad-checks + degenerate checks, the loop runs, reward signal + KL +leash wired, format held); the held-out flatness + the two memory/throughput walls are the +reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated +SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest +limits on what alignment can do for a capability the base model lacks. diff --git a/docs/evolution.md b/docs/evolution.md index 3b6c1ad..f84ec52 100644 --- a/docs/evolution.md +++ b/docs/evolution.md @@ -101,6 +101,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**( **M3(DPO,离线偏好优化,已落地 + 诚实负结果)**:两个复用 CE kernel 的新算子(零新 CUDA)——`seq_logprob`(Σ log πθ over 非 mask 位,反向 = CE_backward 取负求和;grad-check + mask)、`dpo_loss`(−log σ(Δ),双 policy logprob 父节点;grad-check + 退化 Δ=0→log2/∓β·½、β=0→0)。造对(`gen_dpo_pairs`)= chosen=gold、rejected=SFT 自己 greedy(用 M2a 引擎)的格式合法**错误**答案(8% greedy 答对的跳过)。训练(`train_dpo`)把 SFT ckpt 同时作 policy 和冻结 reference,**一次性预算 reference logprob 并缓存**(单模型驻留),每步 policy forward chosen+rejected → seq_logprob → dpo_loss,两 forward 共享 param 累积梯度;**loss 起步恰好 log2**(Δ=0 内置校验)。**结果(v12, 1500 对, β0.1;100 留出题 vs SFT 8/100)**:reward-margin 与 pref-acc 干净上升(loss 被正确优化、infra 对),但**不转化为 held-out 正确率**——lr5e-7×300→7%、×800→5%、lr1e-6×2000→margin+34 **崩溃**(0% 格式、输出垃圾),三档都在 100 题 ~2.7% 标准误内 = 统计持平。**教训**:chosen/rejected 只差最终数字 token,DPO 提升的是**特定训练对的 token 偏好、reweight 现有分布,不 install 能力**;base 模型没有算术算法,偏好优化不泛化,推狠了只是全局扭曲分布→不连贯。**DPO 在 chosen 本就 plausible 时有效,不能凭空造模型没有的知识**——这正是 M4 GRPO 的动机:在线优化**真实可验证 reward**(采样→check→强化真正对的)而非固定对的 proxy(但 GRPO 同样面对 8% 稀疏,能否抬动指标是 M4 的 open question)。与 v8/T17 同源的诚实账:跑通+闸门齐全,负结果如实记。 +**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。 + ## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)) - **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。