docs: M2b — batched decode results (token-identical + ~1.7x rollout, device-cache next)
Implementation log (docs/18) + Phase-3 row (evolution.md): rope_pos primitive + gate, the batched engine (decode_attention/repeat_kv reused), the token- identical batch gate, and the measured ~1.7x rollout-inclusive step speedup + memory stabilization. Closes the M2 decode engine (M2a single-seq + M2b batched); names the device-side cache as the remaining lever. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -522,3 +522,35 @@ leash wired, format held); the held-out flatness + the two memory/throughput wal
|
|||||||
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
reported findings. The honest end-state of the post-training arc: **a complete, correctness-gated
|
||||||
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
SFT → KV-cache → DPO → GRPO stack** — the infrastructure learned in full, with measured, honest
|
||||||
limits on what alignment can do for a capability the base model lacks.
|
limits on what alignment can do for a capability the base model lacks.
|
||||||
|
|
||||||
|
### M2b — batched KV-cache decode (landed; completes the M2 engine, fixes the rollout long-pole)
|
||||||
|
|
||||||
|
Built after M4 (where the rollout long-pole bit hardest): decode the **G samples of one prompt in
|
||||||
|
lockstep** — one forward per step over the whole group → G× fewer kernel launches, the deferred
|
||||||
|
fix from M2a.
|
||||||
|
|
||||||
|
**One new primitive:** `rope_pos(x, positions[])` — RoPE with a *per-row* absolute position (new
|
||||||
|
forward-only kernel), since the G batched rows share one decode position (M2a's `rope_at` does
|
||||||
|
`pos0 + row`, wrong for a batch at a single position). **Gate:** bit-identical to the full rope
|
||||||
|
for positions `[0..n]`, and to `rope_at(P)` per row for a uniform `P`.
|
||||||
|
|
||||||
|
**Engine (`generate_cached_batch`):** `BatchKVCache` carries a G dimension (`[T, G·num_kv, hd]`
|
||||||
|
host-accumulated → `[G·num_kv, T, hd]`); the batched `decode_step` threads G through embed /
|
||||||
|
projections / QK-norm / `rope_pos` / cache. Two M2a pieces drop in unchanged: `decode_attention`
|
||||||
|
is already batch-agnostic (`bh = G·nh`), and `repeat_kv(nh, batch=G)` broadcasts per group. No
|
||||||
|
finished-mask (all G generate `max_new`; the caller cuts at EOS) and no ragged-length prompts yet
|
||||||
|
— both perf-only follow-ups.
|
||||||
|
|
||||||
|
**Gate (token-identical):** all G **greedy** rows are byte-identical to the single-sequence decode
|
||||||
|
(`tests/decode_batch.rs`, 8 query / 2 kv heads → exercises the `repeat_kv` batching) — pins that
|
||||||
|
G-way batching indexes each sequence's K/V with no cross-row contamination.
|
||||||
|
|
||||||
|
**Throughput (v12 1.05B, G=6·B=6, easy task, rollout wired into `train_grpo`):** ~8.5 s/step vs
|
||||||
|
~14–16 s/step for the single-seq cached rollout — **~1.7×**, rollout-inclusive. Short of the full
|
||||||
|
G× because (a) the per-token-logp forwards + the PG update also cost, and (b) the M2a per-layer
|
||||||
|
**host round-trip** is still there (now G× the data in one transfer, not removed). The full
|
||||||
|
device-side cache (no host round-trip) is the remaining decode-engine optimization. Batching also
|
||||||
|
**stabilises memory**: one batched forward per step vs G separate allocations that fragmented the
|
||||||
|
caching allocator (the M4 OOM). So M2b closes the decode-engine milestone (M2a single-seq + M2b
|
||||||
|
batched) and turns the rollout long-pole from "OOM/unbounded" into a bounded ~1.7× win — measured,
|
||||||
|
with the device-cache as the named next lever.
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
|||||||
|
|
||||||
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
**M4(GRPO,在线 critic-free RL,已落地 + 两道诚实系统墙 + 一致负结果)**:新算子 `clipped_pg_loss`(per-token ρ + clip + k3 KL,反向用新增 `scale_rows` per-row 缩放 kernel;grad-check active+A=0 路径 + 退化 ε→∞ vanilla/β=0 无KL)。环 `train_grpo`:采 B prompt × rollout G → checker reward 0/1 → group-relative advantage `(r−mean)/(std+ε)`(无 critic,全对/全错组跳过)→ 存 πθ_old/πref per-token → K 内层 clipped-PG。rollout 用 **M2 引擎 + 新加的 temperature 采样**(单行 logits 比 naive `[seq,vocab]` 轻)。**先把任务改简单**:v12 SFT 在硬/易题都 ~8-9%(只会格式不会算术)→ 在 easy(操作数≤20)上从 v12 base 重训 SFT → held-out **18.7%**;但 250/600 步同样 18.7% = 1B web-text 模型从 ~550 例**不泛化加减法、只记 train**。**两道系统墙(设计文档 Risks 预言)**:① 显存——KL-leash 要 policy+reference 两个 1B fp32-master+Adam≈21GB,加激活在 32GB 5090 上不稳定 OOM → 只能 `β=0`(去掉 reference)跑完;② rollout 长杆——naive 采样增长序列撑碎 allocator,cached 采样更轻但单序列慢仍主导墙钟(~16s/step)。**结果**(easy, β=0, G6·B6, 40步, lr5e-7;150 留出 vs SFT 18.7%):reward 噪声 ~0.58-0.81(被 train 重叠抬),**format 100/100 不崩**(温和 lr 下 β=0 也没崩),**held-out 20.0%**(+1.3pp,~3% 标准误内 = 统计持平)。**M3+M4 一致教训**:模型缺底层能力时,离线偏好(DPO)和在线 RL(GRPO)**都不抬 held-out**——各自在能触及的训练分布上优化目标(被记忆抬高),装不进可泛化算法;**RL 强化模型已会的,不教算术**。**后训练弧诚实终态 = 一套完整、闸门齐全的 SFT → KV-cache → DPO → GRPO 栈**,infra 学全,并测得对齐对"base 缺失能力"能做什么的诚实边界。
|
||||||
|
|
||||||
|
**M2b(批量 KV-cache 解码,已落地,补全 M2 引擎 + 修 rollout 长杆)**:M4 后补的 rollout 长杆修复——一个 prompt 的 **G 个样本同步解码**(每步一次 forward 跑整组 → G× 更少 kernel 启动)。一个新原语 `rope_pos`(逐 row 绝对位置 kernel,G 行共享一个解码位置;闸门 = `[0..n]` 逐位等于全 rope、统一 P 逐行等于 `rope_at(P)`,bit-identical)。引擎 `generate_cached_batch`:`BatchKVCache` 带 G 维,批量 `decode_step` 把 G 贯穿 embed/proj/QK-norm/`rope_pos`/cache;**M2a 两件零改动复用**——`decode_attention` 本就 batch-agnostic(bh=G·nh)、`repeat_kv(nh,batch=G)` 按组广播。闸门 = G 个贪心行逐字节等于单序列(`tests/decode_batch.rs`,8q/2kv 头练 repeat_kv 批量)。**吞吐**(v12, G6·B6, 接进 train_grpo):**~8.5s/step vs 单序列 ~14-16s/step ≈ 1.7×**(rollout-inclusive;未到满 G× 因 per_token_logp + PG 更新也占时间、M2a 主机往返还在);且**显存更稳**(一次批量 forward vs G 次分配撑碎 allocator 的 M4 OOM)。⇒ M2 引擎闭环(M2a 单序列 + M2b 批量),rollout 长杆从"OOM/无界"变成有界 ~1.7× 收益,device 端 cache 是点名的下一杠杆。
|
||||||
|
|
||||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||||
|
|
||||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||||
|
|||||||
Reference in New Issue
Block a user