Design doc for the T6 training stack: Goal / Module Layout / Key Design Decisions (AdamW math + decoupled WD, LR schedule, global-norm grad clip with batch averaging, checkpoint format, data pipeline + xserv tokenizer reuse, sampler) / 验证方法 (AdamW parity, checkpoint round-trip, real training, host unit tests). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
9.5 KiB
Phase T6: Training Loop + AdamW + Real Training — Design Document
Goal
在 T5 的 TinyTransformer(params() / forward / loss + Var::{value,grad,set_value,zero_grad})之上,搭起真正的训练栈,并在真实文本语料上把 loss 训下来:
- 手写 AdamW:per-param 一/二阶矩(m、v),bias correction,decoupled weight decay,对拍
torch.optim.AdamW数值一致。 - 训练 loop:语料 → 采样定长序列 →
forward(loss)→backward→ global-norm grad clip → AdamW step →zero_grad;LR schedule(warmup + cosine);周期性 loss 日志;checkpoint 存/取。 - 采样器:greedy / temperature,训练中/后吐文本看「在不在学」。
- 数据:复用 xserv 的 GPT-2 BPE tokenizer(path-dep),语料 = TinyStories 子集。
不做(留后续 Phase):性能(cuBLAS 切换 / bf16 / 激活重计算 = T7)、分布式(NCCL 数据并行 = T8)。本 Phase 只要正确性 + 清晰的学习信号,训练预算有界(几分钟 / 几千步,非完全收敛)。
Module Layout
crates/xtrain-optim/ # 新 crate:优化器
├── build.rs # 检测 nvcc → no_cuda cfg(逐 crate)
├── src/lib.rs # AdamW:step_host(纯 host 数学) + step(&[Var]) GPU 包装
└── tests/adamw_host.rs # host 单测:对独立参考递推 + 纯 decay 边界(本地可跑,无 GPU)
crates/xtrain-train/ # 新 crate:训练基建 + 入口
├── build.rs # 检测 nvcc → no_cuda cfg
├── Cargo.toml # path-dep: ../../../xserv/crates/xserv-tokenizer(本地/dash5 都解析)
├── src/
│ ├── lib.rs # 模块导出(host-only 与 GPU 件分门控)
│ ├── schedule.rs # LrSchedule:warmup + cosine(host-only,可本地单测)
│ ├── clip.rs # global L2 norm + clip_scale(host 数学)+ clip_grad_norm(&[Var])(GPU 门控)
│ ├── data.rs # Corpus:load tokenizer+语料 → token 流 → sample(input,target) 窗口
│ ├── checkpoint.rs # save / load_into:按 params() 顺序 dump/reload(GPU 门控)
│ ├── sample.rs # generate:greedy / temperature 自回归采样(GPU 门控)
│ ├── train_loop.rs # TrainConfig + train():把以上接到 model+AdamW(GPU 门控)
│ └── bin/train.rs # 真训练入口:load 数据 → train → checkpoint → 采样
└── tests/
├── adamw_parity_dump.rs # AdamW 对拍 fixture:固定 init 跑 N 步 AdamW,dump loss 轨迹 + 终参
├── adamw_parity.py # 等价 PyTorch 模型 + torch.optim.AdamW,对比轨迹 + 终参
├── checkpoint_roundtrip.rs # 训几步→save→载入新模型→logits/loss 逐位一致
└── real_training.rs # TinyStories 有界训练:loss 大幅下降 + 采样在学
data/tinystories-valid-3mb.txt # 语料子集(committed,~3MB,TinyStories-valid 前 3MB,整故事截断)
为什么拆两个 crate:对齐 xserv 的分层(优化器与训练编排分开)。xtrain-optim 只管参数更新数学;xtrain-train 管数据/调度/checkpoint/采样/loop。AdamW 数学独立可测,不依赖 model。
host / GPU 门控约定(沿用全仓):纯算术(LrSchedule、grad-norm 数学、AdamW 的 step_host)始终编译,本地 cargo check + 单测即可验证;凡 round-trip GPU 张量的(step(&[Var])、clip_grad_norm(&[Var])、checkpoint、采样、loop)一律 #[cfg(not(no_cuda))],链接+实跑在 dash5。每 crate 的 build.rs 各自检测 nvcc(cfg 不跨 crate 传播)。
Key Design Decisions
AdamW:手写数学 + decoupled weight decay
第 t 步(1-indexed),参数 θ、梯度 g:
m ← β1·m + (1−β1)·g
v ← β2·v + (1−β2)·g²
m̂ ← m / (1 − β1ᵗ) (bias correction)
v̂ ← v / (1 − β2ᵗ)
θ ← θ − lr·( m̂ / (√v̂ + ε) + wd·θ )
- decoupled weight decay(Loshchilov & Hutter 2019):
wd·θ直接作用在参数上,不并进梯度(不进入自适应√v̂分母)——这正是torch.optim.AdamW的定义,区别于「L2 正则把wd·θ加到g」的 Adam。 - 默认超参对齐 PyTorch:β1=0.9,β2=0.999,ε=1e-8。
- 状态 keyed by 参数在
params()中的下标(稳定序),首次step惰性按各参数 numel 分配m,v。t全局共享(所有参数同一 bias correction,和 PyTorch 一致)。
实现分层:step_host(lr, &mut [Vec<f32>], &[Vec<f32>]) 是纯 host f32 数学(无 GPU、无 autograd,本地单测);step(lr, &[Var]) 把每参数的 value()/grad() 拉到 host、调 step_host、set_value 写回。这条路子(host 算优化器)对 tiny 模型完全够用,且让 AdamW 数学脱离 GPU 可严格对拍——T6 是正确性 Phase,不做 GPU 优化器 kernel(那是性能向,超范围)。lr 每步传入,给 schedule 留口。
LR schedule:warmup + cosine
step ∈ [0,warmup) 线性 0→max_lr;[warmup,total) cosine max_lr→min_lr;≥total 钳到 min_lr。纯函数(只吃 step 下标),本地单测形状。
grad clip:global L2 norm(+ batch 平均)
跨所有参数梯度联合算 L2 norm(同 torch.nn.utils.clip_grad_norm_):total > max_norm 则全体 ×(max_norm/total)。
模型是单序列(无 batch 维),一个 batch_size 的「batch」靠跑 batch_size 次 forward+backward、让 tape 的 fan-out 规则把梯度 SUM 起来实现。为得到 batch 均值梯度,clip 这一趟 host pass 里先 ×1/batch_size 再算 norm/裁剪——clip_grad_norm(params, max_norm, pre_scale) 把「平均」与「裁剪」融成一次 host 往返(省一趟拷贝)。batch 是 T7/边角关切,这里只求正确。
checkpoint 格式
按 params() 顺序 dump 每个参数的 value 到扁平二进制:
magic u32 = "XTRT" | version u32 | n_params u32
×n_params: ndim u32 | dims[ndim] u32 | data[Πdims] f32 (小端)
不存架构/config——调用方用同一 Config 重建模型再 load_into(round-trip 与 resume 都自知 config)。load_into 校验 magic/version/数量/逐参数 shape,按各参数 device 写回 set_value。f32 精确往返 → 重载后 forward 逐位一致(同 kernel 同输入)。
数据管线 + tokenizer 复用
- tokenizer = 复用 xserv 的 from-scratch GPT-2/Qwen BPE:
Cargo.tomlpath-dep../../../xserv/crates/xserv-tokenizer,该相对路径在本地~/projects与 dash5/opt/wjh/projects都解析;Cargo 按目标 crate 自身的 workspace(xserv 的)解析它的serde/regex依赖,不需要 xtrain 复制 workspace dep。加载/opt/wjh/models/gpt2/tokenizer.json。 - 语料 = TinyStories 子集:dash5 经
hf-mirror.com取TinyStories-valid.txt前 ~3MB(HF 直连不可达,proxy 脚本只起后台 SOCKS;hf-mirror 直连 200),committed 进data/。Corpus::load整篇 tokenize 成一条 token 流(TinyStories 用<|endoftext|>分故事,GPT-2 BPE 正好出成单个 special token,文档边界保留);range 下载会掐头去尾,故先丢首个不完整行、截到最后一个<|endoftext|>,只训整故事。sample(seq)随机取窗口[s,s+seq+1)→ input[s,s+seq)/ target 右移一位(next-token),LCG 种子可复现,不引 RNG crate。
采样器
模型单序列、RoPE pos=行号,故自回归生成每步对增长前缀重跑 forward、取末行 logits(最简正确法;KV cache 是推理/性能向,超范围)。temperature==0 greedy argmax,否则按 softmax(logits/T) 采样。
训练 loop(train)
每步:采 batch_size 序列各自 forward loss + backward(tape SUM 梯度)→ clip_grad_norm(×1/batch + 裁剪) → AdamW::step(lr) → 全参数 zero_grad;按 log_every 打 loss/lr/gnorm/tok-s,按 ckpt_every 存 checkpoint,返回逐步 loss 轨迹。
验证方法(验收)
GPU 测试全部 #[cfg(not(no_cuda))] 门控,在 dash5 实跑 capture:
- AdamW 对拍 PyTorch(严格正确性):同一 tiny 模型 + 相同 init,Rust AdamW 与
torch.optim.AdamW(lr/wd/betas/eps 全对齐)各跑 N 步固定 batch → loss 轨迹与终参逐项 rtol 内一致。- fixture:
cargo test -p xtrain-train --test adamw_parity_dump -- --ignored --nocapture - 对比:
python3 crates/xtrain-train/tests/adamw_parity.py /tmp/xtrain_adamw
- fixture:
- checkpoint round-trip:训几步 → save → 载入全新 init 的模型 → 固定输入 logits/loss 逐位一致(且证明载入前新模型确实不同)。
cargo test -p xtrain-train --test checkpoint_roundtrip
- 真训练(端到端学习信号):TinyStories 上有界训练(几百~几千步)→ loss 大幅下降 + greedy 采样显出英文结构(非乱码)。
cargo test -p xtrain-train --release --test real_training -- --ignored --nocapture- 或
cargo run -p xtrain-train --release --bin train -- <tokenizer.json> <corpus.txt> [steps] [ckpt]
- host 单测(本地即跑):AdamW 数学对独立参考递推、LR schedule 形状、grad-norm/clip 数学。
cargo test -p xtrain-optim -p xtrain-train