docs: M2d — ragged-batching lever, 9× measured, step bottleneck → rollout

Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free
insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput.

The honest decomposition correction: M2c claimed the training forwards "dominate" the
step; the clean per-component bench falsifies the strong form — they were ~2.5 s of
the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger
share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G
rollout batching (today only the G samples of each prompt decode in lockstep; the B
prompts are still sequential). Same measure-first lesson, once more.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 23:03:28 +08:00
parent 0e82b2438e
commit 4379868f2d
2 changed files with 53 additions and 0 deletions

View File

@@ -107,6 +107,8 @@ Phase 1/2 把**预训练全栈**学完后Phase 3 转向**后训练 infra**
**M2cdevice 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**K/V device `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它( host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concatdecode_kv 单序列 + decode_batch 批量仍 **token-identical**GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133147 tok/s@128), **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward( T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实闸门齐全的改进(更干净 PCIe解码 +10%),但下一杠杆是 **per-sample forward ragged 批量**而非 cacheM2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache), token-identical-gated;后训练栈完整瓶颈已测绘
**M2d批量 GRPO 训练侧 forward,已落地,M2c 点名的杠杆 + 一处 decomposition 纠正)**M2c 点名的下一杠杆——把每步 `N=B·G` ragged 样本的训练侧 forward(`per_token_logp` 捕获 + inner clipped-PG fwd/bwd)打包进**一次 `forward_batched`**。**使能性质 = causal 下右 padding 免费**:真 completion 行位置早于尾部 pad,causal 禁止前向 attend,故真行 logits 与单序列 forward **逐位相同**,pad 行垃圾被 `target=-100` 屏蔽——这正是训练引擎 pad-and-mask 而非跑 ragged 的原因两件新东西:`per_token_logp_batched`( pad 一次 `forward_batched(N)` 按真长切片)、`ops::clipped_pg_loss_batched`(per-row `advantage[t]` + per-row `weight[t]`,caller `1/(N·n_s)`,op 不再自算 `1/n_tokens` 折进 weight 即与 looped `Σ_s (1/N)(1/n_s)…` **逐位等价**;`--micro` 分块界定 `[chunk·Lmax,vocab]` logits 显存,weight 用全局 N 故分块梯度累积精确)。**两道精确闸门**:`forward_batched_ragged_matches_looped`( pad 批量 forward == 单序列,fp32 max|Δ|=3.7e-7bf16 **0.0**,composed+flash)+ `clipped_pg_loss_batched_matches_looped`(批量 op == looped,loss Δ=1.5e-8/grad 7.5e-9,f32),复合即证端到端等价;端到端短 SFT`train_grpo` 12 ** OOM**(1B master+AdamW+批量激活 micro=16 容得下)、批量 inner 执行。**吞吐(bench,v12 1.05B,N=48,micro16,权重无关)**:capture 62271ms(8.7×)、inner 1907208ms(9.2×)、**训练侧 forward 合计 2526280ms(9.0×)**。**Decomposition 纠正(诚实发现)**:M2c "训练侧 forward 主导 step",干净分量 bench 证伪强形式——训练侧 forward **~8.5s step 里的 ~2.5s(~30%)**,可观值这 9×, **rollout(`generate_cached_batch` ~6s)一直是更大头**;M2d 把训练侧砍到 ~0.28s ,step **~95% rollout**,长杆又摆回 rollout。⇒ M2d 拔掉训练侧 forward 这块 overhang(分量级精确 9×),再次印证 measure-first:**step 级下一杠杆 = B×G rollout 批量**(今天只有每 prompt G 同步B prompt 仍串行)。后训练栈保持完整,step decomposition 现为**实测**而非断言
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)
- **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。