docs: Phase T10 — batched forward

docs/09-batched-forward.md: the launch-bound diagnosis recap, the
[B*S,dim]-flatten + fused batched-attention design (RoPE per-seq position +
causal masking inline in softmax), the attention forward/backward via
strided-batched GEMM, autograd implications, the looped-split/merge dead-end
post-mortem (1127 tok/s, host round-trips), verification methods + before→after
throughput, and the v3 recommendation (per-rank batch 16-32, single/small world
until KI-5 bucketed all-reduce lands).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 00:44:50 +08:00
parent 4ccab0fb42
commit 9a25616a30

187
docs/09-batched-forward.md Normal file
View File

@@ -0,0 +1,187 @@
# Phase T10: Batched 多序列 Forward — Design Document
## Goal
**KI-1 的根因**。v2 暴露、v3 重诊断的结论(见 [docs/known-issues.md](known-issues.md)
吞吐瓶颈**不是** all-reduce而是 **单序列模型设计 launch-bound**——T5 的模型一次只过**一条**
序列,每个 op 的 GEMM 都是 `[seq, dim]` 这种小矩阵,喂不饱 GPU"batch" 是靠**循环 B 次 forward +
让 tape SUM 梯度**伪造的,于是 GPU util 只有 015%、显存占 ~8%,加大 global batch 也只是按比例
增加串行 kernel-launchv3 实测 gbatch 32→256 仅 +1.2%)。
T10 的目标:给 model + autograd 加 **batch 维**,把一个 step 的 B 条序列**一次性**过模型,让线性层变成
**一个大 GEMM** 填满 GPU——这是 launch-bound 的根本解。**硬闸门是正确性**:所有既有 grad-check /
PyTorch 对拍 / overfit / DDP 跨 rank 一致**必须仍过**PyTorch 对拍现在用 **B>1**xserv 推理仍**逐
token 一致**;在此之上拿吞吐收益。
## 核心设计linears 摊平为 `[B*S, dim]`attention 批量化
一个关键观察:**transformer 里只有 attention 需要"序列感知"**,其余全是 per-token 的。
- **线性投影 / 逐元素 / norm / embedding / CE 天然是 2D `[rows, dim]` 上的运算**,不在乎 `rows`
`seq` 还是 `B*S`。所以把激活**摊平成 `[B*S, dim]`** 喂进去,这些 op **原样复用**
q/k/v/o、gate/up/down、lm_head 全部变成**一个大 `[B*S, dim] × [dim, out]` GEMM**(填 GPU 的收益所在);
embedding 对 `[B*S]` ids gatherrms_norm / qk_norm / swiglu / silu / 逐元素按行per-token
cross_entropy 对 `[B*S, vocab]` vs `[B*S]` targets**mean over B*S 行 == 各序列 loss 的均值**。
- **attention 是唯一序列感知的 op**reshape 成 `[B, nh, S, hd]`**每个序列内**做因果 SDPAper-seq
因果 mask**绝不跨序列 attend**),写回 `[B*S, dim]`。在 `(B, nh)` 维上批量。
- **RoPE 位置必须 per-sequence 复位**:位置 = 在**自己序列内**的下标(`row % S`**不是**摊平后的全局
行号。这是正确性陷阱显式处理kernel 加 `period` 参数)。
```text
ids [B*S]
└ embedding → [B*S, dim]
每层 block
rms_norm → [B*S, dim] (per-token)
attention:
x@Wq/Wk/Wv → [B*S, dim] (大 GEMM)
reshape → [B*S, nh, hd]
qk-norm + RoPE(period=S) → [B*S, nh, hd] (RoPE 位置 = row % S)
→ [B, nh, S, hd] → [B*nh, S, hd]
fused batched SDPA(因果) → [B*nh, S, hd] (2× 批量GEMM + 1× causal-softmax)
→ [B*S, dim]; @Wo → [B*S, dim] (大 GEMM)
+residual; rms_norm; SwiGLU MLP(大 GEMM); +residual
final rms_norm; @lm_head → [B*S, vocab]
cross_entropy([B*S, vocab], targets[B*S]) → 标量B*S 行的 mean
```
`forward(ids[seq])` 现在只是 `forward_batched(ids, batch=1)` 的特例 → **采样 / 推理路径 batch=1 不变**
## Module Layout
```
csrc/ops/attention.cu # 新causal 行 softmax含 scale + per-seq 因果 mask
csrc/ops/nn.cu # rope/rope_dx 加 period位置 = tok % period
csrc/ops/model.cu # 加 transpose_4d12[B,S,nh,hd]<->[B,nh,S,hd]
crates/xtrain-cuda/
├── src/cublas.rs # 加 sgemm_strided_batched批量 GEMM行主序同 sgemm 套路)
├── src/ffi.rs # +cublasSgemmStridedBatched / launch_softmax_causal / transpose_4d12rope 加 period
└── build.rs # +attention.cu
crates/xtrain-tensor/src/tensor.rs # 加 attention/attention_backwardfused2/4× 批量GEMM、transpose_4d12rope(+period)
crates/xtrain-autodiff/
├── src/ops.rs # 加 attention 节点、transpose_4d12 节点rope(+period)
└── tests/autograd.rs # +rope_batched / transpose_4d12 / attention(batched) grad-check
crates/xtrain-model/
├── src/model.rs # forward_batched(ids,batch)/loss_batchedattention 走 fused 批量 op删 causal_mask 叶
├── src/lib.rs # +batched_ids_tensor
├── tests/batched.rs # 新batched == looped 单序列logits + 梯度)
├── tests/parity_dump.rs + parity.py # PyTorch 对拍改 B=2,S=4per-seq RoPE + per-seq mask
crates/xtrain-train/src/train_loop.rs # 真 batch一次 batched forward/backward 替代 loop+SUMclip pre-scale 1.0
crates/xtrain-distributed/
├── src/ddp.rs # 每 rank 一次 batched forwardb_local 条clip pre-scale 1.0
└── tests/ddp_correctness.rs # 单卡基线也改 batched对齐
```
## Key Design Decisions
### ① 为什么"摊平 linears"是主要收益
GPU 填不满是因为 GEMM 太小(`[256, 384]` 这种)。摊平成 `[B*S, dim]`B=16/S=256 时左矩阵是
`[4096, 384]`——同一个 cuBLAS GEMM kernel但 M 大了 16×一次 launch 干 16× 的活,启动开销被摊薄,
SM 占用率上去。**embedding / rms_norm / silu / CE 等本就按行**,摊平后只是行数变多,**零改动复用**。
### ② RoPE 位置 per-sequence 复位(正确性陷阱)
摊平后第 `r` 行是序列 `r/S` 的第 `r%S` 个 token。RoPE 角度 = `pos * freq`**`pos` 必须是 `r%S`**
否则第 2 条及之后的序列会用错位置、与单序列结果不一致、PyTorch 对拍直接挂。kernel 加 `period`
= 序列长度 S参数`pos = tok % period``period == tokens`(单序列)时退化为原"位置=行号"。
反向同样按 `period` 取角度RoPE 是正交映射,反向 = 逆旋转,不需缓存 forward
### ③ Attention先摊平到 `[B*nh, S, hd]`fused 批量 SDPA
q/k/v 投影后是 `[B*S, nh, hd]`(顺序 b,s,head,hd。要喂批量 GEMM需排成 `[B*nh, S, hd]`(每个 batch
元素 = 一个 (b, head) 的整条序列):
```text
[B*S, nh, hd] → reshape [B, S, nh, hd] → transpose_4d12 → [B, nh, S, hd] → reshape [B*nh, S, hd]
```
`transpose_4d12`(轴 1,2 互换,新加的结构 op自反式 backward是关键。然后 **fused attention op**
```text
scores[B*nh,S,S] = Q · Kᵀ (cublasSgemmStridedBatched)
P[B*nh,S,S] = softmax(causal(scores / √hd)) (launch_softmax_causal行内做 mask+scale)
out[B*nh,S,hd] = P · V (cublasSgemmStridedBatched)
```
**整个 attention 不论 B*nh 多大,都是 3 次 kernel launch**2 批量 GEMM + 1 softmax没有 per-head /
per-seq 的 Python 循环,也没有 host 往返。
**因果 mask 内联在 softmax kernel**:第 `r` 行的 query 位置 = `r % S`,列 `j > r%S` 是未来 → 概率置 0
不需要 `[S,S]``-1e9` 加性 mask 张量(连 `causal_mask` 叶子都删了)。`1/√hd` 的 scale 也折进
softmax取 max/exp 前乘),省一个 scale pass。
> **踩坑记录(重要)**T10 的第一版 attention 用"per-(batch,head) 循环 + `split_heads`/`merge_heads`"
> 而 split/merge_heads 内部走 **host 往返**`to_device(Cpu)` 再传回。linears 确实摊平成大 GEMM 了,
> 但 B=16 时 attention 路径里成千上万次小 kernel + 几 MB 的 host memcpy **反而把吞吐拖到 1127 tok/s
> (比单序列还慢)**。教训:摊平 linears 是必要非充分——attention 的 host 往返 / launch 风暴必须一起干掉。
> fused 批量 SDPA 之后才到 **25.6K tok/s**。
### ④ Attention backward手写grad-check 兜底)
fused op 的反向是标准 attention 反传,全用批量 GEMM + 复用既有 `softmax_backward`
```text
dP = dOut · Vᵀ ; dV = Pᵀ · dOut
dScores = softmax_jacobian(P, dP) · (1/√hd) (复用行 softmax 反向mask 处 P=0 → 自动零梯度)
dQ = dScores · K ; dK = dScoresᵀ · Q
```
新加的 `attention` / `transpose_4d12` 都有 finite-diff grad-check见 ⑥)。
### ⑤ 训练 loop真 batch 替代 loop+SUM
旧:循环 B 次 `model.loss(单序列)` + `backward()`,靠 tape 把 B 份梯度 SUMclip 时 ×`1/B` 还原 batch-mean。
新:`batched_ids_tensor` 把 B 条序列摊平成 `[B*S]`**一次** `loss_batched` + **一次** `backward()`
CE 已是 B*S 行的 mean = batch-mean loss**backward 直接给出 batch-mean 梯度**,所以 **clip pre-scale = 1.0**
(不再有 loop+SUM+×1/B
**DDP 同理且保持等价**:每 rank 跑 `b_local = B_global/world` 条的**一次** batched forward → backward
梯度 = 本地 mean `Σ_local/b_local``all_reduce_average`(跨 rank 求和 /world
`Σ_global/(world·b_local) = Σ_global/B_global` = 全局 batch-mean → clip pre-scale 也 = 1.0。
单卡基线(`ddp_correctness` 里)同步改成对全局 batch 的一次 batched forward二者对齐。
## 验证方法
**双闸门,都必须过。**
### 正确性(无回归)
- **算子级 finite-diff**`xtrain-autodiff`15 个):新增 `rope_batched`per-seq 位置)、`transpose_4d12`
`attention(batched)` 的 dQ/dK/dV连同既有 12 个全过。
实测:`attn(batched) dQ 7.5e-3 / dK 1.5e-2 / dV 2.9e-4``transpose_4d12 8.2e-5``rope_batched 4.5e-4`
- **batched == looped 单序列**`xtrain-model/tests/batched.rs`同一组权重batched forward 对
"逐序列单独 forward 再拼接" 的 logits + 每参数梯度逐一对比。实测 **logits 完全一致(0.0)**
梯度 max rel **6.4e-4**loss 完全一致。
- **PyTorch 对拍 B>1**`parity.py`,改 B=2/S=4等价 PyTorch 模型per-seq RoPE `pos=row%S`、per-seq
因果 mask、QK-norm、SwiGLU对拍 forward logits + 全部 25 个参数梯度。实测 **loss relerr 5e-8、logits
6.9e-6、梯度全在 rtol 2e-2 内**。
- **overfit**27/27 token、**checkpoint 逐位**、**AdamW 对 torch**、**DDP 跨 rank 参数 bit-identical
(0.0) + DDP loss 对单卡 5.7e-7**:全过。
- **xserv 闭环**短训→导出→xserv serve对 xtrain 贪心**仍逐 token 一致**;采样路径 batch=1 仍工作。
### 吞吐(收益)
单卡 dim384/12L/12h、batch 16、seq 256back-to-back
| 路径 | tok/s | GPU util | 显存 |
|---|---|---|---|
| **before**(单序列 launch-boundKI-1 基线)| ~1653 | 015% | ~3 GB |
| T10 第一版looped split/mergehost 往返)| 1127 | ~11% | 14.8 GB |
| **after**fused 批量 attentionbatch 16| **25627** | **37% 均值 / 54% 峰** | 10.2 GB |
| **after**batch 32| **40263** | — | — |
→ 单卡相对 KI-1 基线 **~15.5×**batch 16/ **~24×**batch 32GPU util 015% → 3754%。
**DDP 4 卡dim384, per-rank batch 32, global 128**1 卡 40.3K → 4 卡 47.2K tok/sglobal
**~1.17×**。这是 batched forward **新暴露**的下一层瓶颈:单卡 compute 快了 1524×每步**对全部
67M 参数的 eager all-reduce + host 侧 optimizer/clip 同步**成了 DDP 的主导开销(不随卡数缩小)。
注意:**单卡 batch 32 = 40K tok/s 已经把 KI-1 时代的 4 卡 3163 tok/s 干翻 ~12×**——根因(单卡
launch-bound已修。DDP 近线性需要 **bucketed / 与 backward overlap 的 all-reduce**KI-1 修复项 2
此前 all-reduce 非瓶颈做了没用,现在才有意义——列为 v3+ 的 follow-up本 Phase 范围外)。
## 给 v3 的 note已解锁
batched forward 就位后v3dim512/16L、200300M tok建议**per-rank batch 1632、seq 2564 卡
global batch 64128**。按单卡 ~25K tok/sdim384dim512 略低、4 卡放大200300M tok 的训练时间从
KI-1 时代估的 1726h 压到**数小时级**。bucketed/overlapped all-reduceKI-1 修复项 2现在才有意义
作为 v3 之后的进一步优化项。