diff --git a/docs/04-tiny-transformer.md b/docs/04-tiny-transformer.md new file mode 100644 index 0000000..29dbd70 --- /dev/null +++ b/docs/04-tiny-transformer.md @@ -0,0 +1,157 @@ +# Phase T5: Tiny Transformer (fwd+bwd) — Design Document + +## Goal + +在 T4 的 autograd 引擎(`Var` tape + 11 个算子)之上,**组装一个 tiny 现代架构 decoder**(RoPE + RMSNorm + SwiGLU + 多头因果 attention),跑通 **char-level bring-up**,并以「**把一个小 batch overfit 到 loss≈0**」证明整条 forward+backward 计算图正确。 + +明确范围(T5 只做这些): + +1. **补 T4 留的缺口算子**为正式可微节点:`embedding`(按 token id gather / scatter-add 反向)、`reshape`、`transpose`(2D + 3D 轴(0,1)),并加 `split_heads`/`merge_heads` 处理多头布局。 +2. **tiny transformer**:token embedding → `n_layers` × {pre-RMSNorm → 多头因果 attention → 残差;pre-RMSNorm → SwiGLU MLP → 残差} → final RMSNorm → LM head → cross-entropy。 +3. **overfit 验证**:极简手写 GD step(`p -= lr·grad`)记住一个固定小 batch。 + +**不做**(留 T6):真训练 loop、AdamW、LR schedule、grad clip、checkpoint、TinyStories/tokenizer。overfit 不需要优化器,一行 `p -= lr·grad` 足够。 + +## Module Layout + +``` +csrc/ops/model.cu # 新:embedding gather/scatter-add + 3D 轴(0,1) transpose kernel + # (reshape 是纯 metadata,无 kernel) + +crates/xtrain-cuda/ +├── build.rs # 新增 model.cu +└── src/ffi.rs # 新增 launch_embedding_fwd/bwd + launch_transpose_3d01(no_cuda 门控) + +crates/xtrain-tensor/ +└── src/tensor.rs # 新增 reshape / embedding(+bwd) / transpose_3d01(no_cuda 门控) + +crates/xtrain-autodiff/ +├── src/tape.rs # 新增 Var::zero_grad / set_value(供手写 GD step) +├── src/ops.rs # 新增节点 embedding / reshape / transpose_3d01 / transpose_2d +│ # / split_heads(->Vec) / merge_heads +└── tests/structural.rs # 新增:上述结构算子各自 grad-check + +crates/xtrain-model/ # 新 crate:模型本体 +├── build.rs # 检测 nvcc → no_cuda cfg(逐 crate) +├── src/ +│ ├── lib.rs # 导出 Config / TinyTransformer / 辅助 +│ ├── config.rs # 超参(host-only,no_cuda 也编) +│ └── model.rs # TinyTransformer:参数容器 + forward 图(no_cuda 门控) +└── tests/ + ├── overfit.rs # 端到端:char-level + 手写 GD → loss≈0 + ├── parity_dump.rs # PyTorch 对拍 fixture(dump 权重/logits/grad) + └── parity.py # 等价 PyTorch 模型,对比 forward + 每参数 grad +``` + +为什么新开 `xtrain-model` crate(而非塞进 autodiff):对齐 xserv 的「模型层独立于算子层」分层;autodiff 只管 `Var`/`ops`,model 在其上拼网络,职责清晰。 + +## Key Design Decisions + +### 约定(对齐引擎,不照搬 HuggingFace) + +- **线性层权重 `[in, out]`,按 `x @ W`**。引擎的 GEMM 是裸 `A @ B`,用 `[in,out]` 省掉每次投影的 transpose(Qwen 的 `[out,in]` + `x@Wᵀ` 是推理侧权重布局的产物,训练侧自定义即可,回流 xserv 时在 T9 转置导出)。 +- **`dim = n_heads · head_dim`**(无独立 attention 投影维度),tiny 配置 `dim=32, n_layers=2, n_heads=2, head_dim=16, ffn_hidden=64`。 +- **RoPE position = token 行号**(kernel 内建约定);`rope` 只接受 `[tokens, heads, head_dim]` 布局。 +- **因果 mask = 加性常量 `[seq,seq]`**:对角线及以下为 0,以上为 −1e9,softmax 前 `add` 进 scores。引擎没有 masking 算子,用一个常量 leaf + `add` 即可(该 leaf 无下游、其 grad 不被读取,无害)。 + +### 多头布局:reshape + transpose_3d01 + split/merge_heads + +引擎的 matmul/softmax 都是 2D 的,多头 attention 因此**逐头做 2D SDPA**。布局流水线: + +``` +proj = x @ Wq # [seq, dim] + reshape → [seq, nh, hd] +rope (kernel 要求正是 [tokens, heads, head_dim]) + transpose_3d01 → [nh, seq, hd] # 让每个 head 的 [seq,hd] 块连续 +split_heads → nh × [seq, hd] # 每头一个 Var + +# 逐头: +scores = (q_i @ k_iᵀ)·(1/√hd) + mask # [seq,seq] +probs = softmax(scores) +out_i = probs @ v_i # [seq, hd] + +merge_heads(out) → [nh, seq, hd] + transpose_3d01 → [seq, nh, hd] + reshape → [seq, dim] + @ Wo # 输出投影 +``` + +- `reshape` 是**纯 metadata**(连续张量改 shape/strides,共享 storage,无 kernel、无数据搬运);`[seq, nh·hd] ↔ [seq, nh, hd]` 正好是它。反向 reshape 回去。 +- `transpose_3d01`(`[a,b,c]→[b,a,c]`)有 kernel,**自反**:反向是对 grad 再做一次同样的 transpose。 +- `split_heads`:`[nh,seq,hd]` 在该布局下每个 head 块连续,前向把每块拷成独立连续张量返回 `Vec`;反向把每头 grad 散回零初始化的 `[nh,seq,hd]`,由引擎的**扇出 SUM** 累加。`merge_heads` 是逆操作。 + +### embedding 前向 gather / 反向 scatter-add + +``` +out[s,:] = table[ids[s], :] # table:[vocab,dim], ids:[seq] I32 +dtable[ids[s],:] += dout[s,:] # 反向:原子 scatter-add +``` + +- ids 是**常量**(不是 `Var`),只有 table 参与求导。 +- 反向必须 **atomicAdd**:多个位置可能映射到同一 id,梯度要累加(grad-check 测试里特意放了重复 id `[0,3,1,3,2,0]`)。`dtable` 先 `zeros` 再原子累加。 + +### Block 结构(pre-norm + residual,Qwen3 风格) + +``` +h = embedding(ids) +for block: + h = h + attention(rms_norm(h, γ_attn)) # 残差 + h = h + swiglu_mlp(rms_norm(h, γ_ffn)) # 残差 +logits = rms_norm(h, γ_final) @ lm_head +loss = cross_entropy(logits, targets) +``` + +SwiGLU MLP:`down( silu(x@W_gate) ∘ (x@W_up) )`,复用 T4 的 `swiglu = mul(silu(g), u)`。 + +### 参数 API(为 T6 优化器准备) + +- `TinyTransformer::params() -> Vec`:稳定顺序的全部可学习叶子(embedding / 各 block 的 9 个 / final_norm / lm_head)。 +- `Var::set_value(t)`:原地更新参数值(GD/AdamW 用),保持叶子身份在多 step 间稳定。 +- `Var::zero_grad()`:清梯度。**关键**:每个 forward 建新图但叶子复用,上一 step 的 grad 不清会被 SUM 累加 → 每 step 后必须 zero。 +- `param_to_host(&Var)`:把参数搬回 host `Vec`(GD step / 对拍导出)。 + +手写 GD step(overfit 用): +```rust +for p in params { + if let Some(g) = p.grad() { p.set_value(p.value().add(&g.scale(-lr))); } + p.zero_grad(); +} +``` + +### overfit 方法学 + +一个固定的小 batch(char-level 文本 → 字符表 → `(input, shifted target)`),反复跑 forward→backward→GD。**只要 fwd+bwd 全对,模型会记住这个 batch,loss → ~0**;任何一个 backward 错了,loss 会停在某个台阶下不去。这是比单算子 grad-check 更强的端到端信号。验收同时检查 greedy argmax 是否完全复现 target 序列。 + +## 验证方法 + +全部 `#![cfg(not(no_cuda))]` 门控,本地只 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090, sm_120)。 + +1. **结构算子 grad-check**(`tests/structural.rs`):沿用 T4 的 `L = sum(W∘out)` 有限差分 harness,对 `embedding`(含重复 id)、`reshape`、`transpose_3d01`、`transpose_2d`、`split/merge_heads` 往返各做一遍。 +2. **overfit**(`tests/overfit.rs`):tiny 模型 + 手写 GD,断言 `loss → <0.05` 且 greedy argmax 全对。 +3. **PyTorch 对拍**(`parity_dump.rs` + `parity.py`):Rust dump 出权重/ids/logits/loss/每参数 grad(一次 backward),Python 用**完全等价**的 PyTorch 模型(同 `x@W` 约定、同 RoPE rotate_half pos=行号、同 RMSNorm/SwiGLU/因果 SDPA)跑 fwd+bwd,对比 forward logits + 21 个参数 grad 的相对误差(rtol=2e-2)。 + +### dash5 实测结果 + +``` +# 结构算子 grad-check(max rel-err) +embedding dTable 3.5e-5 reshape dX 3.4e-4 +transpose_2d dX 2.3e-5 transpose_3d01 dX 1.9e-4 +split/merge_heads dX 3.9e-5 (5/5 通过) + +# overfit(lr=0.3, 200 steps, vocab=16, seq=27, params=21664) +step 0: loss = 2.821415 +step 20: loss = 0.341899 +step 40: loss = 0.099285 +... +step 199: loss = 0.004009 +start 2.821415 → final 0.004009 ; greedy match 27/27 ✅ + +# PyTorch 对拍(rtol=2e-2) +loss: rust=2.505827e0 torch=2.505827e0 relerr=1.4e-8 +logits: max relerr = 9.5e-5 +params checked: 21 worst = grad[l0_wo] @ 1.1e-3 → PARITY OK ✅ + +# 无回归:T4 autograd 12/12 仍全绿 +``` + +forward 与 PyTorch 对到 ~1e-4、每参数 grad 对到 ≤1.1e-3,overfit loss 从 2.82 跌到 0.004 且完全复现 target——三层证据(单算子 finite-diff、端到端 overfit、PyTorch 对拍)一致确认整条 fwd+bwd 正确。