docs/11-bf16-mixed-precision.md: the AMP split (bf16 linears + activations, fp32 master / norms / softmax / RoPE / CE, no loss scaling), the cast-op bridge, module layout, and the dual verification gate (fp32 unchanged + bf16 looser-tol + convergence + mem/throughput). Memory/throughput before->after to be filled from the dash5 bench. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
7.3 KiB
7.3 KiB
Phase T12: bf16 混合精度(fp32 master)— Design Document
KI-2 的具体落地。v4(dim768, fp32)在单卡 32GB 下 per-rank batch 32(global 256)OOM,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益:xserv 推理是 BF16-only,bf16 训练让闭环更贴。
Goal
在不动 fp32 路径任何数值的前提下,新增一个 opt-in 的 bf16 混合精度模式(标准 AMP,fp32 master weights):
- 正确性硬闸门:fp32 全套回归(T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在同样紧容差下保持绿。bf16 是加法、可选的,绝不扰动 fp32。
- bf16 正确性:bf16 前向/梯度在更松的 bf16 容差(≈2–3 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段短 bf16 训练收敛对住 fp32(loss 曲线接近、无 NaN/发散)。
- 显存+吞吐(payoff):dim768 bf16 能跑 per-rank batch 32(解 OOM);测 dim768 bf16 vs fp32 的显存+tok/s。
什么是 bf16、什么是 fp32(标准 AMP split)
| 组件 | 精度 | 理由 |
|---|---|---|
| master weights + AdamW state(m/v) + 优化器更新 | fp32 | 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉 |
| linear GEMM(q/k/v/o、gate/up/down、lm_head) | bf16 in/out + fp32 accum | compute+memory 主体;tensor-core 走 cublasGemmEx(CUDA_R_16BF in/out,CUBLAS_COMPUTE_32F 累加) |
| 激活流(残差流、attention Q/K/V/probs/out、MLP 中间) | bf16 | 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速 |
| RMSNorm / QK-norm | fp32(bf16→fp32 算 reduction→bf16) | 求和/rsqrt 数值敏感 |
| softmax(attention)/ RoPE / cross-entropy | fp32 | softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感 |
| 梯度 → AdamW | fp32 | 见下「cast 算子」——grad 在 fp32 master leaf 上累加,AdamW/clip/DDP all-reduce 全程 fp32、完全不改 |
无 loss scaling:bf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp16(5-bit 指数易下溢)。所以梯度不会下溢到 0,不需要 loss scaling。
Module Layout(surgical:fp32 路径逐字节不动)
核心思路:所有 op 按 self.dtype() 分派。fp32 分支跑原 kernel(一字不改);bf16 分支是新增代码。
1. xtrain-tensor::dtype — 加 BF16
DType::BF16,size_bytes()=2,half::bf16实现TensorDType(halfcrate 已是依赖)。
2. xtrain-cuda — bf16 GEMM + cast kernel
ffi.rs:声明cublasGemmEx/cublasGemmStridedBatchedEx(void* 指针、a_type/b_type/c_type/compute_type),常量CUDA_R_16BF=14、CUDA_R_32F=0、CUBLAS_COMPUTE_32F=68(数值同 xservgemm.rs)。cublas.rs:gemm_ex(...)/gemm_ex_strided_batched(...)——和sgemm同样的 row-major⟺col-major 转置代数,只是走GemmEx、in/out=bf16、accum=fp32。csrc/ops/cast.cu+ ffi:launch_cast_f32_to_bf16/launch_cast_bf16_to_f32(逐元素__float2bfloat16/__bfloat162float)。
3. xtrain-tensor::tensor — dtype-polymorphic ops
to_dtype(target):f32↔bf16 cast(CUDA),同 dtype 直接 clone。matmul/matmul_backward/attention/attention_backward:按 dtype 分派——fp32 走原sgemm(不动),bf16 走gemm_ex,输出同 dtype。- 逐元素 op(
add/mul/silu/scale…)+embedding:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」(新增 bf16 kernel)或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。fp32 调用走原 f32 kernel 不变。
4. xtrain-autodiff::ops — cast 算子 + bf16 透传
cast(x, target_dtype):前向x.to_dtype(target);反向把 grad cast 回x的 dtype。这是 AMP 的关键钩子:- fp32 master weight leaf →
cast(w, BF16)喂给 matmul;matmul 的 bf16 grad 经 cast 反向升回 fp32,累加在 fp32 leaf 上。 - ⟹
.grad()是 fp32,AdamW /clip_grad_norm_gpu/ DDPall_reduce_average_grads一行不改,全程 fp32 master。
- fp32 master weight leaf →
- 其它 op 的 backward 自然按张量 dtype 流转(softmax/rms_norm wrapper 内部 upcast→fp32→downcast,对外是 bf16)。
5. xtrain-model::TinyTransformer — compute_dtype 开关
new_amp(cfg, device, dtype, init)或forward_batched接compute_dtype: DType(默认F32= 原路径,逐字节同)。- bf16 模式:embedding 输出
cast→bf16进入残差流;每个 weight matmul 前cast(w, BF16);norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logitscast→fp32给 cross_entropy。fp32 模式compute_dtype==F32时跳过所有 cast,graph 与 T10/T11 完全一致。
6. xtrain-train / xtrain-distributed — --bf16 flag
TrainConfig/DdpConfig加compute_dtype: DType;train.rs/train_ddp.rs加--bf16flag。- AdamW / clip / checkpoint / DDP all-reduce 不改(master 永远 fp32,grad 永远 fp32)。
Key Design Decisions
- cast 算子承载 fp32 master ↔ bf16 compute 的桥:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master,前向临时 cast 出 bf16,反向 grad 自动升回 fp32。最小改动、零优化器侵入。
- 按 dtype 分派而非新类型:fp32 路径走的还是同一个函数的
F32分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。 - norm/softmax/CE 不写 bf16 reduction kernel:wrapper 里
to_dtype(F32)→ 复用现有 fp32 kernel →to_dtype(BF16)。多两个 cast launch,但复用已验证的 fp32 数值,且这些不是显存/算力大头。 - 无 loss scaling:bf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。
验证方法(双闸门)
闸门 ① fp32 不回归(hard gate)
全套现有测试在原紧容差下保持绿(bf16 是 opt-in,默认 dtype=F32):
cargo test全 crate:grad-check(rel ≤2e-2)、structural、GEMM 对 cuBLAS(~1e-7)、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、PyTorch 对拍(loss/logits/grad)。- xserv 闭环:v4 ckpt(fp32 训)重导 safetensors md5 一致 + xserv 贪心逐 token 对住。
闸门 ② bf16 正确性 + 收敛
- bf16 looser-tol 数值:同一组随机权重/输入,bf16 forward logits 与 bf16 grad 对 fp32 参考在 rel ~1e-2(bf16 2–3 位有效数字)内。
- 短训练收敛:dim768(或缩小代理)bf16 跑数百步,loss 曲线对住 fp32,无 NaN/发散,end loss 接近。
闸门 ③ 显存 + 吞吐(payoff)
- dim768 bf16 能跑 per-rank batch 32(v4 OOM 的触发点)。
- 测 dim768 bf16 vs fp32 的峰值显存 + steady-state tok/s(预期:显存↓、tok/s↑)。
实测结果(dash5, 8× RTX 5090, sm_120)
见
docs/known-issues.mdKI-2 的 before→after 表(fp32 batch32 OOM → bf16 batch32 fit;显存 A→B;tok/s A→B)。