bf16 mixed precision (fp32 master) solves the v4 dim768 fp32 batch-32 OOM and speeds up the now-compute-bound dim768 GEMMs (dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, steady-state): config batch peak mem tok/s fits 32GB fp32 16 27.2 GB 31.5K yes bf16 16 19.3 GB 35.5K yes (-29% mem / +13% tok/s) fp32 32 — — OOM bf16 32 31.1 GB 40.8K yes (+29% vs fp32-b16) Verified on dash5: fp32 suite green at tight tol + xserv export md5 bit-identical to registry; bf16 looser-tol (loss 1.2e-4, logits p99 6.8e-3, grad 1.0e-2) + 150-step convergence tracks fp32 (3.984 vs 3.988); 2-GPU bf16 DDP at per-rank batch 32 trains cleanly. Mark KI-2 FIXED; fill docs/11 results. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
8.5 KiB
Phase T12: bf16 混合精度(fp32 master)— Design Document
KI-2 的具体落地。v4(dim768, fp32)在单卡 32GB 下 per-rank batch 32(global 256)OOM,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益:xserv 推理是 BF16-only,bf16 训练让闭环更贴。
Goal
在不动 fp32 路径任何数值的前提下,新增一个 opt-in 的 bf16 混合精度模式(标准 AMP,fp32 master weights):
- 正确性硬闸门:fp32 全套回归(T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在同样紧容差下保持绿。bf16 是加法、可选的,绝不扰动 fp32。
- bf16 正确性:bf16 前向/梯度在更松的 bf16 容差(≈2–3 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段短 bf16 训练收敛对住 fp32(loss 曲线接近、无 NaN/发散)。
- 显存+吞吐(payoff):dim768 bf16 能跑 per-rank batch 32(解 OOM);测 dim768 bf16 vs fp32 的显存+tok/s。
什么是 bf16、什么是 fp32(标准 AMP split)
| 组件 | 精度 | 理由 |
|---|---|---|
| master weights + AdamW state(m/v) + 优化器更新 | fp32 | 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉 |
| linear GEMM(q/k/v/o、gate/up/down、lm_head) | bf16 in/out + fp32 accum | compute+memory 主体;tensor-core 走 cublasGemmEx(CUDA_R_16BF in/out,CUBLAS_COMPUTE_32F 累加) |
| 激活流(残差流、attention Q/K/V/probs/out、MLP 中间) | bf16 | 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速 |
| RMSNorm / QK-norm | fp32(bf16→fp32 算 reduction→bf16) | 求和/rsqrt 数值敏感 |
| softmax(attention)/ RoPE / cross-entropy | fp32 | softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感 |
| 梯度 → AdamW | fp32 | 见下「cast 算子」——grad 在 fp32 master leaf 上累加,AdamW/clip/DDP all-reduce 全程 fp32、完全不改 |
无 loss scaling:bf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp16(5-bit 指数易下溢)。所以梯度不会下溢到 0,不需要 loss scaling。
Module Layout(surgical:fp32 路径逐字节不动)
核心思路:所有 op 按 self.dtype() 分派。fp32 分支跑原 kernel(一字不改);bf16 分支是新增代码。
1. xtrain-tensor::dtype — 加 BF16
DType::BF16,size_bytes()=2,half::bf16实现TensorDType(halfcrate 已是依赖)。
2. xtrain-cuda — bf16 GEMM + cast kernel
ffi.rs:声明cublasGemmEx/cublasGemmStridedBatchedEx(void* 指针、a_type/b_type/c_type/compute_type),常量CUDA_R_16BF=14、CUDA_R_32F=0、CUBLAS_COMPUTE_32F=68(数值同 xservgemm.rs)。cublas.rs:gemm_ex(...)/gemm_ex_strided_batched(...)——和sgemm同样的 row-major⟺col-major 转置代数,只是走GemmEx、in/out=bf16、accum=fp32。csrc/ops/cast.cu+ ffi:launch_cast_f32_to_bf16/launch_cast_bf16_to_f32(逐元素__float2bfloat16/__bfloat162float)。
3. xtrain-tensor::tensor — dtype-polymorphic ops
to_dtype(target):f32↔bf16 cast(CUDA),同 dtype 直接 clone。matmul/matmul_backward/attention/attention_backward:按 dtype 分派——fp32 走原sgemm(不动),bf16 走gemm_ex,输出同 dtype。- 逐元素 op(
add/mul/silu/scale…)+embedding:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」(新增 bf16 kernel)或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。fp32 调用走原 f32 kernel 不变。
4. xtrain-autodiff::ops — cast 算子 + bf16 透传
cast(x, target_dtype):前向x.to_dtype(target);反向把 grad cast 回x的 dtype。这是 AMP 的关键钩子:- fp32 master weight leaf →
cast(w, BF16)喂给 matmul;matmul 的 bf16 grad 经 cast 反向升回 fp32,累加在 fp32 leaf 上。 - ⟹
.grad()是 fp32,AdamW /clip_grad_norm_gpu/ DDPall_reduce_average_grads一行不改,全程 fp32 master。
- fp32 master weight leaf →
- 其它 op 的 backward 自然按张量 dtype 流转(softmax/rms_norm wrapper 内部 upcast→fp32→downcast,对外是 bf16)。
5. xtrain-model::TinyTransformer — compute_dtype 开关
new_amp(cfg, device, dtype, init)或forward_batched接compute_dtype: DType(默认F32= 原路径,逐字节同)。- bf16 模式:embedding 输出
cast→bf16进入残差流;每个 weight matmul 前cast(w, BF16);norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logitscast→fp32给 cross_entropy。fp32 模式compute_dtype==F32时跳过所有 cast,graph 与 T10/T11 完全一致。
6. xtrain-train / xtrain-distributed — --bf16 flag
TrainConfig/DdpConfig加compute_dtype: DType;train.rs/train_ddp.rs加--bf16flag。- AdamW / clip / checkpoint / DDP all-reduce 不改(master 永远 fp32,grad 永远 fp32)。
Key Design Decisions
- cast 算子承载 fp32 master ↔ bf16 compute 的桥:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master,前向临时 cast 出 bf16,反向 grad 自动升回 fp32。最小改动、零优化器侵入。
- 按 dtype 分派而非新类型:fp32 路径走的还是同一个函数的
F32分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。 - norm/softmax/CE 不写 bf16 reduction kernel:wrapper 里
to_dtype(F32)→ 复用现有 fp32 kernel →to_dtype(BF16)。多两个 cast launch,但复用已验证的 fp32 数值,且这些不是显存/算力大头。 - 无 loss scaling:bf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。
验证方法(双闸门)
闸门 ① fp32 不回归(hard gate)
全套现有测试在原紧容差下保持绿(bf16 是 opt-in,默认 dtype=F32):
cargo test全 crate:grad-check(rel ≤2e-2)、structural、GEMM 对 cuBLAS(~1e-7)、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、PyTorch 对拍(loss/logits/grad)。- xserv 闭环:v4 ckpt(fp32 训)重导 safetensors md5 一致 + xserv 贪心逐 token 对住。
闸门 ② bf16 正确性 + 收敛
- bf16 looser-tol 数值:同一组随机权重/输入,bf16 forward logits 与 bf16 grad 对 fp32 参考在 rel ~1e-2(bf16 2–3 位有效数字)内。
- 短训练收敛:dim768(或缩小代理)bf16 跑数百步,loss 曲线对住 fp32,无 NaN/发散,end loss 接近。
闸门 ③ 显存 + 吞吐(payoff)
- dim768 bf16 能跑 per-rank batch 32(v4 OOM 的触发点)。
- 测 dim768 bf16 vs fp32 的峰值显存 + steady-state tok/s(预期:显存↓、tok/s↑)。
实测结果(dash5, 1× RTX 5090 32GB, sm_120)
闸门 ① fp32 不回归:全套测试在原紧容差绿(autograd 15 / structural 5 / GEMM 5 / batched==looped / overfit 27/27 / AdamW GPU bit-exact + host 对 torch / checkpoint 逐位 / DDP 2)。xserv 闭环:v3 ckpt 用 T12 代码重导 model.safetensors 与 registry md5 逐位一致(b04fc9f9a0c9af04c47d9ca649aea12e)——export/fp32 数值零漂移。
闸门 ② bf16 正确性 + 收敛:
- looser-tol(tests/bf16.rs):同 fp32 master 跑 fp32 vs bf16——loss rel
1.2e-4、logits mean rel2.0e-3/ p996.8e-3、grad worst scaled-mean1.0e-2,无 NaN,grad 仍 fp32(master 未动)。 - 收敛:dim768 短训 150 步,bf16-b16 loss 轨迹对住 fp32-b16(step50
4.40vs4.40、step1493.984vs3.988),单调降、无发散。
闸门 ③ 显存 + 吞吐(dim768/18L/24h×32 ffn2048 seq256, steady-state):
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|---|---|---|---|---|
| fp32 | 16 (v4 fallback) | 27.2 GB | 31.5K | ✅ |
| bf16 | 16 | 19.3 GB(−29%) | 35.5K(+13%) | ✅ |
| fp32 | 32 | — | — | ❌ OOM |
| bf16 | 32(甜点区) | 31.1 GB | 40.8K | ✅ 解 OOM |
→ 同 batch:bf16 显存 −29% / tok/s +13%;bf16 解 fp32-batch32 OOM,batch32 达 40.8K tok/s(+29% vs fp32-b16)。KI-2 标 FIXED。