Files
xtrain/docs/05-training-loop.md
Gahow Wang 29b4d30b6c docs: Phase T6 — training loop
Design doc for the T6 training stack: Goal / Module Layout / Key Design
Decisions (AdamW math + decoupled WD, LR schedule, global-norm grad clip with
batch averaging, checkpoint format, data pipeline + xserv tokenizer reuse,
sampler) / 验证方法 (AdamW parity, checkpoint round-trip, real training, host
unit tests).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:30:14 +08:00

115 lines
9.5 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T6: Training Loop + AdamW + Real Training — Design Document
## Goal
在 T5 的 `TinyTransformer``params()` / `forward` / `loss` + `Var::{value,grad,set_value,zero_grad}`)之上,搭起**真正的训练栈**,并在**真实文本语料**上把 loss 训下来:
1. **手写 AdamW**per-param 一/二阶矩m、vbias correction**decoupled weight decay**,对拍 `torch.optim.AdamW` 数值一致。
2. **训练 loop**:语料 → 采样定长序列 → `forward(loss)``backward`**global-norm grad clip****AdamW step**`zero_grad`**LR schedule**warmup + cosine周期性 loss 日志;**checkpoint 存/取**。
3. **采样器**greedy / temperature训练中/后吐文本看「在不在学」。
4. **数据****复用 xserv 的 GPT-2 BPE** tokenizerpath-dep语料 = **TinyStories** 子集。
**不做**(留后续 Phase性能cuBLAS 切换 / bf16 / 激活重计算 = T7、分布式NCCL 数据并行 = T8。本 Phase 只要**正确性 + 清晰的学习信号**,训练预算有界(几分钟 / 几千步,非完全收敛)。
## Module Layout
```
crates/xtrain-optim/ # 新 crate优化器
├── build.rs # 检测 nvcc → no_cuda cfg逐 crate
├── src/lib.rs # AdamWstep_host(纯 host 数学) + step(&[Var]) GPU 包装
└── tests/adamw_host.rs # host 单测:对独立参考递推 + 纯 decay 边界(本地可跑,无 GPU
crates/xtrain-train/ # 新 crate训练基建 + 入口
├── build.rs # 检测 nvcc → no_cuda cfg
├── Cargo.toml # path-dep: ../../../xserv/crates/xserv-tokenizer本地/dash5 都解析)
├── src/
│ ├── lib.rs # 模块导出host-only 与 GPU 件分门控)
│ ├── schedule.rs # LrSchedulewarmup + cosinehost-only可本地单测
│ ├── clip.rs # global L2 norm + clip_scalehost 数学)+ clip_grad_norm(&[Var])GPU 门控)
│ ├── data.rs # Corpusload tokenizer+语料 → token 流 → sample(input,target) 窗口
│ ├── checkpoint.rs # save / load_into按 params() 顺序 dump/reloadGPU 门控)
│ ├── sample.rs # generategreedy / temperature 自回归采样GPU 门控)
│ ├── train_loop.rs # TrainConfig + train():把以上接到 model+AdamWGPU 门控)
│ └── bin/train.rs # 真训练入口load 数据 → train → checkpoint → 采样
└── tests/
├── adamw_parity_dump.rs # AdamW 对拍 fixture固定 init 跑 N 步 AdamWdump loss 轨迹 + 终参
├── adamw_parity.py # 等价 PyTorch 模型 + torch.optim.AdamW对比轨迹 + 终参
├── checkpoint_roundtrip.rs # 训几步→save→载入新模型→logits/loss 逐位一致
└── real_training.rs # TinyStories 有界训练loss 大幅下降 + 采样在学
data/tinystories-valid-3mb.txt # 语料子集committed~3MBTinyStories-valid 前 3MB整故事截断
```
**为什么拆两个 crate**:对齐 xserv 的分层(优化器与训练编排分开)。`xtrain-optim` 只管参数更新数学;`xtrain-train` 管数据/调度/checkpoint/采样/loop。AdamW 数学独立可测,不依赖 model。
**host / GPU 门控约定**(沿用全仓):纯算术(`LrSchedule`、grad-norm 数学、AdamW 的 `step_host`**始终编译**,本地 `cargo check` + 单测即可验证;凡 round-trip GPU 张量的(`step(&[Var])``clip_grad_norm(&[Var])`、checkpoint、采样、loop一律 `#[cfg(not(no_cuda))]`,链接+实跑在 dash5。每 crate 的 `build.rs` 各自检测 nvcccfg 不跨 crate 传播)。
## Key Design Decisions
### AdamW手写数学 + decoupled weight decay
`t`1-indexed参数 `θ`、梯度 `g`
```text
m ← β1·m + (1β1)·g
v ← β2·v + (1β2)·g²
m̂ ← m / (1 β1ᵗ) (bias correction)
v̂ ← v / (1 β2ᵗ)
θ ← θ lr·( m̂ / (√v̂ + ε) + wd·θ )
```
- **decoupled weight decay**Loshchilov & Hutter 2019`wd·θ` 直接作用在参数上,**不**并进梯度(不进入自适应 `√v̂` 分母)——这正是 `torch.optim.AdamW` 的定义区别于「L2 正则把 `wd·θ` 加到 `g`」的 Adam。
- 默认超参对齐 PyTorchβ1=0.9β2=0.999,ε=1e-8。
- **状态 keyed by 参数在 `params()` 中的下标**(稳定序),首次 `step` 惰性按各参数 numel 分配 `m,v``t` 全局共享(所有参数同一 bias correction和 PyTorch 一致)。
**实现分层**`step_host(lr, &mut [Vec<f32>], &[Vec<f32>])` 是纯 host f32 数学(无 GPU、无 autograd本地单测`step(lr, &[Var])` 把每参数的 `value()`/`grad()` 拉到 host、调 `step_host``set_value` 写回。这条路子host 算优化器)对 tiny 模型完全够用,且让 AdamW 数学**脱离 GPU 可严格对拍**——T6 是正确性 Phase不做 GPU 优化器 kernel那是性能向超范围`lr` 每步传入,给 schedule 留口。
### LR schedulewarmup + cosine
`step ∈ [0,warmup)` 线性 `0→max_lr``[warmup,total)` cosine `max_lr→min_lr``≥total` 钳到 `min_lr`。纯函数(只吃 step 下标),本地单测形状。
### grad clipglobal L2 norm+ batch 平均)
跨**所有**参数梯度联合算 L2 norm`torch.nn.utils.clip_grad_norm_``total > max_norm` 则全体 `×(max_norm/total)`
模型是**单序列**(无 batch 维),一个 `batch_size` 的「batch」靠**跑 `batch_size` 次 forward+backward**、让 tape 的 fan-out 规则**把梯度 SUM** 起来实现。为得到 batch 均值梯度clip 这一趟 host pass 里**先 `×1/batch_size`** 再算 norm/裁剪——`clip_grad_norm(params, max_norm, pre_scale)` 把「平均」与「裁剪」融成一次 host 往返省一趟拷贝。batch 是 T7/边角关切,这里只求正确。
### checkpoint 格式
`params()` 顺序 dump 每个参数的 value 到扁平二进制:
```text
magic u32 = "XTRT" | version u32 | n_params u32
×n_params: ndim u32 | dims[ndim] u32 | data[Πdims] f32 (小端)
```
不存架构/config——调用方用同一 `Config` 重建模型再 `load_into`round-trip 与 resume 都自知 config`load_into` 校验 magic/version/数量/逐参数 shape按各参数 device 写回 `set_value`。f32 精确往返 → 重载后 forward 逐位一致(同 kernel 同输入)。
### 数据管线 + tokenizer 复用
- **tokenizer = 复用 xserv 的 from-scratch GPT-2/Qwen BPE**`Cargo.toml` path-dep `../../../xserv/crates/xserv-tokenizer`,该相对路径在本地 `~/projects` 与 dash5 `/opt/wjh/projects` 都解析Cargo 按目标 crate 自身的 workspacexserv 的)解析它的 `serde/regex` 依赖,不需要 xtrain 复制 workspace dep。加载 `/opt/wjh/models/gpt2/tokenizer.json`
- **语料 = TinyStories 子集**dash5 经 `hf-mirror.com``TinyStories-valid.txt` 前 ~3MBHF 直连不可达proxy 脚本只起后台 SOCKShf-mirror 直连 200committed 进 `data/``Corpus::load` 整篇 tokenize 成一条 token 流TinyStories 用 `<|endoftext|>` 分故事GPT-2 BPE 正好出成单个 special token文档边界保留range 下载会掐头去尾,故先丢首个不完整行、截到最后一个 `<|endoftext|>`,只训整故事。`sample(seq)` 随机取窗口 `[s,s+seq+1)` → input `[s,s+seq)` / target 右移一位next-tokenLCG 种子可复现,不引 RNG crate。
### 采样器
模型单序列、RoPE pos=行号,故自回归生成**每步对增长前缀重跑 forward、取末行 logits**最简正确法KV cache 是推理/性能向,超范围)。`temperature==0` greedy argmax否则按 `softmax(logits/T)` 采样。
### 训练 loop`train`
每步:采 `batch_size` 序列各自 forward `loss` + backwardtape SUM 梯度)→ `clip_grad_norm(×1/batch + 裁剪)``AdamW::step(lr)` → 全参数 `zero_grad`;按 `log_every``loss/lr/gnorm/tok-s`,按 `ckpt_every` 存 checkpoint返回逐步 loss 轨迹。
## 验证方法(验收)
GPU 测试全部 `#[cfg(not(no_cuda))]` 门控,在 dash5 实跑 capture
1. **AdamW 对拍 PyTorch**(严格正确性):同一 tiny 模型 + 相同 initRust AdamW 与 `torch.optim.AdamW`lr/wd/betas/eps 全对齐)各跑 N 步固定 batch → **loss 轨迹**与**终参**逐项 rtol 内一致。
- fixture`cargo test -p xtrain-train --test adamw_parity_dump -- --ignored --nocapture`
- 对比:`python3 crates/xtrain-train/tests/adamw_parity.py /tmp/xtrain_adamw`
2. **checkpoint round-trip**:训几步 → save → 载入**全新 init 的模型** → 固定输入 logits/loss 逐位一致(且证明载入前新模型确实不同)。
- `cargo test -p xtrain-train --test checkpoint_roundtrip`
3. **真训练**端到端学习信号TinyStories 上有界训练(几百~几千步)→ loss 大幅下降 + greedy 采样显出英文结构(非乱码)。
- `cargo test -p xtrain-train --release --test real_training -- --ignored --nocapture`
-`cargo run -p xtrain-train --release --bin train -- <tokenizer.json> <corpus.txt> [steps] [ckpt]`
4. **host 单测**本地即跑AdamW 数学对独立参考递推、LR schedule 形状、grad-norm/clip 数学。
- `cargo test -p xtrain-optim -p xtrain-train`