Files
xtrain/docs/11-bf16-mixed-precision.md
Gahow Wang 320c1ae4fb perf: KI-2 FIXED — dim768 bf16 fits batch 32, tok/s 31.5K→40.8K
bf16 mixed precision (fp32 master) solves the v4 dim768 fp32 batch-32
OOM and speeds up the now-compute-bound dim768 GEMMs (dash5 1× RTX
5090 32GB, dim768/18L/24h×32 ffn2048 seq256, steady-state):

  config          batch  peak mem   tok/s   fits 32GB
  fp32            16      27.2 GB    31.5K   yes
  bf16            16      19.3 GB    35.5K   yes   (-29% mem / +13% tok/s)
  fp32            32      —          —       OOM
  bf16            32      31.1 GB    40.8K   yes   (+29% vs fp32-b16)

Verified on dash5: fp32 suite green at tight tol + xserv export md5
bit-identical to registry; bf16 looser-tol (loss 1.2e-4, logits p99
6.8e-3, grad 1.0e-2) + 150-step convergence tracks fp32 (3.984 vs
3.988); 2-GPU bf16 DDP at per-rank batch 32 trains cleanly.

Mark KI-2 FIXED; fill docs/11 results.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:28:20 +08:00

8.5 KiB
Raw Permalink Blame History

Phase T12: bf16 混合精度fp32 master— Design Document

KI-2 的具体落地。v4dim768, fp32在单卡 32GB 下 per-rank batch 32global 256OOM,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益xserv 推理是 BF16-onlybf16 训练让闭环更贴。

Goal

不动 fp32 路径任何数值的前提下,新增一个 opt-in 的 bf16 混合精度模式(标准 AMPfp32 master weights

  1. 正确性硬闸门fp32 全套回归T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在同样紧容差下保持绿。bf16 是加法、可选的,绝不扰动 fp32。
  2. bf16 正确性bf16 前向/梯度在更松的 bf16 容差≈23 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段短 bf16 训练收敛对住 fp32loss 曲线接近、无 NaN/发散)。
  3. 显存+吞吐payoffdim768 bf16 能跑 per-rank batch 32解 OOM测 dim768 bf16 vs fp32 的显存+tok/s。

什么是 bf16、什么是 fp32标准 AMP split

组件 精度 理由
master weights + AdamW state(m/v) + 优化器更新 fp32 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉
linear GEMMq/k/v/o、gate/up/down、lm_head bf16 in/out + fp32 accum compute+memory 主体tensor-core 走 cublasGemmExCUDA_R_16BF in/outCUBLAS_COMPUTE_32F 累加)
激活流残差流、attention Q/K/V/probs/out、MLP 中间) bf16 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速
RMSNorm / QK-norm fp32bf16→fp32 算 reduction→bf16 求和/rsqrt 数值敏感
softmaxattention/ RoPE / cross-entropy fp32 softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感
梯度 → AdamW fp32 见下「cast 算子」——grad 在 fp32 master leaf 上累加AdamW/clip/DDP all-reduce 全程 fp32、完全不改

无 loss scalingbf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp165-bit 指数易下溢)。所以梯度不会下溢到 0不需要 loss scaling。

Module Layoutsurgicalfp32 路径逐字节不动)

核心思路:所有 op 按 self.dtype() 分派。fp32 分支跑原 kernel一字不改bf16 分支是新增代码。

1. xtrain-tensor::dtype — 加 BF16

  • DType::BF16size_bytes()=2half::bf16 实现 TensorDTypehalf crate 已是依赖)。

2. xtrain-cuda — bf16 GEMM + cast kernel

  • ffi.rs:声明 cublasGemmEx / cublasGemmStridedBatchedExvoid* 指针、a_type/b_type/c_type/compute_type),常量 CUDA_R_16BF=14CUDA_R_32F=0CUBLAS_COMPUTE_32F=68(数值同 xserv gemm.rs)。
  • cublas.rsgemm_ex(...) / gemm_ex_strided_batched(...)——和 sgemm 同样的 row-major⟺col-major 转置代数,只是走 GemmEx、in/out=bf16、accum=fp32。
  • csrc/ops/cast.cu + ffilaunch_cast_f32_to_bf16 / launch_cast_bf16_to_f32(逐元素 __float2bfloat16 / __bfloat162float)。

3. xtrain-tensor::tensor — dtype-polymorphic ops

  • to_dtype(target)f32↔bf16 castCUDA同 dtype 直接 clone。
  • matmul / matmul_backward / attention / attention_backward:按 dtype 分派——fp32 走原 sgemm不动bf16 走 gemm_ex,输出同 dtype。
  • 逐元素 opadd/mul/silu/scale…)+ embedding:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」新增 bf16 kernel或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。fp32 调用走原 f32 kernel 不变。

4. xtrain-autodiff::opscast 算子 + bf16 透传

  • cast(x, target_dtype):前向 x.to_dtype(target)反向把 grad cast 回 x 的 dtype。这是 AMP 的关键钩子:
    • fp32 master weight leaf → cast(w, BF16) 喂给 matmulmatmul 的 bf16 grad 经 cast 反向升回 fp32,累加在 fp32 leaf 上。
    • .grad()fp32AdamW / clip_grad_norm_gpu / DDP all_reduce_average_grads 一行不改,全程 fp32 master。
  • 其它 op 的 backward 自然按张量 dtype 流转softmax/rms_norm wrapper 内部 upcast→fp32→downcast对外是 bf16

5. xtrain-model::TinyTransformercompute_dtype 开关

  • new_amp(cfg, device, dtype, init)forward_batchedcompute_dtype: DType(默认 F32 = 原路径,逐字节同)。
  • bf16 模式embedding 输出 cast→bf16 进入残差流;每个 weight matmul 前 cast(w, BF16)norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logits cast→fp32 给 cross_entropy。fp32 模式 compute_dtype==F32 时跳过所有 castgraph 与 T10/T11 完全一致。

6. xtrain-train / xtrain-distributed--bf16 flag

  • TrainConfig/DdpConfigcompute_dtype: DTypetrain.rs/train_ddp.rs--bf16 flag。
  • AdamW / clip / checkpoint / DDP all-reduce 不改master 永远 fp32grad 永远 fp32

Key Design Decisions

  • cast 算子承载 fp32 master ↔ bf16 compute 的桥:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master前向临时 cast 出 bf16反向 grad 自动升回 fp32。最小改动、零优化器侵入。
  • 按 dtype 分派而非新类型fp32 路径走的还是同一个函数的 F32 分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。
  • norm/softmax/CE 不写 bf16 reduction kernelwrapper 里 to_dtype(F32) → 复用现有 fp32 kernel → to_dtype(BF16)。多两个 cast launch复用已验证的 fp32 数值,且这些不是显存/算力大头。
  • 无 loss scalingbf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。

验证方法(双闸门)

闸门 ① fp32 不回归hard gate

全套现有测试在原紧容差下保持绿bf16 是 opt-in默认 dtype=F32

  • cargo test 全 crategrad-checkrel ≤2e-2、structural、GEMM 对 cuBLAS~1e-7、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、PyTorch 对拍loss/logits/grad
  • xserv 闭环v4 ckptfp32 训)重导 safetensors md5 一致 + xserv 贪心逐 token 对住。

闸门 ② bf16 正确性 + 收敛

  • bf16 looser-tol 数值:同一组随机权重/输入bf16 forward logits 与 bf16 grad 对 fp32 参考在 rel ~1e-2bf16 23 位有效数字)内。
  • 短训练收敛dim768或缩小代理bf16 跑数百步loss 曲线对住 fp32无 NaN/发散end loss 接近。

闸门 ③ 显存 + 吞吐payoff

  • dim768 bf16 能跑 per-rank batch 32v4 OOM 的触发点)。
  • 测 dim768 bf16 vs fp32 的峰值显存 + steady-state tok/s预期显存↓、tok/s↑

实测结果dash5, 1× RTX 5090 32GB, sm_120

闸门 ① fp32 不回归全套测试在原紧容差绿autograd 15 / structural 5 / GEMM 5 / batched==looped / overfit 27/27 / AdamW GPU bit-exact + host 对 torch / checkpoint 逐位 / DDP 2xserv 闭环v3 ckpt 用 T12 代码重导 model.safetensors 与 registry md5 逐位一致b04fc9f9a0c9af04c47d9ca649aea12e——export/fp32 数值零漂移。

闸门 ② bf16 正确性 + 收敛

  • looser-toltests/bf16.rs:同 fp32 master 跑 fp32 vs bf16——loss rel 1.2e-4、logits mean rel 2.0e-3 / p99 6.8e-3、grad worst scaled-mean 1.0e-2,无 NaNgrad 仍 fp32master 未动)。
  • 收敛dim768 短训 150 步bf16-b16 loss 轨迹对住 fp32-b16step50 4.40 vs 4.40、step149 3.984 vs 3.988),单调降、无发散。

闸门 ③ 显存 + 吞吐dim768/18L/24h×32 ffn2048 seq256, steady-state

config per-rank batch 峰值显存 tok/s fits 32GB?
fp32 16 (v4 fallback) 27.2 GB 31.5K
bf16 16 19.3 GB29% 35.5K+13%
fp32 32 OOM
bf16 32甜点区 31.1 GB 40.8K 解 OOM

→ 同 batchbf16 显存 29% / tok/s +13%bf16 解 fp32-batch32 OOMbatch32 达 40.8K tok/s+29% vs fp32-b16。KI-2 标 FIXED