diff --git a/docs/09-batched-forward.md b/docs/09-batched-forward.md new file mode 100644 index 0000000..583ba88 --- /dev/null +++ b/docs/09-batched-forward.md @@ -0,0 +1,187 @@ +# Phase T10: Batched 多序列 Forward — Design Document + +## Goal + +修 **KI-1 的根因**。v2 暴露、v3 重诊断的结论(见 [docs/known-issues.md](known-issues.md)): +吞吐瓶颈**不是** all-reduce,而是 **单序列模型设计 launch-bound**——T5 的模型一次只过**一条** +序列,每个 op 的 GEMM 都是 `[seq, dim]` 这种小矩阵,喂不饱 GPU;"batch" 是靠**循环 B 次 forward + +让 tape SUM 梯度**伪造的,于是 GPU util 只有 0–15%、显存占 ~8%,加大 global batch 也只是按比例 +增加串行 kernel-launch(v3 实测 gbatch 32→256 仅 +1.2%)。 + +T10 的目标:给 model + autograd 加 **batch 维**,把一个 step 的 B 条序列**一次性**过模型,让线性层变成 +**一个大 GEMM** 填满 GPU——这是 launch-bound 的根本解。**硬闸门是正确性**:所有既有 grad-check / +PyTorch 对拍 / overfit / DDP 跨 rank 一致**必须仍过**(PyTorch 对拍现在用 **B>1**),xserv 推理仍**逐 +token 一致**;在此之上拿吞吐收益。 + +## 核心设计:linears 摊平为 `[B*S, dim]`,attention 批量化 + +一个关键观察:**transformer 里只有 attention 需要"序列感知"**,其余全是 per-token 的。 + +- **线性投影 / 逐元素 / norm / embedding / CE 天然是 2D `[rows, dim]` 上的运算**,不在乎 `rows` 是 + `seq` 还是 `B*S`。所以把激活**摊平成 `[B*S, dim]`** 喂进去,这些 op **原样复用**: + q/k/v/o、gate/up/down、lm_head 全部变成**一个大 `[B*S, dim] × [dim, out]` GEMM**(填 GPU 的收益所在); + embedding 对 `[B*S]` ids gather;rms_norm / qk_norm / swiglu / silu / 逐元素按行(per-token); + cross_entropy 对 `[B*S, vocab]` vs `[B*S]` targets,**mean over B*S 行 == 各序列 loss 的均值**。 +- **attention 是唯一序列感知的 op**:reshape 成 `[B, nh, S, hd]`,**每个序列内**做因果 SDPA(per-seq + 因果 mask,**绝不跨序列 attend**),写回 `[B*S, dim]`。在 `(B, nh)` 维上批量。 +- **RoPE 位置必须 per-sequence 复位**:位置 = 在**自己序列内**的下标(`row % S`),**不是**摊平后的全局 + 行号。这是正确性陷阱,显式处理(kernel 加 `period` 参数)。 + +```text +ids [B*S] + └ embedding → [B*S, dim] + 每层 block: + rms_norm → [B*S, dim] (per-token) + attention: + x@Wq/Wk/Wv → [B*S, dim] (大 GEMM) + reshape → [B*S, nh, hd] + qk-norm + RoPE(period=S) → [B*S, nh, hd] (RoPE 位置 = row % S) + → [B, nh, S, hd] → [B*nh, S, hd] + fused batched SDPA(因果) → [B*nh, S, hd] (2× 批量GEMM + 1× causal-softmax) + → [B*S, dim]; @Wo → [B*S, dim] (大 GEMM) + +residual; rms_norm; SwiGLU MLP(大 GEMM); +residual + final rms_norm; @lm_head → [B*S, vocab] +cross_entropy([B*S, vocab], targets[B*S]) → 标量(B*S 行的 mean) +``` + +`forward(ids[seq])` 现在只是 `forward_batched(ids, batch=1)` 的特例 → **采样 / 推理路径 batch=1 不变**。 + +## Module Layout + +``` +csrc/ops/attention.cu # 新:causal 行 softmax(含 scale + per-seq 因果 mask) +csrc/ops/nn.cu # rope/rope_dx 加 period(位置 = tok % period) +csrc/ops/model.cu # 加 transpose_4d12([B,S,nh,hd]<->[B,nh,S,hd]) +crates/xtrain-cuda/ +├── src/cublas.rs # 加 sgemm_strided_batched(批量 GEMM,行主序同 sgemm 套路) +├── src/ffi.rs # +cublasSgemmStridedBatched / launch_softmax_causal / transpose_4d12;rope 加 period +└── build.rs # +attention.cu +crates/xtrain-tensor/src/tensor.rs # 加 attention/attention_backward(fused,2/4× 批量GEMM)、transpose_4d12;rope(+period) +crates/xtrain-autodiff/ +├── src/ops.rs # 加 attention 节点、transpose_4d12 节点;rope(+period) +└── tests/autograd.rs # +rope_batched / transpose_4d12 / attention(batched) grad-check +crates/xtrain-model/ +├── src/model.rs # forward_batched(ids,batch)/loss_batched;attention 走 fused 批量 op;删 causal_mask 叶 +├── src/lib.rs # +batched_ids_tensor +├── tests/batched.rs # 新:batched == looped 单序列(logits + 梯度) +├── tests/parity_dump.rs + parity.py # PyTorch 对拍改 B=2,S=4(per-seq RoPE + per-seq mask) +crates/xtrain-train/src/train_loop.rs # 真 batch:一次 batched forward/backward 替代 loop+SUM;clip pre-scale 1.0 +crates/xtrain-distributed/ +├── src/ddp.rs # 每 rank 一次 batched forward(b_local 条);clip pre-scale 1.0 +└── tests/ddp_correctness.rs # 单卡基线也改 batched(对齐) +``` + +## Key Design Decisions + +### ① 为什么"摊平 linears"是主要收益 + +GPU 填不满是因为 GEMM 太小(`[256, 384]` 这种)。摊平成 `[B*S, dim]` 后,B=16/S=256 时左矩阵是 +`[4096, 384]`——同一个 cuBLAS GEMM kernel,但 M 大了 16×,一次 launch 干 16× 的活,启动开销被摊薄, +SM 占用率上去。**embedding / rms_norm / silu / CE 等本就按行**,摊平后只是行数变多,**零改动复用**。 + +### ② RoPE 位置 per-sequence 复位(正确性陷阱) + +摊平后第 `r` 行是序列 `r/S` 的第 `r%S` 个 token。RoPE 角度 = `pos * freq`,**`pos` 必须是 `r%S`**, +否则第 2 条及之后的序列会用错位置、与单序列结果不一致、PyTorch 对拍直接挂。kernel 加 `period` +(= 序列长度 S)参数,`pos = tok % period`;`period == tokens`(单序列)时退化为原"位置=行号"。 +反向同样按 `period` 取角度(RoPE 是正交映射,反向 = 逆旋转,不需缓存 forward)。 + +### ③ Attention:先摊平到 `[B*nh, S, hd]`,fused 批量 SDPA + +q/k/v 投影后是 `[B*S, nh, hd]`(顺序 b,s,head,hd)。要喂批量 GEMM,需排成 `[B*nh, S, hd]`(每个 batch +元素 = 一个 (b, head) 的整条序列): + +```text +[B*S, nh, hd] → reshape [B, S, nh, hd] → transpose_4d12 → [B, nh, S, hd] → reshape [B*nh, S, hd] +``` + +`transpose_4d12`(轴 1,2 互换,新加的结构 op,自反式 backward)是关键。然后 **fused attention op**: + +```text +scores[B*nh,S,S] = Q · Kᵀ (cublasSgemmStridedBatched) +P[B*nh,S,S] = softmax(causal(scores / √hd)) (launch_softmax_causal,行内做 mask+scale) +out[B*nh,S,hd] = P · V (cublasSgemmStridedBatched) +``` + +**整个 attention 不论 B*nh 多大,都是 3 次 kernel launch**(2 批量 GEMM + 1 softmax),没有 per-head / +per-seq 的 Python 循环,也没有 host 往返。 + +**因果 mask 内联在 softmax kernel**:第 `r` 行的 query 位置 = `r % S`,列 `j > r%S` 是未来 → 概率置 0, +不需要 `[S,S]` 的 `-1e9` 加性 mask 张量(连 `causal_mask` 叶子都删了)。`1/√hd` 的 scale 也折进 +softmax(取 max/exp 前乘),省一个 scale pass。 + +> **踩坑记录(重要)**:T10 的第一版 attention 用"per-(batch,head) 循环 + `split_heads`/`merge_heads`", +> 而 split/merge_heads 内部走 **host 往返**(`to_device(Cpu)` 再传回)。linears 确实摊平成大 GEMM 了, +> 但 B=16 时 attention 路径里成千上万次小 kernel + 几 MB 的 host memcpy **反而把吞吐拖到 1127 tok/s +> (比单序列还慢)**。教训:摊平 linears 是必要非充分——attention 的 host 往返 / launch 风暴必须一起干掉。 +> fused 批量 SDPA 之后才到 **25.6K tok/s**。 + +### ④ Attention backward(手写,grad-check 兜底) + +fused op 的反向是标准 attention 反传,全用批量 GEMM + 复用既有 `softmax_backward`: + +```text +dP = dOut · Vᵀ ; dV = Pᵀ · dOut +dScores = softmax_jacobian(P, dP) · (1/√hd) (复用行 softmax 反向;mask 处 P=0 → 自动零梯度) +dQ = dScores · K ; dK = dScoresᵀ · Q +``` + +新加的 `attention` / `transpose_4d12` 都有 finite-diff grad-check(见 ⑥)。 + +### ⑤ 训练 loop:真 batch 替代 loop+SUM + +旧:循环 B 次 `model.loss(单序列)` + `backward()`,靠 tape 把 B 份梯度 SUM,clip 时 ×`1/B` 还原 batch-mean。 +新:`batched_ids_tensor` 把 B 条序列摊平成 `[B*S]`,**一次** `loss_batched` + **一次** `backward()`。 +CE 已是 B*S 行的 mean = batch-mean loss,**backward 直接给出 batch-mean 梯度**,所以 **clip pre-scale = 1.0** +(不再有 loop+SUM+×1/B)。 + +**DDP 同理且保持等价**:每 rank 跑 `b_local = B_global/world` 条的**一次** batched forward → backward +梯度 = 本地 mean `Σ_local/b_local`;`all_reduce_average`(跨 rank 求和 /world)得 +`Σ_global/(world·b_local) = Σ_global/B_global` = 全局 batch-mean → clip pre-scale 也 = 1.0。 +单卡基线(`ddp_correctness` 里)同步改成对全局 batch 的一次 batched forward,二者对齐。 + +## 验证方法 + +**双闸门,都必须过。** + +### 正确性(无回归) + +- **算子级 finite-diff**(`xtrain-autodiff`,15 个):新增 `rope_batched`(per-seq 位置)、`transpose_4d12`、 + `attention(batched)` 的 dQ/dK/dV,连同既有 12 个全过。 + 实测:`attn(batched) dQ 7.5e-3 / dK 1.5e-2 / dV 2.9e-4`,`transpose_4d12 8.2e-5`,`rope_batched 4.5e-4`。 +- **batched == looped 单序列**(`xtrain-model/tests/batched.rs`,新):同一组权重,batched forward 对 + "逐序列单独 forward 再拼接" 的 logits + 每参数梯度逐一对比。实测 **logits 完全一致(0.0)**, + 梯度 max rel **6.4e-4**,loss 完全一致。 +- **PyTorch 对拍 B>1**(`parity.py`,改 B=2/S=4):等价 PyTorch 模型(per-seq RoPE `pos=row%S`、per-seq + 因果 mask、QK-norm、SwiGLU)对拍 forward logits + 全部 25 个参数梯度。实测 **loss relerr 5e-8、logits + 6.9e-6、梯度全在 rtol 2e-2 内**。 +- **overfit**(27/27 token)、**checkpoint 逐位**、**AdamW 对 torch**、**DDP 跨 rank 参数 bit-identical + (0.0) + DDP loss 对单卡 5.7e-7**:全过。 +- **xserv 闭环**:短训→导出→xserv serve,对 xtrain 贪心**仍逐 token 一致**;采样路径 batch=1 仍工作。 + +### 吞吐(收益) + +单卡 dim384/12L/12h、batch 16、seq 256,back-to-back: + +| 路径 | tok/s | GPU util | 显存 | +|---|---|---|---| +| **before**(单序列 launch-bound,KI-1 基线)| ~1653 | 0–15% | ~3 GB | +| T10 第一版(looped split/merge,host 往返)| 1127 | ~11% | 14.8 GB | +| **after**(fused 批量 attention,batch 16)| **25627** | **37% 均值 / 54% 峰** | 10.2 GB | +| **after**(batch 32)| **40263** | — | — | + +→ 单卡相对 KI-1 基线 **~15.5×**(batch 16)/ **~24×**(batch 32),GPU util 0–15% → 37–54%。 + +**DDP 4 卡(dim384, per-rank batch 32, global 128)**:1 卡 40.3K → 4 卡 47.2K tok/s(global), +仅 **~1.17×**。这是 batched forward **新暴露**的下一层瓶颈:单卡 compute 快了 15–24×后,每步**对全部 +67M 参数的 eager all-reduce + host 侧 optimizer/clip 同步**成了 DDP 的主导开销(不随卡数缩小)。 +注意:**单卡 batch 32 = 40K tok/s 已经把 KI-1 时代的 4 卡 3163 tok/s 干翻 ~12×**——根因(单卡 +launch-bound)已修。DDP 近线性需要 **bucketed / 与 backward overlap 的 all-reduce**(KI-1 修复项 2), +此前 all-reduce 非瓶颈做了没用,现在才有意义——列为 v3+ 的 follow-up(本 Phase 范围外)。 + +## 给 v3 的 note(已解锁) + +batched forward 就位后,v3(dim512/16L、200–300M tok)建议:**per-rank batch 16–32、seq 256,4 卡 +global batch 64–128**。按单卡 ~25K tok/s(dim384,dim512 略低)、4 卡放大,200–300M tok 的训练时间从 +KI-1 时代估的 17–26h 压到**数小时级**。bucketed/overlapped all-reduce(KI-1 修复项 2)现在才有意义, +作为 v3 之后的进一步优化项。