Files
xtrain/docs/09-batched-forward.md
Gahow Wang 9a25616a30 docs: Phase T10 — batched forward
docs/09-batched-forward.md: the launch-bound diagnosis recap, the
[B*S,dim]-flatten + fused batched-attention design (RoPE per-seq position +
causal masking inline in softmax), the attention forward/backward via
strided-batched GEMM, autograd implications, the looped-split/merge dead-end
post-mortem (1127 tok/s, host round-trips), verification methods + before→after
throughput, and the v3 recommendation (per-rank batch 16-32, single/small world
until KI-5 bucketed all-reduce lands).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:50 +08:00

12 KiB
Raw Permalink Blame History

Phase T10: Batched 多序列 Forward — Design Document

Goal

KI-1 的根因。v2 暴露、v3 重诊断的结论(见 docs/known-issues.md 吞吐瓶颈不是 all-reduce而是 单序列模型设计 launch-bound——T5 的模型一次只过一条 序列,每个 op 的 GEMM 都是 [seq, dim] 这种小矩阵,喂不饱 GPU"batch" 是靠循环 B 次 forward + 让 tape SUM 梯度伪造的,于是 GPU util 只有 015%、显存占 ~8%,加大 global batch 也只是按比例 增加串行 kernel-launchv3 实测 gbatch 32→256 仅 +1.2%)。

T10 的目标:给 model + autograd 加 batch 维,把一个 step 的 B 条序列一次性过模型,让线性层变成 一个大 GEMM 填满 GPU——这是 launch-bound 的根本解。硬闸门是正确性:所有既有 grad-check / PyTorch 对拍 / overfit / DDP 跨 rank 一致必须仍过PyTorch 对拍现在用 B>1xserv 推理仍逐 token 一致;在此之上拿吞吐收益。

核心设计linears 摊平为 [B*S, dim]attention 批量化

一个关键观察:transformer 里只有 attention 需要"序列感知",其余全是 per-token 的。

  • 线性投影 / 逐元素 / norm / embedding / CE 天然是 2D [rows, dim] 上的运算,不在乎 rowsseq 还是 B*S。所以把激活摊平成 [B*S, dim] 喂进去,这些 op 原样复用 q/k/v/o、gate/up/down、lm_head 全部变成一个大 [B*S, dim] × [dim, out] GEMM(填 GPU 的收益所在); embedding 对 [B*S] ids gatherrms_norm / qk_norm / swiglu / silu / 逐元素按行per-token cross_entropy 对 [B*S, vocab] vs [B*S] targetsmean over B*S 行 == 各序列 loss 的均值
  • attention 是唯一序列感知的 opreshape 成 [B, nh, S, hd]每个序列内做因果 SDPAper-seq 因果 mask绝不跨序列 attend),写回 [B*S, dim]。在 (B, nh) 维上批量。
  • RoPE 位置必须 per-sequence 复位:位置 = 在自己序列内的下标(row % S不是摊平后的全局 行号。这是正确性陷阱显式处理kernel 加 period 参数)。
ids [B*S]
  └ embedding           → [B*S, dim]
  每层 block
    rms_norm            → [B*S, dim]            (per-token)
    attention:
      x@Wq/Wk/Wv        → [B*S, dim]            (大 GEMM)
      reshape           → [B*S, nh, hd]
      qk-norm + RoPE(period=S) → [B*S, nh, hd]  (RoPE 位置 = row % S)
      → [B, nh, S, hd] → [B*nh, S, hd]
      fused batched SDPA(因果) → [B*nh, S, hd]   (2× 批量GEMM + 1× causal-softmax)
      → [B*S, dim]; @Wo → [B*S, dim]            (大 GEMM)
    +residual; rms_norm; SwiGLU MLP(大 GEMM); +residual
  final rms_norm; @lm_head → [B*S, vocab]
cross_entropy([B*S, vocab], targets[B*S])       → 标量B*S 行的 mean

forward(ids[seq]) 现在只是 forward_batched(ids, batch=1) 的特例 → 采样 / 推理路径 batch=1 不变

Module Layout

csrc/ops/attention.cu                    # 新causal 行 softmax含 scale + per-seq 因果 mask
csrc/ops/nn.cu                           # rope/rope_dx 加 period位置 = tok % period
csrc/ops/model.cu                        # 加 transpose_4d12[B,S,nh,hd]<->[B,nh,S,hd]
crates/xtrain-cuda/
├── src/cublas.rs                        # 加 sgemm_strided_batched批量 GEMM行主序同 sgemm 套路)
├── src/ffi.rs                           # +cublasSgemmStridedBatched / launch_softmax_causal / transpose_4d12rope 加 period
└── build.rs                             # +attention.cu
crates/xtrain-tensor/src/tensor.rs       # 加 attention/attention_backwardfused2/4× 批量GEMM、transpose_4d12rope(+period)
crates/xtrain-autodiff/
├── src/ops.rs                           # 加 attention 节点、transpose_4d12 节点rope(+period)
└── tests/autograd.rs                    # +rope_batched / transpose_4d12 / attention(batched) grad-check
crates/xtrain-model/
├── src/model.rs                         # forward_batched(ids,batch)/loss_batchedattention 走 fused 批量 op删 causal_mask 叶
├── src/lib.rs                           # +batched_ids_tensor
├── tests/batched.rs                     # 新batched == looped 单序列logits + 梯度)
├── tests/parity_dump.rs + parity.py     # PyTorch 对拍改 B=2,S=4per-seq RoPE + per-seq mask
crates/xtrain-train/src/train_loop.rs    # 真 batch一次 batched forward/backward 替代 loop+SUMclip pre-scale 1.0
crates/xtrain-distributed/
├── src/ddp.rs                           # 每 rank 一次 batched forwardb_local 条clip pre-scale 1.0
└── tests/ddp_correctness.rs             # 单卡基线也改 batched对齐

Key Design Decisions

① 为什么"摊平 linears"是主要收益

GPU 填不满是因为 GEMM 太小([256, 384] 这种)。摊平成 [B*S, dim]B=16/S=256 时左矩阵是 [4096, 384]——同一个 cuBLAS GEMM kernel但 M 大了 16×一次 launch 干 16× 的活,启动开销被摊薄, SM 占用率上去。embedding / rms_norm / silu / CE 等本就按行,摊平后只是行数变多,零改动复用

② RoPE 位置 per-sequence 复位(正确性陷阱)

摊平后第 r 行是序列 r/S 的第 r%S 个 token。RoPE 角度 = pos * freqpos 必须是 r%S 否则第 2 条及之后的序列会用错位置、与单序列结果不一致、PyTorch 对拍直接挂。kernel 加 period = 序列长度 S参数pos = tok % periodperiod == tokens(单序列)时退化为原"位置=行号"。 反向同样按 period 取角度RoPE 是正交映射,反向 = 逆旋转,不需缓存 forward

③ Attention先摊平到 [B*nh, S, hd]fused 批量 SDPA

q/k/v 投影后是 [B*S, nh, hd](顺序 b,s,head,hd。要喂批量 GEMM需排成 [B*nh, S, hd](每个 batch 元素 = 一个 (b, head) 的整条序列):

[B*S, nh, hd] → reshape [B, S, nh, hd] → transpose_4d12 → [B, nh, S, hd] → reshape [B*nh, S, hd]

transpose_4d12(轴 1,2 互换,新加的结构 op自反式 backward是关键。然后 fused attention op

scores[B*nh,S,S] = Q · Kᵀ                       (cublasSgemmStridedBatched)
P[B*nh,S,S]      = softmax(causal(scores / √hd)) (launch_softmax_causal行内做 mask+scale)
out[B*nh,S,hd]   = P · V                         (cublasSgemmStridedBatched)

整个 attention 不论 B*nh 多大,都是 3 次 kernel launch2 批量 GEMM + 1 softmax没有 per-head / per-seq 的 Python 循环,也没有 host 往返。

因果 mask 内联在 softmax kernel:第 r 行的 query 位置 = r % S,列 j > r%S 是未来 → 概率置 0 不需要 [S,S]-1e9 加性 mask 张量(连 causal_mask 叶子都删了)。1/√hd 的 scale 也折进 softmax取 max/exp 前乘),省一个 scale pass。

踩坑记录(重要)T10 的第一版 attention 用"per-(batch,head) 循环 + split_heads/merge_heads" 而 split/merge_heads 内部走 host 往返to_device(Cpu) 再传回。linears 确实摊平成大 GEMM 了, 但 B=16 时 attention 路径里成千上万次小 kernel + 几 MB 的 host memcpy 反而把吞吐拖到 1127 tok/s (比单序列还慢)。教训:摊平 linears 是必要非充分——attention 的 host 往返 / launch 风暴必须一起干掉。 fused 批量 SDPA 之后才到 25.6K tok/s

④ Attention backward手写grad-check 兜底)

fused op 的反向是标准 attention 反传,全用批量 GEMM + 复用既有 softmax_backward

dP      = dOut · Vᵀ ;  dV = Pᵀ · dOut
dScores = softmax_jacobian(P, dP) · (1/√hd)      (复用行 softmax 反向mask 处 P=0 → 自动零梯度)
dQ      = dScores · K ;  dK = dScoresᵀ · Q

新加的 attention / transpose_4d12 都有 finite-diff grad-check见 ⑥)。

⑤ 训练 loop真 batch 替代 loop+SUM

旧:循环 B 次 model.loss(单序列) + backward(),靠 tape 把 B 份梯度 SUMclip 时 ×1/B 还原 batch-mean。 新:batched_ids_tensor 把 B 条序列摊平成 [B*S]一次 loss_batched + 一次 backward()。 CE 已是 B*S 行的 mean = batch-mean lossbackward 直接给出 batch-mean 梯度,所以 clip pre-scale = 1.0 (不再有 loop+SUM+×1/B

DDP 同理且保持等价:每 rank 跑 b_local = B_global/world 条的一次 batched forward → backward 梯度 = 本地 mean Σ_local/b_localall_reduce_average(跨 rank 求和 /worldΣ_global/(world·b_local) = Σ_global/B_global = 全局 batch-mean → clip pre-scale 也 = 1.0。 单卡基线(ddp_correctness 里)同步改成对全局 batch 的一次 batched forward二者对齐。

验证方法

双闸门,都必须过。

正确性(无回归)

  • 算子级 finite-diffxtrain-autodiff15 个):新增 rope_batchedper-seq 位置)、transpose_4d12attention(batched) 的 dQ/dK/dV连同既有 12 个全过。 实测:attn(batched) dQ 7.5e-3 / dK 1.5e-2 / dV 2.9e-4transpose_4d12 8.2e-5rope_batched 4.5e-4
  • batched == looped 单序列xtrain-model/tests/batched.rs同一组权重batched forward 对 "逐序列单独 forward 再拼接" 的 logits + 每参数梯度逐一对比。实测 logits 完全一致(0.0) 梯度 max rel 6.4e-4loss 完全一致。
  • PyTorch 对拍 B>1parity.py,改 B=2/S=4等价 PyTorch 模型per-seq RoPE pos=row%S、per-seq 因果 mask、QK-norm、SwiGLU对拍 forward logits + 全部 25 个参数梯度。实测 loss relerr 5e-8、logits 6.9e-6、梯度全在 rtol 2e-2 内
  • overfit27/27 tokencheckpoint 逐位AdamW 对 torchDDP 跨 rank 参数 bit-identical (0.0) + DDP loss 对单卡 5.7e-7:全过。
  • xserv 闭环短训→导出→xserv serve对 xtrain 贪心仍逐 token 一致;采样路径 batch=1 仍工作。

吞吐(收益)

单卡 dim384/12L/12h、batch 16、seq 256back-to-back

路径 tok/s GPU util 显存
before(单序列 launch-boundKI-1 基线) ~1653 015% ~3 GB
T10 第一版looped split/mergehost 往返) 1127 ~11% 14.8 GB
afterfused 批量 attentionbatch 16 25627 37% 均值 / 54% 峰 10.2 GB
afterbatch 32 40263

→ 单卡相对 KI-1 基线 ~15.5×batch 16/ ~24×batch 32GPU util 015% → 3754%。

DDP 4 卡dim384, per-rank batch 32, global 1281 卡 40.3K → 4 卡 47.2K tok/sglobal~1.17×。这是 batched forward 新暴露的下一层瓶颈:单卡 compute 快了 1524×每步对全部 67M 参数的 eager all-reduce + host 侧 optimizer/clip 同步成了 DDP 的主导开销(不随卡数缩小)。 注意:单卡 batch 32 = 40K tok/s 已经把 KI-1 时代的 4 卡 3163 tok/s 干翻 ~12×——根因(单卡 launch-bound已修。DDP 近线性需要 bucketed / 与 backward overlap 的 all-reduceKI-1 修复项 2 此前 all-reduce 非瓶颈做了没用,现在才有意义——列为 v3+ 的 follow-up本 Phase 范围外)。

给 v3 的 note已解锁

batched forward 就位后v3dim512/16L、200300M tok建议per-rank batch 1632、seq 2564 卡 global batch 64128。按单卡 ~25K tok/sdim384dim512 略低、4 卡放大200300M tok 的训练时间从 KI-1 时代估的 1726h 压到数小时级。bucketed/overlapped all-reduceKI-1 修复项 2现在才有意义 作为 v3 之后的进一步优化项。