Goal / Module Layout / Key Design Decisions (multi-head layout via reshape+transpose_3d01+split/merge_heads, embedding gather/scatter-add, x@W convention, causal mask, params API, overfit methodology) / 验证方法 with the dash5 results (grad-checks, overfit 2.82->0.004, PyTorch parity). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
8.9 KiB
Phase T5: Tiny Transformer (fwd+bwd) — Design Document
Goal
在 T4 的 autograd 引擎(Var tape + 11 个算子)之上,组装一个 tiny 现代架构 decoder(RoPE + RMSNorm + SwiGLU + 多头因果 attention),跑通 char-level bring-up,并以「把一个小 batch overfit 到 loss≈0」证明整条 forward+backward 计算图正确。
明确范围(T5 只做这些):
- 补 T4 留的缺口算子为正式可微节点:
embedding(按 token id gather / scatter-add 反向)、reshape、transpose(2D + 3D 轴(0,1)),并加split_heads/merge_heads处理多头布局。 - tiny transformer:token embedding →
n_layers× {pre-RMSNorm → 多头因果 attention → 残差;pre-RMSNorm → SwiGLU MLP → 残差} → final RMSNorm → LM head → cross-entropy。 - overfit 验证:极简手写 GD step(
p -= lr·grad)记住一个固定小 batch。
不做(留 T6):真训练 loop、AdamW、LR schedule、grad clip、checkpoint、TinyStories/tokenizer。overfit 不需要优化器,一行 p -= lr·grad 足够。
Module Layout
csrc/ops/model.cu # 新:embedding gather/scatter-add + 3D 轴(0,1) transpose kernel
# (reshape 是纯 metadata,无 kernel)
crates/xtrain-cuda/
├── build.rs # 新增 model.cu
└── src/ffi.rs # 新增 launch_embedding_fwd/bwd + launch_transpose_3d01(no_cuda 门控)
crates/xtrain-tensor/
└── src/tensor.rs # 新增 reshape / embedding(+bwd) / transpose_3d01(no_cuda 门控)
crates/xtrain-autodiff/
├── src/tape.rs # 新增 Var::zero_grad / set_value(供手写 GD step)
├── src/ops.rs # 新增节点 embedding / reshape / transpose_3d01 / transpose_2d
│ # / split_heads(->Vec<Var>) / merge_heads
└── tests/structural.rs # 新增:上述结构算子各自 grad-check
crates/xtrain-model/ # 新 crate:模型本体
├── build.rs # 检测 nvcc → no_cuda cfg(逐 crate)
├── src/
│ ├── lib.rs # 导出 Config / TinyTransformer / 辅助
│ ├── config.rs # 超参(host-only,no_cuda 也编)
│ └── model.rs # TinyTransformer:参数容器 + forward 图(no_cuda 门控)
└── tests/
├── overfit.rs # 端到端:char-level + 手写 GD → loss≈0
├── parity_dump.rs # PyTorch 对拍 fixture(dump 权重/logits/grad)
└── parity.py # 等价 PyTorch 模型,对比 forward + 每参数 grad
为什么新开 xtrain-model crate(而非塞进 autodiff):对齐 xserv 的「模型层独立于算子层」分层;autodiff 只管 Var/ops,model 在其上拼网络,职责清晰。
Key Design Decisions
约定(对齐引擎,不照搬 HuggingFace)
- 线性层权重
[in, out],按x @ W。引擎的 GEMM 是裸A @ B,用[in,out]省掉每次投影的 transpose(Qwen 的[out,in]+x@Wᵀ是推理侧权重布局的产物,训练侧自定义即可,回流 xserv 时在 T9 转置导出)。 dim = n_heads · head_dim(无独立 attention 投影维度),tiny 配置dim=32, n_layers=2, n_heads=2, head_dim=16, ffn_hidden=64。- RoPE position = token 行号(kernel 内建约定);
rope只接受[tokens, heads, head_dim]布局。 - 因果 mask = 加性常量
[seq,seq]:对角线及以下为 0,以上为 −1e9,softmax 前add进 scores。引擎没有 masking 算子,用一个常量 leaf +add即可(该 leaf 无下游、其 grad 不被读取,无害)。
多头布局:reshape + transpose_3d01 + split/merge_heads
引擎的 matmul/softmax 都是 2D 的,多头 attention 因此逐头做 2D SDPA。布局流水线:
proj = x @ Wq # [seq, dim]
reshape → [seq, nh, hd]
rope (kernel 要求正是 [tokens, heads, head_dim])
transpose_3d01 → [nh, seq, hd] # 让每个 head 的 [seq,hd] 块连续
split_heads → nh × [seq, hd] # 每头一个 Var
# 逐头:
scores = (q_i @ k_iᵀ)·(1/√hd) + mask # [seq,seq]
probs = softmax(scores)
out_i = probs @ v_i # [seq, hd]
merge_heads(out) → [nh, seq, hd]
transpose_3d01 → [seq, nh, hd]
reshape → [seq, dim]
@ Wo # 输出投影
reshape是纯 metadata(连续张量改 shape/strides,共享 storage,无 kernel、无数据搬运);[seq, nh·hd] ↔ [seq, nh, hd]正好是它。反向 reshape 回去。transpose_3d01([a,b,c]→[b,a,c])有 kernel,自反:反向是对 grad 再做一次同样的 transpose。split_heads:[nh,seq,hd]在该布局下每个 head 块连续,前向把每块拷成独立连续张量返回Vec<Var>;反向把每头 grad 散回零初始化的[nh,seq,hd],由引擎的扇出 SUM 累加。merge_heads是逆操作。
embedding 前向 gather / 反向 scatter-add
out[s,:] = table[ids[s], :] # table:[vocab,dim], ids:[seq] I32
dtable[ids[s],:] += dout[s,:] # 反向:原子 scatter-add
- ids 是常量(不是
Var),只有 table 参与求导。 - 反向必须 atomicAdd:多个位置可能映射到同一 id,梯度要累加(grad-check 测试里特意放了重复 id
[0,3,1,3,2,0])。dtable先zeros再原子累加。
Block 结构(pre-norm + residual,Qwen3 风格)
h = embedding(ids)
for block:
h = h + attention(rms_norm(h, γ_attn)) # 残差
h = h + swiglu_mlp(rms_norm(h, γ_ffn)) # 残差
logits = rms_norm(h, γ_final) @ lm_head
loss = cross_entropy(logits, targets)
SwiGLU MLP:down( silu(x@W_gate) ∘ (x@W_up) ),复用 T4 的 swiglu = mul(silu(g), u)。
参数 API(为 T6 优化器准备)
TinyTransformer::params() -> Vec<Var>:稳定顺序的全部可学习叶子(embedding / 各 block 的 9 个 / final_norm / lm_head)。Var::set_value(t):原地更新参数值(GD/AdamW 用),保持叶子身份在多 step 间稳定。Var::zero_grad():清梯度。关键:每个 forward 建新图但叶子复用,上一 step 的 grad 不清会被 SUM 累加 → 每 step 后必须 zero。param_to_host(&Var):把参数搬回 hostVec<f32>(GD step / 对拍导出)。
手写 GD step(overfit 用):
for p in params {
if let Some(g) = p.grad() { p.set_value(p.value().add(&g.scale(-lr))); }
p.zero_grad();
}
overfit 方法学
一个固定的小 batch(char-level 文本 → 字符表 → (input, shifted target)),反复跑 forward→backward→GD。只要 fwd+bwd 全对,模型会记住这个 batch,loss → ~0;任何一个 backward 错了,loss 会停在某个台阶下不去。这是比单算子 grad-check 更强的端到端信号。验收同时检查 greedy argmax 是否完全复现 target 序列。
验证方法
全部 #![cfg(not(no_cuda))] 门控,本地只 cargo check/fmt,构建+实跑在 dash5(8× RTX 5090, sm_120)。
- 结构算子 grad-check(
tests/structural.rs):沿用 T4 的L = sum(W∘out)有限差分 harness,对embedding(含重复 id)、reshape、transpose_3d01、transpose_2d、split/merge_heads往返各做一遍。 - overfit(
tests/overfit.rs):tiny 模型 + 手写 GD,断言loss → <0.05且 greedy argmax 全对。 - PyTorch 对拍(
parity_dump.rs+parity.py):Rust dump 出权重/ids/logits/loss/每参数 grad(一次 backward),Python 用完全等价的 PyTorch 模型(同x@W约定、同 RoPE rotate_half pos=行号、同 RMSNorm/SwiGLU/因果 SDPA)跑 fwd+bwd,对比 forward logits + 21 个参数 grad 的相对误差(rtol=2e-2)。
dash5 实测结果
# 结构算子 grad-check(max rel-err)
embedding dTable 3.5e-5 reshape dX 3.4e-4
transpose_2d dX 2.3e-5 transpose_3d01 dX 1.9e-4
split/merge_heads dX 3.9e-5 (5/5 通过)
# overfit(lr=0.3, 200 steps, vocab=16, seq=27, params=21664)
step 0: loss = 2.821415
step 20: loss = 0.341899
step 40: loss = 0.099285
...
step 199: loss = 0.004009
start 2.821415 → final 0.004009 ; greedy match 27/27 ✅
# PyTorch 对拍(rtol=2e-2)
loss: rust=2.505827e0 torch=2.505827e0 relerr=1.4e-8
logits: max relerr = 9.5e-5
params checked: 21 worst = grad[l0_wo] @ 1.1e-3 → PARITY OK ✅
# 无回归:T4 autograd 12/12 仍全绿
forward 与 PyTorch 对到 ~1e-4、每参数 grad 对到 ≤1.1e-3,overfit loss 从 2.82 跌到 0.004 且完全复现 target——三层证据(单算子 finite-diff、端到端 overfit、PyTorch 对拍)一致确认整条 fwd+bwd 正确。