Files
xtrain/docs/04-tiny-transformer.md
Gahow Wang 8565565647 docs: Phase T5 — tiny transformer
Goal / Module Layout / Key Design Decisions (multi-head layout via
reshape+transpose_3d01+split/merge_heads, embedding gather/scatter-add,
x@W convention, causal mask, params API, overfit methodology) / 验证方法
with the dash5 results (grad-checks, overfit 2.82->0.004, PyTorch parity).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:09:30 +08:00

8.9 KiB
Raw Permalink Blame History

Phase T5: Tiny Transformer (fwd+bwd) — Design Document

Goal

在 T4 的 autograd 引擎(Var tape + 11 个算子)之上,组装一个 tiny 现代架构 decoderRoPE + RMSNorm + SwiGLU + 多头因果 attention跑通 char-level bring-up,并以「把一个小 batch overfit 到 loss≈0」证明整条 forward+backward 计算图正确。

明确范围T5 只做这些):

  1. 补 T4 留的缺口算子为正式可微节点:embedding(按 token id gather / scatter-add 反向)、reshapetranspose2D + 3D 轴(0,1)),并加 split_heads/merge_heads 处理多头布局。
  2. tiny transformertoken embedding → n_layers × {pre-RMSNorm → 多头因果 attention → 残差pre-RMSNorm → SwiGLU MLP → 残差} → final RMSNorm → LM head → cross-entropy。
  3. overfit 验证:极简手写 GD stepp -= lr·grad)记住一个固定小 batch。

不做(留 T6真训练 loop、AdamW、LR schedule、grad clip、checkpoint、TinyStories/tokenizer。overfit 不需要优化器,一行 p -= lr·grad 足够。

Module Layout

csrc/ops/model.cu                      # 新embedding gather/scatter-add + 3D 轴(0,1) transpose kernel
                                       #     reshape 是纯 metadata无 kernel

crates/xtrain-cuda/
├── build.rs                           # 新增 model.cu
└── src/ffi.rs                         # 新增 launch_embedding_fwd/bwd + launch_transpose_3d01no_cuda 门控)

crates/xtrain-tensor/
└── src/tensor.rs                      # 新增 reshape / embedding(+bwd) / transpose_3d01no_cuda 门控)

crates/xtrain-autodiff/
├── src/tape.rs                        # 新增 Var::zero_grad / set_value供手写 GD step
├── src/ops.rs                         # 新增节点 embedding / reshape / transpose_3d01 / transpose_2d
│                                      #         / split_heads(->Vec<Var>) / merge_heads
└── tests/structural.rs               # 新增:上述结构算子各自 grad-check

crates/xtrain-model/                  # 新 crate模型本体
├── build.rs                           # 检测 nvcc → no_cuda cfg逐 crate
├── src/
│   ├── lib.rs                         # 导出 Config / TinyTransformer / 辅助
│   ├── config.rs                      # 超参host-onlyno_cuda 也编)
│   └── model.rs                       # TinyTransformer参数容器 + forward 图no_cuda 门控)
└── tests/
    ├── overfit.rs                     # 端到端char-level + 手写 GD → loss≈0
    ├── parity_dump.rs                 # PyTorch 对拍 fixturedump 权重/logits/grad
    └── parity.py                      # 等价 PyTorch 模型,对比 forward + 每参数 grad

为什么新开 xtrain-model crate而非塞进 autodiff对齐 xserv 的「模型层独立于算子层」分层autodiff 只管 Var/opsmodel 在其上拼网络,职责清晰。

Key Design Decisions

约定(对齐引擎,不照搬 HuggingFace

  • 线性层权重 [in, out],按 x @ W。引擎的 GEMM 是裸 A @ B,用 [in,out] 省掉每次投影的 transposeQwen 的 [out,in] + x@Wᵀ 是推理侧权重布局的产物,训练侧自定义即可,回流 xserv 时在 T9 转置导出)。
  • dim = n_heads · head_dim(无独立 attention 投影维度tiny 配置 dim=32, n_layers=2, n_heads=2, head_dim=16, ffn_hidden=64
  • RoPE position = token 行号kernel 内建约定);rope 只接受 [tokens, heads, head_dim] 布局。
  • 因果 mask = 加性常量 [seq,seq]:对角线及以下为 0以上为 1e9softmax 前 add 进 scores。引擎没有 masking 算子,用一个常量 leaf + add 即可(该 leaf 无下游、其 grad 不被读取,无害)。

多头布局reshape + transpose_3d01 + split/merge_heads

引擎的 matmul/softmax 都是 2D 的,多头 attention 因此逐头做 2D SDPA。布局流水线:

proj   = x @ Wq                # [seq, dim]
       reshape → [seq, nh, hd]
rope   (kernel 要求正是 [tokens, heads, head_dim])
       transpose_3d01 → [nh, seq, hd]   # 让每个 head 的 [seq,hd] 块连续
split_heads → nh × [seq, hd]            # 每头一个 Var

# 逐头:
scores = (q_i @ k_iᵀ)·(1/√hd) + mask     # [seq,seq]
probs  = softmax(scores)
out_i  = probs @ v_i                      # [seq, hd]

merge_heads(out) → [nh, seq, hd]
       transpose_3d01 → [seq, nh, hd]
       reshape → [seq, dim]
       @ Wo                                # 输出投影
  • reshape纯 metadata(连续张量改 shape/strides共享 storage无 kernel、无数据搬运[seq, nh·hd] ↔ [seq, nh, hd] 正好是它。反向 reshape 回去。
  • transpose_3d01[a,b,c]→[b,a,c])有 kernel自反:反向是对 grad 再做一次同样的 transpose。
  • split_heads[nh,seq,hd] 在该布局下每个 head 块连续,前向把每块拷成独立连续张量返回 Vec<Var>;反向把每头 grad 散回零初始化的 [nh,seq,hd],由引擎的扇出 SUM 累加。merge_heads 是逆操作。

embedding 前向 gather / 反向 scatter-add

out[s,:] = table[ids[s], :]              # table:[vocab,dim], ids:[seq] I32
dtable[ids[s],:] += dout[s,:]            # 反向:原子 scatter-add
  • ids 是常量(不是 Var),只有 table 参与求导。
  • 反向必须 atomicAdd:多个位置可能映射到同一 id梯度要累加grad-check 测试里特意放了重复 id [0,3,1,3,2,0])。dtablezeros 再原子累加。

Block 结构pre-norm + residualQwen3 风格)

h = embedding(ids)
for block:
    h = h + attention(rms_norm(h, γ_attn))     # 残差
    h = h + swiglu_mlp(rms_norm(h, γ_ffn))     # 残差
logits = rms_norm(h, γ_final) @ lm_head
loss   = cross_entropy(logits, targets)

SwiGLU MLPdown( silu(x@W_gate) ∘ (x@W_up) ),复用 T4 的 swiglu = mul(silu(g), u)

参数 API为 T6 优化器准备)

  • TinyTransformer::params() -> Vec<Var>稳定顺序的全部可学习叶子embedding / 各 block 的 9 个 / final_norm / lm_head
  • Var::set_value(t)原地更新参数值GD/AdamW 用),保持叶子身份在多 step 间稳定。
  • Var::zero_grad():清梯度。关键:每个 forward 建新图但叶子复用,上一 step 的 grad 不清会被 SUM 累加 → 每 step 后必须 zero。
  • param_to_host(&Var):把参数搬回 host Vec<f32>GD step / 对拍导出)。

手写 GD stepoverfit 用):

for p in params {
    if let Some(g) = p.grad() { p.set_value(p.value().add(&g.scale(-lr))); }
    p.zero_grad();
}

overfit 方法学

一个固定的小 batchchar-level 文本 → 字符表 → (input, shifted target)),反复跑 forward→backward→GD。只要 fwd+bwd 全对,模型会记住这个 batchloss → ~0;任何一个 backward 错了loss 会停在某个台阶下不去。这是比单算子 grad-check 更强的端到端信号。验收同时检查 greedy argmax 是否完全复现 target 序列。

验证方法

全部 #![cfg(not(no_cuda))] 门控,本地只 cargo check/fmt,构建+实跑在 dash58× RTX 5090, sm_120

  1. 结构算子 grad-checktests/structural.rs):沿用 T4 的 L = sum(W∘out) 有限差分 harnessembedding(含重复 idreshapetranspose_3d01transpose_2dsplit/merge_heads 往返各做一遍。
  2. overfittests/overfit.rstiny 模型 + 手写 GD断言 loss → <0.05 且 greedy argmax 全对。
  3. PyTorch 对拍parity_dump.rs + parity.pyRust dump 出权重/ids/logits/loss/每参数 grad一次 backwardPython 用完全等价的 PyTorch 模型(同 x@W 约定、同 RoPE rotate_half pos=行号、同 RMSNorm/SwiGLU/因果 SDPA跑 fwd+bwd对比 forward logits + 21 个参数 grad 的相对误差rtol=2e-2

dash5 实测结果

# 结构算子 grad-checkmax rel-err
embedding dTable      3.5e-5      reshape dX            3.4e-4
transpose_2d dX       2.3e-5      transpose_3d01 dX     1.9e-4
split/merge_heads dX  3.9e-5      5/5 通过)

# overfitlr=0.3, 200 steps, vocab=16, seq=27, params=21664
step   0: loss = 2.821415
step  20: loss = 0.341899
step  40: loss = 0.099285
...
step 199: loss = 0.004009
start 2.821415 → final 0.004009 ; greedy match 27/27   ✅

# PyTorch 对拍rtol=2e-2
loss:   rust=2.505827e0  torch=2.505827e0   relerr=1.4e-8
logits: max relerr = 9.5e-5
params checked: 21  worst = grad[l0_wo] @ 1.1e-3   → PARITY OK ✅

# 无回归T4 autograd 12/12 仍全绿

forward 与 PyTorch 对到 ~1e-4、每参数 grad 对到 ≤1.1e-3overfit loss 从 2.82 跌到 0.004 且完全复现 target——三层证据单算子 finite-diff、端到端 overfit、PyTorch 对拍)一致确认整条 fwd+bwd 正确。