docs: M2d — ragged-batching lever, 9× measured, step bottleneck → rollout

Records the M2d lever (batch the GRPO training-side forwards), the right-pad-is-free
insight, both exact gates, the end-to-end no-OOM smoke, and the 9× throughput.

The honest decomposition correction: M2c claimed the training forwards "dominate" the
step; the clean per-component bench falsifies the strong form — they were ~2.5 s of
the ~8.5 s step (~30%), worth the 9×, but the rollout (~6 s) was always the larger
share. After M2d the step is ~95% rollout, so the next step-level lever is full B×G
rollout batching (today only the G samples of each prompt decode in lockstep; the B
prompts are still sequential). Same measure-first lesson, once more.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-30 23:03:28 +08:00
parent 0e82b2438e
commit 4379868f2d
2 changed files with 53 additions and 0 deletions

View File

@@ -576,3 +576,54 @@ still a real, correctness-gated improvement (cleaner code, less PCIe, ~10% decod
headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**, headline is that the *next* decode lever is **ragged batched prefill of the per-sample forwards**,
not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache), not the cache. The M2 decode engine is now M2a (single-seq) + M2b (batched) + M2c (device cache),
all token-identical-gated; the post-training stack remains complete with its bottleneck mapped. all token-identical-gated; the post-training stack remains complete with its bottleneck mapped.
### M2d — batch the GRPO training-side forwards (landed; the lever M2c named, + a decomposition correction)
M2c named the next lever: **ragged batched prefill of the per-sample training-side forwards**. Those
forwards are the two phases that, per step, run one single-sequence `forward` per sample: the
`per_token_logp` **captures** (logπ_old policy + logπ_ref reference) and the inner **clipped-PG**
forward/backwards. M2d packs all `N = B·G` ragged samples of a step into ONE `forward_batched`.
**The enabling property — right-padding is free under causal attention.** Pad each ragged completion
on the RIGHT to the batch's `Lmax`. A real completion row sits at an earlier position than the
trailing pad, and causal masking forbids attending forward, so its logits are **bit-identical** to
the unpadded single-sequence forward; the pad rows are garbage but masked out (`target = -100`). This
is exactly why training engines pad-and-mask rather than run ragged. Two new pieces:
- `per_token_logp_batched` (`crates/xtrain-train/src/grpo_batch.rs`): right-pad → one
`forward_batched(batch = N)` → slice each sample's logπ back to its real length.
- `ops::clipped_pg_loss_batched` (`crates/xtrain-autodiff/src/ops.rs`): like the per-sample
`clipped_pg_loss`, but takes **per-row** `advantage[t]` (the owning sample's `A`) and **per-row**
`weight[t]` (the full normaliser; the caller passes `1/(N·n_s)`). It does NOT compute its own
`1/n_tokens`, so folding `weight = 1/(N·n_s)` reproduces the looped `Σ_s (1/N)(1/n_s)…`
**bit-for-bit** (the per-row CE backward is row-local). A `--micro` knob packs in chunks to bound
the `[chunk·Lmax, vocab]` logits memory; the weight uses the GLOBAL `N`, so chunked
grad-accumulation is exact. Both `train_grpo` and the bench call these shared helpers.
**Correctness gates (exact, not bf16-noisy):**
- `xtrain-model::forward_batched_ragged_matches_looped` — forward_batched on right-padded ragged
sequences == per-sequence single-seq forward on the real rows, **max|Δlogit| = 3.7e-7 (fp32) and
0.0 (bf16)**, both composed + flash. Pins "right-pad is free".
- `xtrain-autodiff::clipped_pg_loss_batched_matches_looped` — batched op == looped
`Σ_s (1/N)·clipped_pg_loss_s`, **loss Δ=1.5e-8, grad max|Δ|=7.5e-9 (f32)**.
Composed, these prove the batched GRPO step == the looped step. End-to-end: a short SFT (v12 base,
150 steps, arith) → `train_grpo` 12 steps runs clean — **no OOM** (1B master + AdamW + batched
activations fit with `micro=16`), mean-reward rises, the batched inner executes.
**Throughput (bench `bin/bench_grpo_batch`, v12 1.05B, N=48 ragged, micro=16, β=0, weight-independent):**
| phase (per step) | looped (single-seq) | batched (M2d) | speedup |
|-------------------------|---------------------|---------------|---------|
| capture `per_token_logp`| 622 ms | 71 ms | 8.7× |
| inner clipped-PG fwd+bwd| 1907 ms | 208 ms | 9.2× |
| **training forwards** | **2526 ms** | **280 ms** | **9.0×**|
**The decomposition correction (the honest finding).** M2c claimed "the per-sample training forwards
now dominate the step." The clean per-component bench falsifies the strong form: the training
forwards were **~2.5 s of the ~8.5 s step (~30%)** — substantial and worth the 9× win, but the
**rollout (`generate_cached_batch`, ~6 s) was always the larger share.** After M2d cuts the training
forwards to ~0.28 s, the step is **~95% rollout** — the long pole has swung back to the rollout. So
M2d removes the training-forward overhang (a real, exactly-gated 9× on its component), and re-confirms
the same measure-first lesson one more time: the next **step-level** lever is **full B×G rollout
batching** — today only the `G` samples of each prompt decode in lockstep (M2b); the `B` prompts are
still sequential. M2d closes the "ragged batched per-sample forwards" lever M2c named; the post-
training stack stays complete, now with the step decomposition measured, not asserted.

View File

@@ -107,6 +107,8 @@ Phase 1/2 把**预训练全栈**学完后Phase 3 转向**后训练 infra**
**M2cdevice 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**K/V device `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它( host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concatdecode_kv 单序列 + decode_batch 批量仍 **token-identical**GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133147 tok/s@128), **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward( T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实闸门齐全的改进(更干净 PCIe解码 +10%),但下一杠杆是 **per-sample forward ragged 批量**而非 cacheM2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache), token-identical-gated;后训练栈完整瓶颈已测绘 **M2cdevice 端 KV cache,已落地,瓶颈转移的 profile-first 发现)**K/V device `[bh,T,hd]`(每层 `Option<Tensor>`),每步用新 `cat_seq` kernel(沿 seq 拼接)append 一个 token——去掉 M2a/M2b 每层**主机往返** + `transpose_3d01`,单序列和批量都重构到它( host Vec+rebuild 干净)。闸门全保:`cat_seq`==host concatdecode_kv 单序列 + decode_batch 批量仍 **token-identical**GQA 训练路径不受影响。**发现(measure-first 的点,不是加速故事)**:去掉主机往返让**纯单序列解码 +10%**(133147 tok/s@128), **GRPO step 不动**(~8.5s/step)——因为 M2b 批量化后 rollout 已不是 step 瓶颈,**per-sample `per_token_logp` 捕获(2×/样本)+ PG 更新 forward/backward(全序列 `model.forward`)成了主导**。长杆从 rollout **转移**到训练侧 forward( T11/T17/M2a:profile 后再动手——你修的不是剩下的瓶颈)。device cache 仍是真实闸门齐全的改进(更干净 PCIe解码 +10%),但下一杠杆是 **per-sample forward ragged 批量**而非 cacheM2 引擎现 = M2a(单序列)+ M2b(批量)+ M2c(device cache), token-identical-gated;后训练栈完整瓶颈已测绘
**M2d批量 GRPO 训练侧 forward,已落地,M2c 点名的杠杆 + 一处 decomposition 纠正)**M2c 点名的下一杠杆——把每步 `N=B·G` ragged 样本的训练侧 forward(`per_token_logp` 捕获 + inner clipped-PG fwd/bwd)打包进**一次 `forward_batched`**。**使能性质 = causal 下右 padding 免费**:真 completion 行位置早于尾部 pad,causal 禁止前向 attend,故真行 logits 与单序列 forward **逐位相同**,pad 行垃圾被 `target=-100` 屏蔽——这正是训练引擎 pad-and-mask 而非跑 ragged 的原因两件新东西:`per_token_logp_batched`( pad 一次 `forward_batched(N)` 按真长切片)、`ops::clipped_pg_loss_batched`(per-row `advantage[t]` + per-row `weight[t]`,caller `1/(N·n_s)`,op 不再自算 `1/n_tokens` 折进 weight 即与 looped `Σ_s (1/N)(1/n_s)…` **逐位等价**;`--micro` 分块界定 `[chunk·Lmax,vocab]` logits 显存,weight 用全局 N 故分块梯度累积精确)。**两道精确闸门**:`forward_batched_ragged_matches_looped`( pad 批量 forward == 单序列,fp32 max|Δ|=3.7e-7bf16 **0.0**,composed+flash)+ `clipped_pg_loss_batched_matches_looped`(批量 op == looped,loss Δ=1.5e-8/grad 7.5e-9,f32),复合即证端到端等价;端到端短 SFT`train_grpo` 12 ** OOM**(1B master+AdamW+批量激活 micro=16 容得下)、批量 inner 执行。**吞吐(bench,v12 1.05B,N=48,micro16,权重无关)**:capture 62271ms(8.7×)、inner 1907208ms(9.2×)、**训练侧 forward 合计 2526280ms(9.0×)**。**Decomposition 纠正(诚实发现)**:M2c "训练侧 forward 主导 step",干净分量 bench 证伪强形式——训练侧 forward **~8.5s step 里的 ~2.5s(~30%)**,可观值这 9×, **rollout(`generate_cached_batch` ~6s)一直是更大头**;M2d 把训练侧砍到 ~0.28s ,step **~95% rollout**,长杆又摆回 rollout。⇒ M2d 拔掉训练侧 forward 这块 overhang(分量级精确 9×),再次印证 measure-first:**step 级下一杠杆 = B×G rollout 批量**(今天只有每 prompt G 同步B prompt 仍串行)。后训练栈保持完整,step decomposition 现为**实测**而非断言
## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md) ## 四、perf 杠杆台账(详见 [known-issues.md](known-issues.md)
- **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。 - **已修**KI-1 单序列 launch-boundT10)· KI-5 per-op cudaMalloc 串行T11)· KI-2 bf16/OOMT12)· KI-3 激活重计算T13解锁 dim1024v8 用上)。