docs/09-batched-forward.md: the launch-bound diagnosis recap, the [B*S,dim]-flatten + fused batched-attention design (RoPE per-seq position + causal masking inline in softmax), the attention forward/backward via strided-batched GEMM, autograd implications, the looped-split/merge dead-end post-mortem (1127 tok/s, host round-trips), verification methods + before→after throughput, and the v3 recommendation (per-rank batch 16-32, single/small world until KI-5 bucketed all-reduce lands). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
12 KiB
Phase T10: Batched 多序列 Forward — Design Document
Goal
修 KI-1 的根因。v2 暴露、v3 重诊断的结论(见 docs/known-issues.md):
吞吐瓶颈不是 all-reduce,而是 单序列模型设计 launch-bound——T5 的模型一次只过一条
序列,每个 op 的 GEMM 都是 [seq, dim] 这种小矩阵,喂不饱 GPU;"batch" 是靠循环 B 次 forward +
让 tape SUM 梯度伪造的,于是 GPU util 只有 0–15%、显存占 ~8%,加大 global batch 也只是按比例
增加串行 kernel-launch(v3 实测 gbatch 32→256 仅 +1.2%)。
T10 的目标:给 model + autograd 加 batch 维,把一个 step 的 B 条序列一次性过模型,让线性层变成 一个大 GEMM 填满 GPU——这是 launch-bound 的根本解。硬闸门是正确性:所有既有 grad-check / PyTorch 对拍 / overfit / DDP 跨 rank 一致必须仍过(PyTorch 对拍现在用 B>1),xserv 推理仍逐 token 一致;在此之上拿吞吐收益。
核心设计:linears 摊平为 [B*S, dim],attention 批量化
一个关键观察:transformer 里只有 attention 需要"序列感知",其余全是 per-token 的。
- 线性投影 / 逐元素 / norm / embedding / CE 天然是 2D
[rows, dim]上的运算,不在乎rows是seq还是B*S。所以把激活摊平成[B*S, dim]喂进去,这些 op 原样复用: q/k/v/o、gate/up/down、lm_head 全部变成一个大[B*S, dim] × [dim, out]GEMM(填 GPU 的收益所在); embedding 对[B*S]ids gather;rms_norm / qk_norm / swiglu / silu / 逐元素按行(per-token); cross_entropy 对[B*S, vocab]vs[B*S]targets,mean over B*S 行 == 各序列 loss 的均值。 - attention 是唯一序列感知的 op:reshape 成
[B, nh, S, hd],每个序列内做因果 SDPA(per-seq 因果 mask,绝不跨序列 attend),写回[B*S, dim]。在(B, nh)维上批量。 - RoPE 位置必须 per-sequence 复位:位置 = 在自己序列内的下标(
row % S),不是摊平后的全局 行号。这是正确性陷阱,显式处理(kernel 加period参数)。
ids [B*S]
└ embedding → [B*S, dim]
每层 block:
rms_norm → [B*S, dim] (per-token)
attention:
x@Wq/Wk/Wv → [B*S, dim] (大 GEMM)
reshape → [B*S, nh, hd]
qk-norm + RoPE(period=S) → [B*S, nh, hd] (RoPE 位置 = row % S)
→ [B, nh, S, hd] → [B*nh, S, hd]
fused batched SDPA(因果) → [B*nh, S, hd] (2× 批量GEMM + 1× causal-softmax)
→ [B*S, dim]; @Wo → [B*S, dim] (大 GEMM)
+residual; rms_norm; SwiGLU MLP(大 GEMM); +residual
final rms_norm; @lm_head → [B*S, vocab]
cross_entropy([B*S, vocab], targets[B*S]) → 标量(B*S 行的 mean)
forward(ids[seq]) 现在只是 forward_batched(ids, batch=1) 的特例 → 采样 / 推理路径 batch=1 不变。
Module Layout
csrc/ops/attention.cu # 新:causal 行 softmax(含 scale + per-seq 因果 mask)
csrc/ops/nn.cu # rope/rope_dx 加 period(位置 = tok % period)
csrc/ops/model.cu # 加 transpose_4d12([B,S,nh,hd]<->[B,nh,S,hd])
crates/xtrain-cuda/
├── src/cublas.rs # 加 sgemm_strided_batched(批量 GEMM,行主序同 sgemm 套路)
├── src/ffi.rs # +cublasSgemmStridedBatched / launch_softmax_causal / transpose_4d12;rope 加 period
└── build.rs # +attention.cu
crates/xtrain-tensor/src/tensor.rs # 加 attention/attention_backward(fused,2/4× 批量GEMM)、transpose_4d12;rope(+period)
crates/xtrain-autodiff/
├── src/ops.rs # 加 attention 节点、transpose_4d12 节点;rope(+period)
└── tests/autograd.rs # +rope_batched / transpose_4d12 / attention(batched) grad-check
crates/xtrain-model/
├── src/model.rs # forward_batched(ids,batch)/loss_batched;attention 走 fused 批量 op;删 causal_mask 叶
├── src/lib.rs # +batched_ids_tensor
├── tests/batched.rs # 新:batched == looped 单序列(logits + 梯度)
├── tests/parity_dump.rs + parity.py # PyTorch 对拍改 B=2,S=4(per-seq RoPE + per-seq mask)
crates/xtrain-train/src/train_loop.rs # 真 batch:一次 batched forward/backward 替代 loop+SUM;clip pre-scale 1.0
crates/xtrain-distributed/
├── src/ddp.rs # 每 rank 一次 batched forward(b_local 条);clip pre-scale 1.0
└── tests/ddp_correctness.rs # 单卡基线也改 batched(对齐)
Key Design Decisions
① 为什么"摊平 linears"是主要收益
GPU 填不满是因为 GEMM 太小([256, 384] 这种)。摊平成 [B*S, dim] 后,B=16/S=256 时左矩阵是
[4096, 384]——同一个 cuBLAS GEMM kernel,但 M 大了 16×,一次 launch 干 16× 的活,启动开销被摊薄,
SM 占用率上去。embedding / rms_norm / silu / CE 等本就按行,摊平后只是行数变多,零改动复用。
② RoPE 位置 per-sequence 复位(正确性陷阱)
摊平后第 r 行是序列 r/S 的第 r%S 个 token。RoPE 角度 = pos * freq,pos 必须是 r%S,
否则第 2 条及之后的序列会用错位置、与单序列结果不一致、PyTorch 对拍直接挂。kernel 加 period
(= 序列长度 S)参数,pos = tok % period;period == tokens(单序列)时退化为原"位置=行号"。
反向同样按 period 取角度(RoPE 是正交映射,反向 = 逆旋转,不需缓存 forward)。
③ Attention:先摊平到 [B*nh, S, hd],fused 批量 SDPA
q/k/v 投影后是 [B*S, nh, hd](顺序 b,s,head,hd)。要喂批量 GEMM,需排成 [B*nh, S, hd](每个 batch
元素 = 一个 (b, head) 的整条序列):
[B*S, nh, hd] → reshape [B, S, nh, hd] → transpose_4d12 → [B, nh, S, hd] → reshape [B*nh, S, hd]
transpose_4d12(轴 1,2 互换,新加的结构 op,自反式 backward)是关键。然后 fused attention op:
scores[B*nh,S,S] = Q · Kᵀ (cublasSgemmStridedBatched)
P[B*nh,S,S] = softmax(causal(scores / √hd)) (launch_softmax_causal,行内做 mask+scale)
out[B*nh,S,hd] = P · V (cublasSgemmStridedBatched)
整个 attention 不论 B*nh 多大,都是 3 次 kernel launch(2 批量 GEMM + 1 softmax),没有 per-head / per-seq 的 Python 循环,也没有 host 往返。
因果 mask 内联在 softmax kernel:第 r 行的 query 位置 = r % S,列 j > r%S 是未来 → 概率置 0,
不需要 [S,S] 的 -1e9 加性 mask 张量(连 causal_mask 叶子都删了)。1/√hd 的 scale 也折进
softmax(取 max/exp 前乘),省一个 scale pass。
踩坑记录(重要):T10 的第一版 attention 用"per-(batch,head) 循环 +
split_heads/merge_heads", 而 split/merge_heads 内部走 host 往返(to_device(Cpu)再传回)。linears 确实摊平成大 GEMM 了, 但 B=16 时 attention 路径里成千上万次小 kernel + 几 MB 的 host memcpy 反而把吞吐拖到 1127 tok/s (比单序列还慢)。教训:摊平 linears 是必要非充分——attention 的 host 往返 / launch 风暴必须一起干掉。 fused 批量 SDPA 之后才到 25.6K tok/s。
④ Attention backward(手写,grad-check 兜底)
fused op 的反向是标准 attention 反传,全用批量 GEMM + 复用既有 softmax_backward:
dP = dOut · Vᵀ ; dV = Pᵀ · dOut
dScores = softmax_jacobian(P, dP) · (1/√hd) (复用行 softmax 反向;mask 处 P=0 → 自动零梯度)
dQ = dScores · K ; dK = dScoresᵀ · Q
新加的 attention / transpose_4d12 都有 finite-diff grad-check(见 ⑥)。
⑤ 训练 loop:真 batch 替代 loop+SUM
旧:循环 B 次 model.loss(单序列) + backward(),靠 tape 把 B 份梯度 SUM,clip 时 ×1/B 还原 batch-mean。
新:batched_ids_tensor 把 B 条序列摊平成 [B*S],一次 loss_batched + 一次 backward()。
CE 已是 B*S 行的 mean = batch-mean loss,backward 直接给出 batch-mean 梯度,所以 clip pre-scale = 1.0
(不再有 loop+SUM+×1/B)。
DDP 同理且保持等价:每 rank 跑 b_local = B_global/world 条的一次 batched forward → backward
梯度 = 本地 mean Σ_local/b_local;all_reduce_average(跨 rank 求和 /world)得
Σ_global/(world·b_local) = Σ_global/B_global = 全局 batch-mean → clip pre-scale 也 = 1.0。
单卡基线(ddp_correctness 里)同步改成对全局 batch 的一次 batched forward,二者对齐。
验证方法
双闸门,都必须过。
正确性(无回归)
- 算子级 finite-diff(
xtrain-autodiff,15 个):新增rope_batched(per-seq 位置)、transpose_4d12、attention(batched)的 dQ/dK/dV,连同既有 12 个全过。 实测:attn(batched) dQ 7.5e-3 / dK 1.5e-2 / dV 2.9e-4,transpose_4d12 8.2e-5,rope_batched 4.5e-4。 - batched == looped 单序列(
xtrain-model/tests/batched.rs,新):同一组权重,batched forward 对 "逐序列单独 forward 再拼接" 的 logits + 每参数梯度逐一对比。实测 logits 完全一致(0.0), 梯度 max rel 6.4e-4,loss 完全一致。 - PyTorch 对拍 B>1(
parity.py,改 B=2/S=4):等价 PyTorch 模型(per-seq RoPEpos=row%S、per-seq 因果 mask、QK-norm、SwiGLU)对拍 forward logits + 全部 25 个参数梯度。实测 loss relerr 5e-8、logits 6.9e-6、梯度全在 rtol 2e-2 内。 - overfit(27/27 token)、checkpoint 逐位、AdamW 对 torch、DDP 跨 rank 参数 bit-identical (0.0) + DDP loss 对单卡 5.7e-7:全过。
- xserv 闭环:短训→导出→xserv serve,对 xtrain 贪心仍逐 token 一致;采样路径 batch=1 仍工作。
吞吐(收益)
单卡 dim384/12L/12h、batch 16、seq 256,back-to-back:
| 路径 | tok/s | GPU util | 显存 |
|---|---|---|---|
| before(单序列 launch-bound,KI-1 基线) | ~1653 | 0–15% | ~3 GB |
| T10 第一版(looped split/merge,host 往返) | 1127 | ~11% | 14.8 GB |
| after(fused 批量 attention,batch 16) | 25627 | 37% 均值 / 54% 峰 | 10.2 GB |
| after(batch 32) | 40263 | — | — |
→ 单卡相对 KI-1 基线 ~15.5×(batch 16)/ ~24×(batch 32),GPU util 0–15% → 37–54%。
DDP 4 卡(dim384, per-rank batch 32, global 128):1 卡 40.3K → 4 卡 47.2K tok/s(global), 仅 ~1.17×。这是 batched forward 新暴露的下一层瓶颈:单卡 compute 快了 15–24×后,每步对全部 67M 参数的 eager all-reduce + host 侧 optimizer/clip 同步成了 DDP 的主导开销(不随卡数缩小)。 注意:单卡 batch 32 = 40K tok/s 已经把 KI-1 时代的 4 卡 3163 tok/s 干翻 ~12×——根因(单卡 launch-bound)已修。DDP 近线性需要 bucketed / 与 backward overlap 的 all-reduce(KI-1 修复项 2), 此前 all-reduce 非瓶颈做了没用,现在才有意义——列为 v3+ 的 follow-up(本 Phase 范围外)。
给 v3 的 note(已解锁)
batched forward 就位后,v3(dim512/16L、200–300M tok)建议:per-rank batch 16–32、seq 256,4 卡 global batch 64–128。按单卡 ~25K tok/s(dim384,dim512 略低)、4 卡放大,200–300M tok 的训练时间从 KI-1 时代估的 17–26h 压到数小时级。bucketed/overlapped all-reduce(KI-1 修复项 2)现在才有意义, 作为 v3 之后的进一步优化项。