Files
xtrain/docs/04-tiny-transformer.md
Gahow Wang 8565565647 docs: Phase T5 — tiny transformer
Goal / Module Layout / Key Design Decisions (multi-head layout via
reshape+transpose_3d01+split/merge_heads, embedding gather/scatter-add,
x@W convention, causal mask, params API, overfit methodology) / 验证方法
with the dash5 results (grad-checks, overfit 2.82->0.004, PyTorch parity).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:09:30 +08:00

158 lines
8.9 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T5: Tiny Transformer (fwd+bwd) — Design Document
## Goal
在 T4 的 autograd 引擎(`Var` tape + 11 个算子)之上,**组装一个 tiny 现代架构 decoder**RoPE + RMSNorm + SwiGLU + 多头因果 attention跑通 **char-level bring-up**,并以「**把一个小 batch overfit 到 loss≈0**」证明整条 forward+backward 计算图正确。
明确范围T5 只做这些):
1. **补 T4 留的缺口算子**为正式可微节点:`embedding`(按 token id gather / scatter-add 反向)、`reshape``transpose`2D + 3D 轴(0,1)),并加 `split_heads`/`merge_heads` 处理多头布局。
2. **tiny transformer**token embedding → `n_layers` × {pre-RMSNorm → 多头因果 attention → 残差pre-RMSNorm → SwiGLU MLP → 残差} → final RMSNorm → LM head → cross-entropy。
3. **overfit 验证**:极简手写 GD step`p -= lr·grad`)记住一个固定小 batch。
**不做**(留 T6真训练 loop、AdamW、LR schedule、grad clip、checkpoint、TinyStories/tokenizer。overfit 不需要优化器,一行 `p -= lr·grad` 足够。
## Module Layout
```
csrc/ops/model.cu # 新embedding gather/scatter-add + 3D 轴(0,1) transpose kernel
# reshape 是纯 metadata无 kernel
crates/xtrain-cuda/
├── build.rs # 新增 model.cu
└── src/ffi.rs # 新增 launch_embedding_fwd/bwd + launch_transpose_3d01no_cuda 门控)
crates/xtrain-tensor/
└── src/tensor.rs # 新增 reshape / embedding(+bwd) / transpose_3d01no_cuda 门控)
crates/xtrain-autodiff/
├── src/tape.rs # 新增 Var::zero_grad / set_value供手写 GD step
├── src/ops.rs # 新增节点 embedding / reshape / transpose_3d01 / transpose_2d
│ # / split_heads(->Vec<Var>) / merge_heads
└── tests/structural.rs # 新增:上述结构算子各自 grad-check
crates/xtrain-model/ # 新 crate模型本体
├── build.rs # 检测 nvcc → no_cuda cfg逐 crate
├── src/
│ ├── lib.rs # 导出 Config / TinyTransformer / 辅助
│ ├── config.rs # 超参host-onlyno_cuda 也编)
│ └── model.rs # TinyTransformer参数容器 + forward 图no_cuda 门控)
└── tests/
├── overfit.rs # 端到端char-level + 手写 GD → loss≈0
├── parity_dump.rs # PyTorch 对拍 fixturedump 权重/logits/grad
└── parity.py # 等价 PyTorch 模型,对比 forward + 每参数 grad
```
为什么新开 `xtrain-model` crate而非塞进 autodiff对齐 xserv 的「模型层独立于算子层」分层autodiff 只管 `Var`/`ops`model 在其上拼网络,职责清晰。
## Key Design Decisions
### 约定(对齐引擎,不照搬 HuggingFace
- **线性层权重 `[in, out]`,按 `x @ W`**。引擎的 GEMM 是裸 `A @ B`,用 `[in,out]` 省掉每次投影的 transposeQwen 的 `[out,in]` + `x@Wᵀ` 是推理侧权重布局的产物,训练侧自定义即可,回流 xserv 时在 T9 转置导出)。
- **`dim = n_heads · head_dim`**(无独立 attention 投影维度tiny 配置 `dim=32, n_layers=2, n_heads=2, head_dim=16, ffn_hidden=64`
- **RoPE position = token 行号**kernel 内建约定);`rope` 只接受 `[tokens, heads, head_dim]` 布局。
- **因果 mask = 加性常量 `[seq,seq]`**:对角线及以下为 0以上为 1e9softmax 前 `add` 进 scores。引擎没有 masking 算子,用一个常量 leaf + `add` 即可(该 leaf 无下游、其 grad 不被读取,无害)。
### 多头布局reshape + transpose_3d01 + split/merge_heads
引擎的 matmul/softmax 都是 2D 的,多头 attention 因此**逐头做 2D SDPA**。布局流水线:
```
proj = x @ Wq # [seq, dim]
reshape → [seq, nh, hd]
rope (kernel 要求正是 [tokens, heads, head_dim])
transpose_3d01 → [nh, seq, hd] # 让每个 head 的 [seq,hd] 块连续
split_heads → nh × [seq, hd] # 每头一个 Var
# 逐头:
scores = (q_i @ k_iᵀ)·(1/√hd) + mask # [seq,seq]
probs = softmax(scores)
out_i = probs @ v_i # [seq, hd]
merge_heads(out) → [nh, seq, hd]
transpose_3d01 → [seq, nh, hd]
reshape → [seq, dim]
@ Wo # 输出投影
```
- `reshape` 是**纯 metadata**(连续张量改 shape/strides共享 storage无 kernel、无数据搬运`[seq, nh·hd] ↔ [seq, nh, hd]` 正好是它。反向 reshape 回去。
- `transpose_3d01``[a,b,c]→[b,a,c]`)有 kernel**自反**:反向是对 grad 再做一次同样的 transpose。
- `split_heads``[nh,seq,hd]` 在该布局下每个 head 块连续,前向把每块拷成独立连续张量返回 `Vec<Var>`;反向把每头 grad 散回零初始化的 `[nh,seq,hd]`,由引擎的**扇出 SUM** 累加。`merge_heads` 是逆操作。
### embedding 前向 gather / 反向 scatter-add
```
out[s,:] = table[ids[s], :] # table:[vocab,dim], ids:[seq] I32
dtable[ids[s],:] += dout[s,:] # 反向:原子 scatter-add
```
- ids 是**常量**(不是 `Var`),只有 table 参与求导。
- 反向必须 **atomicAdd**:多个位置可能映射到同一 id梯度要累加grad-check 测试里特意放了重复 id `[0,3,1,3,2,0]`)。`dtable``zeros` 再原子累加。
### Block 结构pre-norm + residualQwen3 风格)
```
h = embedding(ids)
for block:
h = h + attention(rms_norm(h, γ_attn)) # 残差
h = h + swiglu_mlp(rms_norm(h, γ_ffn)) # 残差
logits = rms_norm(h, γ_final) @ lm_head
loss = cross_entropy(logits, targets)
```
SwiGLU MLP`down( silu(x@W_gate) ∘ (x@W_up) )`,复用 T4 的 `swiglu = mul(silu(g), u)`
### 参数 API为 T6 优化器准备)
- `TinyTransformer::params() -> Vec<Var>`稳定顺序的全部可学习叶子embedding / 各 block 的 9 个 / final_norm / lm_head
- `Var::set_value(t)`原地更新参数值GD/AdamW 用),保持叶子身份在多 step 间稳定。
- `Var::zero_grad()`:清梯度。**关键**:每个 forward 建新图但叶子复用,上一 step 的 grad 不清会被 SUM 累加 → 每 step 后必须 zero。
- `param_to_host(&Var)`:把参数搬回 host `Vec<f32>`GD step / 对拍导出)。
手写 GD stepoverfit 用):
```rust
for p in params {
if let Some(g) = p.grad() { p.set_value(p.value().add(&g.scale(-lr))); }
p.zero_grad();
}
```
### overfit 方法学
一个固定的小 batchchar-level 文本 → 字符表 → `(input, shifted target)`),反复跑 forward→backward→GD。**只要 fwd+bwd 全对,模型会记住这个 batchloss → ~0**;任何一个 backward 错了loss 会停在某个台阶下不去。这是比单算子 grad-check 更强的端到端信号。验收同时检查 greedy argmax 是否完全复现 target 序列。
## 验证方法
全部 `#![cfg(not(no_cuda))]` 门控,本地只 `cargo check`/`fmt`,构建+实跑在 dash58× RTX 5090, sm_120
1. **结构算子 grad-check**`tests/structural.rs`):沿用 T4 的 `L = sum(W∘out)` 有限差分 harness`embedding`(含重复 id`reshape``transpose_3d01``transpose_2d``split/merge_heads` 往返各做一遍。
2. **overfit**`tests/overfit.rs`tiny 模型 + 手写 GD断言 `loss → <0.05` 且 greedy argmax 全对。
3. **PyTorch 对拍**`parity_dump.rs` + `parity.py`Rust dump 出权重/ids/logits/loss/每参数 grad一次 backwardPython 用**完全等价**的 PyTorch 模型(同 `x@W` 约定、同 RoPE rotate_half pos=行号、同 RMSNorm/SwiGLU/因果 SDPA跑 fwd+bwd对比 forward logits + 21 个参数 grad 的相对误差rtol=2e-2
### dash5 实测结果
```
# 结构算子 grad-checkmax rel-err
embedding dTable 3.5e-5 reshape dX 3.4e-4
transpose_2d dX 2.3e-5 transpose_3d01 dX 1.9e-4
split/merge_heads dX 3.9e-5 5/5 通过)
# overfitlr=0.3, 200 steps, vocab=16, seq=27, params=21664
step 0: loss = 2.821415
step 20: loss = 0.341899
step 40: loss = 0.099285
...
step 199: loss = 0.004009
start 2.821415 → final 0.004009 ; greedy match 27/27 ✅
# PyTorch 对拍rtol=2e-2
loss: rust=2.505827e0 torch=2.505827e0 relerr=1.4e-8
logits: max relerr = 9.5e-5
params checked: 21 worst = grad[l0_wo] @ 1.1e-3 → PARITY OK ✅
# 无回归T4 autograd 12/12 仍全绿
```
forward 与 PyTorch 对到 ~1e-4、每参数 grad 对到 ≤1.1e-3overfit loss 从 2.82 跌到 0.004 且完全复现 target——三层证据单算子 finite-diff、端到端 overfit、PyTorch 对拍)一致确认整条 fwd+bwd 正确。