Files
xtrain/docs/09-batched-forward.md
Gahow Wang 9a25616a30 docs: Phase T10 — batched forward
docs/09-batched-forward.md: the launch-bound diagnosis recap, the
[B*S,dim]-flatten + fused batched-attention design (RoPE per-seq position +
causal masking inline in softmax), the attention forward/backward via
strided-batched GEMM, autograd implications, the looped-split/merge dead-end
post-mortem (1127 tok/s, host round-trips), verification methods + before→after
throughput, and the v3 recommendation (per-rank batch 16-32, single/small world
until KI-5 bucketed all-reduce lands).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 00:44:50 +08:00

188 lines
12 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T10: Batched 多序列 Forward — Design Document
## Goal
**KI-1 的根因**。v2 暴露、v3 重诊断的结论(见 [docs/known-issues.md](known-issues.md)
吞吐瓶颈**不是** all-reduce而是 **单序列模型设计 launch-bound**——T5 的模型一次只过**一条**
序列,每个 op 的 GEMM 都是 `[seq, dim]` 这种小矩阵,喂不饱 GPU"batch" 是靠**循环 B 次 forward +
让 tape SUM 梯度**伪造的,于是 GPU util 只有 015%、显存占 ~8%,加大 global batch 也只是按比例
增加串行 kernel-launchv3 实测 gbatch 32→256 仅 +1.2%)。
T10 的目标:给 model + autograd 加 **batch 维**,把一个 step 的 B 条序列**一次性**过模型,让线性层变成
**一个大 GEMM** 填满 GPU——这是 launch-bound 的根本解。**硬闸门是正确性**:所有既有 grad-check /
PyTorch 对拍 / overfit / DDP 跨 rank 一致**必须仍过**PyTorch 对拍现在用 **B>1**xserv 推理仍**逐
token 一致**;在此之上拿吞吐收益。
## 核心设计linears 摊平为 `[B*S, dim]`attention 批量化
一个关键观察:**transformer 里只有 attention 需要"序列感知"**,其余全是 per-token 的。
- **线性投影 / 逐元素 / norm / embedding / CE 天然是 2D `[rows, dim]` 上的运算**,不在乎 `rows`
`seq` 还是 `B*S`。所以把激活**摊平成 `[B*S, dim]`** 喂进去,这些 op **原样复用**
q/k/v/o、gate/up/down、lm_head 全部变成**一个大 `[B*S, dim] × [dim, out]` GEMM**(填 GPU 的收益所在);
embedding 对 `[B*S]` ids gatherrms_norm / qk_norm / swiglu / silu / 逐元素按行per-token
cross_entropy 对 `[B*S, vocab]` vs `[B*S]` targets**mean over B*S 行 == 各序列 loss 的均值**。
- **attention 是唯一序列感知的 op**reshape 成 `[B, nh, S, hd]`**每个序列内**做因果 SDPAper-seq
因果 mask**绝不跨序列 attend**),写回 `[B*S, dim]`。在 `(B, nh)` 维上批量。
- **RoPE 位置必须 per-sequence 复位**:位置 = 在**自己序列内**的下标(`row % S`**不是**摊平后的全局
行号。这是正确性陷阱显式处理kernel 加 `period` 参数)。
```text
ids [B*S]
└ embedding → [B*S, dim]
每层 block
rms_norm → [B*S, dim] (per-token)
attention:
x@Wq/Wk/Wv → [B*S, dim] (大 GEMM)
reshape → [B*S, nh, hd]
qk-norm + RoPE(period=S) → [B*S, nh, hd] (RoPE 位置 = row % S)
→ [B, nh, S, hd] → [B*nh, S, hd]
fused batched SDPA(因果) → [B*nh, S, hd] (2× 批量GEMM + 1× causal-softmax)
→ [B*S, dim]; @Wo → [B*S, dim] (大 GEMM)
+residual; rms_norm; SwiGLU MLP(大 GEMM); +residual
final rms_norm; @lm_head → [B*S, vocab]
cross_entropy([B*S, vocab], targets[B*S]) → 标量B*S 行的 mean
```
`forward(ids[seq])` 现在只是 `forward_batched(ids, batch=1)` 的特例 → **采样 / 推理路径 batch=1 不变**
## Module Layout
```
csrc/ops/attention.cu # 新causal 行 softmax含 scale + per-seq 因果 mask
csrc/ops/nn.cu # rope/rope_dx 加 period位置 = tok % period
csrc/ops/model.cu # 加 transpose_4d12[B,S,nh,hd]<->[B,nh,S,hd]
crates/xtrain-cuda/
├── src/cublas.rs # 加 sgemm_strided_batched批量 GEMM行主序同 sgemm 套路)
├── src/ffi.rs # +cublasSgemmStridedBatched / launch_softmax_causal / transpose_4d12rope 加 period
└── build.rs # +attention.cu
crates/xtrain-tensor/src/tensor.rs # 加 attention/attention_backwardfused2/4× 批量GEMM、transpose_4d12rope(+period)
crates/xtrain-autodiff/
├── src/ops.rs # 加 attention 节点、transpose_4d12 节点rope(+period)
└── tests/autograd.rs # +rope_batched / transpose_4d12 / attention(batched) grad-check
crates/xtrain-model/
├── src/model.rs # forward_batched(ids,batch)/loss_batchedattention 走 fused 批量 op删 causal_mask 叶
├── src/lib.rs # +batched_ids_tensor
├── tests/batched.rs # 新batched == looped 单序列logits + 梯度)
├── tests/parity_dump.rs + parity.py # PyTorch 对拍改 B=2,S=4per-seq RoPE + per-seq mask
crates/xtrain-train/src/train_loop.rs # 真 batch一次 batched forward/backward 替代 loop+SUMclip pre-scale 1.0
crates/xtrain-distributed/
├── src/ddp.rs # 每 rank 一次 batched forwardb_local 条clip pre-scale 1.0
└── tests/ddp_correctness.rs # 单卡基线也改 batched对齐
```
## Key Design Decisions
### ① 为什么"摊平 linears"是主要收益
GPU 填不满是因为 GEMM 太小(`[256, 384]` 这种)。摊平成 `[B*S, dim]`B=16/S=256 时左矩阵是
`[4096, 384]`——同一个 cuBLAS GEMM kernel但 M 大了 16×一次 launch 干 16× 的活,启动开销被摊薄,
SM 占用率上去。**embedding / rms_norm / silu / CE 等本就按行**,摊平后只是行数变多,**零改动复用**。
### ② RoPE 位置 per-sequence 复位(正确性陷阱)
摊平后第 `r` 行是序列 `r/S` 的第 `r%S` 个 token。RoPE 角度 = `pos * freq`**`pos` 必须是 `r%S`**
否则第 2 条及之后的序列会用错位置、与单序列结果不一致、PyTorch 对拍直接挂。kernel 加 `period`
= 序列长度 S参数`pos = tok % period``period == tokens`(单序列)时退化为原"位置=行号"。
反向同样按 `period` 取角度RoPE 是正交映射,反向 = 逆旋转,不需缓存 forward
### ③ Attention先摊平到 `[B*nh, S, hd]`fused 批量 SDPA
q/k/v 投影后是 `[B*S, nh, hd]`(顺序 b,s,head,hd。要喂批量 GEMM需排成 `[B*nh, S, hd]`(每个 batch
元素 = 一个 (b, head) 的整条序列):
```text
[B*S, nh, hd] → reshape [B, S, nh, hd] → transpose_4d12 → [B, nh, S, hd] → reshape [B*nh, S, hd]
```
`transpose_4d12`(轴 1,2 互换,新加的结构 op自反式 backward是关键。然后 **fused attention op**
```text
scores[B*nh,S,S] = Q · Kᵀ (cublasSgemmStridedBatched)
P[B*nh,S,S] = softmax(causal(scores / √hd)) (launch_softmax_causal行内做 mask+scale)
out[B*nh,S,hd] = P · V (cublasSgemmStridedBatched)
```
**整个 attention 不论 B*nh 多大,都是 3 次 kernel launch**2 批量 GEMM + 1 softmax没有 per-head /
per-seq 的 Python 循环,也没有 host 往返。
**因果 mask 内联在 softmax kernel**:第 `r` 行的 query 位置 = `r % S`,列 `j > r%S` 是未来 → 概率置 0
不需要 `[S,S]``-1e9` 加性 mask 张量(连 `causal_mask` 叶子都删了)。`1/√hd` 的 scale 也折进
softmax取 max/exp 前乘),省一个 scale pass。
> **踩坑记录(重要)**T10 的第一版 attention 用"per-(batch,head) 循环 + `split_heads`/`merge_heads`"
> 而 split/merge_heads 内部走 **host 往返**`to_device(Cpu)` 再传回。linears 确实摊平成大 GEMM 了,
> 但 B=16 时 attention 路径里成千上万次小 kernel + 几 MB 的 host memcpy **反而把吞吐拖到 1127 tok/s
> (比单序列还慢)**。教训:摊平 linears 是必要非充分——attention 的 host 往返 / launch 风暴必须一起干掉。
> fused 批量 SDPA 之后才到 **25.6K tok/s**。
### ④ Attention backward手写grad-check 兜底)
fused op 的反向是标准 attention 反传,全用批量 GEMM + 复用既有 `softmax_backward`
```text
dP = dOut · Vᵀ ; dV = Pᵀ · dOut
dScores = softmax_jacobian(P, dP) · (1/√hd) (复用行 softmax 反向mask 处 P=0 → 自动零梯度)
dQ = dScores · K ; dK = dScoresᵀ · Q
```
新加的 `attention` / `transpose_4d12` 都有 finite-diff grad-check见 ⑥)。
### ⑤ 训练 loop真 batch 替代 loop+SUM
旧:循环 B 次 `model.loss(单序列)` + `backward()`,靠 tape 把 B 份梯度 SUMclip 时 ×`1/B` 还原 batch-mean。
新:`batched_ids_tensor` 把 B 条序列摊平成 `[B*S]`**一次** `loss_batched` + **一次** `backward()`
CE 已是 B*S 行的 mean = batch-mean loss**backward 直接给出 batch-mean 梯度**,所以 **clip pre-scale = 1.0**
(不再有 loop+SUM+×1/B
**DDP 同理且保持等价**:每 rank 跑 `b_local = B_global/world` 条的**一次** batched forward → backward
梯度 = 本地 mean `Σ_local/b_local``all_reduce_average`(跨 rank 求和 /world
`Σ_global/(world·b_local) = Σ_global/B_global` = 全局 batch-mean → clip pre-scale 也 = 1.0。
单卡基线(`ddp_correctness` 里)同步改成对全局 batch 的一次 batched forward二者对齐。
## 验证方法
**双闸门,都必须过。**
### 正确性(无回归)
- **算子级 finite-diff**`xtrain-autodiff`15 个):新增 `rope_batched`per-seq 位置)、`transpose_4d12`
`attention(batched)` 的 dQ/dK/dV连同既有 12 个全过。
实测:`attn(batched) dQ 7.5e-3 / dK 1.5e-2 / dV 2.9e-4``transpose_4d12 8.2e-5``rope_batched 4.5e-4`
- **batched == looped 单序列**`xtrain-model/tests/batched.rs`同一组权重batched forward 对
"逐序列单独 forward 再拼接" 的 logits + 每参数梯度逐一对比。实测 **logits 完全一致(0.0)**
梯度 max rel **6.4e-4**loss 完全一致。
- **PyTorch 对拍 B>1**`parity.py`,改 B=2/S=4等价 PyTorch 模型per-seq RoPE `pos=row%S`、per-seq
因果 mask、QK-norm、SwiGLU对拍 forward logits + 全部 25 个参数梯度。实测 **loss relerr 5e-8、logits
6.9e-6、梯度全在 rtol 2e-2 内**。
- **overfit**27/27 token、**checkpoint 逐位**、**AdamW 对 torch**、**DDP 跨 rank 参数 bit-identical
(0.0) + DDP loss 对单卡 5.7e-7**:全过。
- **xserv 闭环**短训→导出→xserv serve对 xtrain 贪心**仍逐 token 一致**;采样路径 batch=1 仍工作。
### 吞吐(收益)
单卡 dim384/12L/12h、batch 16、seq 256back-to-back
| 路径 | tok/s | GPU util | 显存 |
|---|---|---|---|
| **before**(单序列 launch-boundKI-1 基线)| ~1653 | 015% | ~3 GB |
| T10 第一版looped split/mergehost 往返)| 1127 | ~11% | 14.8 GB |
| **after**fused 批量 attentionbatch 16| **25627** | **37% 均值 / 54% 峰** | 10.2 GB |
| **after**batch 32| **40263** | — | — |
→ 单卡相对 KI-1 基线 **~15.5×**batch 16/ **~24×**batch 32GPU util 015% → 3754%。
**DDP 4 卡dim384, per-rank batch 32, global 128**1 卡 40.3K → 4 卡 47.2K tok/sglobal
**~1.17×**。这是 batched forward **新暴露**的下一层瓶颈:单卡 compute 快了 1524×每步**对全部
67M 参数的 eager all-reduce + host 侧 optimizer/clip 同步**成了 DDP 的主导开销(不随卡数缩小)。
注意:**单卡 batch 32 = 40K tok/s 已经把 KI-1 时代的 4 卡 3163 tok/s 干翻 ~12×**——根因(单卡
launch-bound已修。DDP 近线性需要 **bucketed / 与 backward overlap 的 all-reduce**KI-1 修复项 2
此前 all-reduce 非瓶颈做了没用,现在才有意义——列为 v3+ 的 follow-up本 Phase 范围外)。
## 给 v3 的 note已解锁
batched forward 就位后v3dim512/16L、200300M tok建议**per-rank batch 1632、seq 2564 卡
global batch 64128**。按单卡 ~25K tok/sdim384dim512 略低、4 卡放大200300M tok 的训练时间从
KI-1 时代估的 1726h 压到**数小时级**。bucketed/overlapped all-reduceKI-1 修复项 2现在才有意义
作为 v3 之后的进一步优化项。