docs: M2d — ragged-batching lever, 9× measured, step bottleneck → rollout
Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput. The honest decomposition correction: M2c claimed the training forwards "dominate" the step; the clean per-component bench falsifies the strong form — they were ~2.5 s of the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G rollout batching (today only the G samples of each prompt decode in lockstep; the B prompts are still sequential). Same measure-first lesson, once more. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -576,3 +576,54 @@ still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decod
|
|||||||
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
|
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
|
||||||
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
|
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
|
||||||
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
|
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
|
||||||
|
|
||||||
|
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
|
||||||
|
|
||||||
|
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
|
||||||
|
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
|
||||||
|
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
|
||||||
|
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
|
||||||
|
|
||||||
|
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
|
||||||
|
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
|
||||||
|
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
|
||||||
|
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
|
||||||
|
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
|
||||||
|
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
|
||||||
|
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
|
||||||
|
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
|
||||||
|
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
|
||||||
|
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
|
||||||
|
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
|
||||||
|
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
|
||||||
|
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
|
||||||
|
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
|
||||||
|
|
||||||
|
**Correctness gates (exact, not bf16-noisy):**
|
||||||
|
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
|
||||||
|
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
|
||||||
|
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
|
||||||
|
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
|
||||||
|
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
|
||||||
|
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
|
||||||
|
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
|
||||||
|
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
|
||||||
|
|
||||||
|
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
|
||||||
|
|
||||||
|
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|
||||||
|
|-------------------------|---------------------|---------------|---------|
|
||||||
|
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
|
||||||
|
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
|
||||||
|
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
|
||||||
|
|
||||||
|
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
|
||||||
|
now dominate the step." The clean per-component bench falsifies the strong form: the training
|
||||||
|
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
|
||||||
|
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
|
||||||
|
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
|
||||||
|
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
|
||||||
|
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
|
||||||
|
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
|
||||||
|
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
|
||||||
|
training stack stays complete, now with the step decomposition measured, not asserted.
|
||||||
|
|||||||
@@ -107,6 +107,8 @@ Phase 1/2 把**预训练全栈**学完后,Phase 3 转向**后训练 infra**(
|
|||||||
|
|
||||||
**M2c(device 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**:K/V 留 device 为 `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它(比 host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concat、decode_kv 单序列 + decode_batch 批量仍 **token-identical**、GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133→147 tok/s@128),但 **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward(同 T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实、闸门齐全的改进(更干净、少 PCIe、解码 +10%),但下一杠杆是 **per-sample forward 的 ragged 批量**而非 cache。M2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache),全 token-identical-gated;后训练栈完整、瓶颈已测绘。
|
**M2c(device 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**:K/V 留 device 为 `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它(比 host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concat、decode_kv 单序列 + decode_batch 批量仍 **token-identical**、GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133→147 tok/s@128),但 **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward(同 T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实、闸门齐全的改进(更干净、少 PCIe、解码 +10%),但下一杠杆是 **per-sample forward 的 ragged 批量**而非 cache。M2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache),全 token-identical-gated;后训练栈完整、瓶颈已测绘。
|
||||||
|
|
||||||
|
**M2d(批量 GRPO 训练侧 forward,已落地,M2c 点名的杠杆 + 一处 decomposition 纠正)**:M2c 点名的下一杠杆——把每步 `N=B·G` 个 ragged 样本的训练侧 forward(`per_token_logp` 捕获 + inner clipped-PG fwd/bwd)打包进**一次 `forward_batched`**。**使能性质 = causal 下右 padding 免费**:真 completion 行位置早于尾部 pad,causal 禁止前向 attend,故真行 logits 与单序列 forward **逐位相同**,pad 行垃圾被 `target=-100` 屏蔽——这正是训练引擎 pad-and-mask 而非跑 ragged 的原因。两件新东西:`per_token_logp_batched`(右 pad → 一次 `forward_batched(N)` → 按真长切片)、`ops::clipped_pg_loss_batched`(per-row `advantage[t]` + per-row `weight[t]`,caller 传 `1/(N·n_s)`,op 不再自算 `1/n_tokens` → 折进 weight 即与 looped `Σ_s (1/N)(1/n_s)…` **逐位等价**;`--micro` 分块界定 `[chunk·Lmax,vocab]` logits 显存,weight 用全局 N 故分块梯度累积精确)。**两道精确闸门**:`forward_batched_ragged_matches_looped`(右 pad 批量 forward == 单序列,fp32 max|Δ|=3.7e-7、bf16 **0.0**,composed+flash)+ `clipped_pg_loss_batched_matches_looped`(批量 op == looped,loss Δ=1.5e-8/grad 7.5e-9,f32),复合即证端到端等价;端到端短 SFT→`train_grpo` 12 步**不 OOM**(1B master+AdamW+批量激活 micro=16 容得下)、批量 inner 执行。**吞吐(bench,v12 1.05B,N=48,micro16,权重无关)**:capture 622→71ms(8.7×)、inner 1907→208ms(9.2×)、**训练侧 forward 合计 2526→280ms(9.0×)**。**Decomposition 纠正(诚实发现)**:M2c 说"训练侧 forward 主导 step",干净分量 bench 证伪强形式——训练侧 forward 是 **~8.5s step 里的 ~2.5s(~30%)**,可观、值这 9×,但 **rollout(`generate_cached_batch` ~6s)一直是更大头**;M2d 把训练侧砍到 ~0.28s 后,step **~95% 是 rollout**,长杆又摆回 rollout。⇒ M2d 拔掉训练侧 forward 这块 overhang(分量级精确 9×),再次印证 measure-first:**step 级下一杠杆 = 全 B×G rollout 批量**(今天只有每 prompt 的 G 同步、B 个 prompt 仍串行)。后训练栈保持完整,step decomposition 现为**实测**而非断言。
|
||||||
|
|
||||||
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md))
|
||||||
|
|
||||||
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
- **已修**:KI-1 单序列 launch-bound(T10)· KI-5 per-op cudaMalloc 串行(T11)· KI-2 bf16/OOM(T12)· KI-3 激活重计算(T13,解锁 dim1024,v8 用上)。
|
||||||
|
|||||||
Reference in New Issue
Block a user