Goal / Module Layout / Key Design Decisions (multi-head layout via reshape+transpose_3d01+split/merge_heads, embedding gather/scatter-add, x@W convention, causal mask, params API, overfit methodology) / 验证方法 with the dash5 results (grad-checks, overfit 2.82->0.004, PyTorch parity). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
158 lines
8.9 KiB
Markdown
158 lines
8.9 KiB
Markdown
# Phase T5: Tiny Transformer (fwd+bwd) — Design Document
|
||
|
||
## Goal
|
||
|
||
在 T4 的 autograd 引擎(`Var` tape + 11 个算子)之上,**组装一个 tiny 现代架构 decoder**(RoPE + RMSNorm + SwiGLU + 多头因果 attention),跑通 **char-level bring-up**,并以「**把一个小 batch overfit 到 loss≈0**」证明整条 forward+backward 计算图正确。
|
||
|
||
明确范围(T5 只做这些):
|
||
|
||
1. **补 T4 留的缺口算子**为正式可微节点:`embedding`(按 token id gather / scatter-add 反向)、`reshape`、`transpose`(2D + 3D 轴(0,1)),并加 `split_heads`/`merge_heads` 处理多头布局。
|
||
2. **tiny transformer**:token embedding → `n_layers` × {pre-RMSNorm → 多头因果 attention → 残差;pre-RMSNorm → SwiGLU MLP → 残差} → final RMSNorm → LM head → cross-entropy。
|
||
3. **overfit 验证**:极简手写 GD step(`p -= lr·grad`)记住一个固定小 batch。
|
||
|
||
**不做**(留 T6):真训练 loop、AdamW、LR schedule、grad clip、checkpoint、TinyStories/tokenizer。overfit 不需要优化器,一行 `p -= lr·grad` 足够。
|
||
|
||
## Module Layout
|
||
|
||
```
|
||
csrc/ops/model.cu # 新:embedding gather/scatter-add + 3D 轴(0,1) transpose kernel
|
||
# (reshape 是纯 metadata,无 kernel)
|
||
|
||
crates/xtrain-cuda/
|
||
├── build.rs # 新增 model.cu
|
||
└── src/ffi.rs # 新增 launch_embedding_fwd/bwd + launch_transpose_3d01(no_cuda 门控)
|
||
|
||
crates/xtrain-tensor/
|
||
└── src/tensor.rs # 新增 reshape / embedding(+bwd) / transpose_3d01(no_cuda 门控)
|
||
|
||
crates/xtrain-autodiff/
|
||
├── src/tape.rs # 新增 Var::zero_grad / set_value(供手写 GD step)
|
||
├── src/ops.rs # 新增节点 embedding / reshape / transpose_3d01 / transpose_2d
|
||
│ # / split_heads(->Vec<Var>) / merge_heads
|
||
└── tests/structural.rs # 新增:上述结构算子各自 grad-check
|
||
|
||
crates/xtrain-model/ # 新 crate:模型本体
|
||
├── build.rs # 检测 nvcc → no_cuda cfg(逐 crate)
|
||
├── src/
|
||
│ ├── lib.rs # 导出 Config / TinyTransformer / 辅助
|
||
│ ├── config.rs # 超参(host-only,no_cuda 也编)
|
||
│ └── model.rs # TinyTransformer:参数容器 + forward 图(no_cuda 门控)
|
||
└── tests/
|
||
├── overfit.rs # 端到端:char-level + 手写 GD → loss≈0
|
||
├── parity_dump.rs # PyTorch 对拍 fixture(dump 权重/logits/grad)
|
||
└── parity.py # 等价 PyTorch 模型,对比 forward + 每参数 grad
|
||
```
|
||
|
||
为什么新开 `xtrain-model` crate(而非塞进 autodiff):对齐 xserv 的「模型层独立于算子层」分层;autodiff 只管 `Var`/`ops`,model 在其上拼网络,职责清晰。
|
||
|
||
## Key Design Decisions
|
||
|
||
### 约定(对齐引擎,不照搬 HuggingFace)
|
||
|
||
- **线性层权重 `[in, out]`,按 `x @ W`**。引擎的 GEMM 是裸 `A @ B`,用 `[in,out]` 省掉每次投影的 transpose(Qwen 的 `[out,in]` + `x@Wᵀ` 是推理侧权重布局的产物,训练侧自定义即可,回流 xserv 时在 T9 转置导出)。
|
||
- **`dim = n_heads · head_dim`**(无独立 attention 投影维度),tiny 配置 `dim=32, n_layers=2, n_heads=2, head_dim=16, ffn_hidden=64`。
|
||
- **RoPE position = token 行号**(kernel 内建约定);`rope` 只接受 `[tokens, heads, head_dim]` 布局。
|
||
- **因果 mask = 加性常量 `[seq,seq]`**:对角线及以下为 0,以上为 −1e9,softmax 前 `add` 进 scores。引擎没有 masking 算子,用一个常量 leaf + `add` 即可(该 leaf 无下游、其 grad 不被读取,无害)。
|
||
|
||
### 多头布局:reshape + transpose_3d01 + split/merge_heads
|
||
|
||
引擎的 matmul/softmax 都是 2D 的,多头 attention 因此**逐头做 2D SDPA**。布局流水线:
|
||
|
||
```
|
||
proj = x @ Wq # [seq, dim]
|
||
reshape → [seq, nh, hd]
|
||
rope (kernel 要求正是 [tokens, heads, head_dim])
|
||
transpose_3d01 → [nh, seq, hd] # 让每个 head 的 [seq,hd] 块连续
|
||
split_heads → nh × [seq, hd] # 每头一个 Var
|
||
|
||
# 逐头:
|
||
scores = (q_i @ k_iᵀ)·(1/√hd) + mask # [seq,seq]
|
||
probs = softmax(scores)
|
||
out_i = probs @ v_i # [seq, hd]
|
||
|
||
merge_heads(out) → [nh, seq, hd]
|
||
transpose_3d01 → [seq, nh, hd]
|
||
reshape → [seq, dim]
|
||
@ Wo # 输出投影
|
||
```
|
||
|
||
- `reshape` 是**纯 metadata**(连续张量改 shape/strides,共享 storage,无 kernel、无数据搬运);`[seq, nh·hd] ↔ [seq, nh, hd]` 正好是它。反向 reshape 回去。
|
||
- `transpose_3d01`(`[a,b,c]→[b,a,c]`)有 kernel,**自反**:反向是对 grad 再做一次同样的 transpose。
|
||
- `split_heads`:`[nh,seq,hd]` 在该布局下每个 head 块连续,前向把每块拷成独立连续张量返回 `Vec<Var>`;反向把每头 grad 散回零初始化的 `[nh,seq,hd]`,由引擎的**扇出 SUM** 累加。`merge_heads` 是逆操作。
|
||
|
||
### embedding 前向 gather / 反向 scatter-add
|
||
|
||
```
|
||
out[s,:] = table[ids[s], :] # table:[vocab,dim], ids:[seq] I32
|
||
dtable[ids[s],:] += dout[s,:] # 反向:原子 scatter-add
|
||
```
|
||
|
||
- ids 是**常量**(不是 `Var`),只有 table 参与求导。
|
||
- 反向必须 **atomicAdd**:多个位置可能映射到同一 id,梯度要累加(grad-check 测试里特意放了重复 id `[0,3,1,3,2,0]`)。`dtable` 先 `zeros` 再原子累加。
|
||
|
||
### Block 结构(pre-norm + residual,Qwen3 风格)
|
||
|
||
```
|
||
h = embedding(ids)
|
||
for block:
|
||
h = h + attention(rms_norm(h, γ_attn)) # 残差
|
||
h = h + swiglu_mlp(rms_norm(h, γ_ffn)) # 残差
|
||
logits = rms_norm(h, γ_final) @ lm_head
|
||
loss = cross_entropy(logits, targets)
|
||
```
|
||
|
||
SwiGLU MLP:`down( silu(x@W_gate) ∘ (x@W_up) )`,复用 T4 的 `swiglu = mul(silu(g), u)`。
|
||
|
||
### 参数 API(为 T6 优化器准备)
|
||
|
||
- `TinyTransformer::params() -> Vec<Var>`:稳定顺序的全部可学习叶子(embedding / 各 block 的 9 个 / final_norm / lm_head)。
|
||
- `Var::set_value(t)`:原地更新参数值(GD/AdamW 用),保持叶子身份在多 step 间稳定。
|
||
- `Var::zero_grad()`:清梯度。**关键**:每个 forward 建新图但叶子复用,上一 step 的 grad 不清会被 SUM 累加 → 每 step 后必须 zero。
|
||
- `param_to_host(&Var)`:把参数搬回 host `Vec<f32>`(GD step / 对拍导出)。
|
||
|
||
手写 GD step(overfit 用):
|
||
```rust
|
||
for p in params {
|
||
if let Some(g) = p.grad() { p.set_value(p.value().add(&g.scale(-lr))); }
|
||
p.zero_grad();
|
||
}
|
||
```
|
||
|
||
### overfit 方法学
|
||
|
||
一个固定的小 batch(char-level 文本 → 字符表 → `(input, shifted target)`),反复跑 forward→backward→GD。**只要 fwd+bwd 全对,模型会记住这个 batch,loss → ~0**;任何一个 backward 错了,loss 会停在某个台阶下不去。这是比单算子 grad-check 更强的端到端信号。验收同时检查 greedy argmax 是否完全复现 target 序列。
|
||
|
||
## 验证方法
|
||
|
||
全部 `#![cfg(not(no_cuda))]` 门控,本地只 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090, sm_120)。
|
||
|
||
1. **结构算子 grad-check**(`tests/structural.rs`):沿用 T4 的 `L = sum(W∘out)` 有限差分 harness,对 `embedding`(含重复 id)、`reshape`、`transpose_3d01`、`transpose_2d`、`split/merge_heads` 往返各做一遍。
|
||
2. **overfit**(`tests/overfit.rs`):tiny 模型 + 手写 GD,断言 `loss → <0.05` 且 greedy argmax 全对。
|
||
3. **PyTorch 对拍**(`parity_dump.rs` + `parity.py`):Rust dump 出权重/ids/logits/loss/每参数 grad(一次 backward),Python 用**完全等价**的 PyTorch 模型(同 `x@W` 约定、同 RoPE rotate_half pos=行号、同 RMSNorm/SwiGLU/因果 SDPA)跑 fwd+bwd,对比 forward logits + 21 个参数 grad 的相对误差(rtol=2e-2)。
|
||
|
||
### dash5 实测结果
|
||
|
||
```
|
||
# 结构算子 grad-check(max rel-err)
|
||
embedding dTable 3.5e-5 reshape dX 3.4e-4
|
||
transpose_2d dX 2.3e-5 transpose_3d01 dX 1.9e-4
|
||
split/merge_heads dX 3.9e-5 (5/5 通过)
|
||
|
||
# overfit(lr=0.3, 200 steps, vocab=16, seq=27, params=21664)
|
||
step 0: loss = 2.821415
|
||
step 20: loss = 0.341899
|
||
step 40: loss = 0.099285
|
||
...
|
||
step 199: loss = 0.004009
|
||
start 2.821415 → final 0.004009 ; greedy match 27/27 ✅
|
||
|
||
# PyTorch 对拍(rtol=2e-2)
|
||
loss: rust=2.505827e0 torch=2.505827e0 relerr=1.4e-8
|
||
logits: max relerr = 9.5e-5
|
||
params checked: 21 worst = grad[l0_wo] @ 1.1e-3 → PARITY OK ✅
|
||
|
||
# 无回归:T4 autograd 12/12 仍全绿
|
||
```
|
||
|
||
forward 与 PyTorch 对到 ~1e-4、每参数 grad 对到 ≤1.1e-3,overfit loss 从 2.82 跌到 0.004 且完全复现 target——三层证据(单算子 finite-diff、端到端 overfit、PyTorch 对拍)一致确认整条 fwd+bwd 正确。
|