Files
xtrain/docs/11-bf16-mixed-precision.md
Gahow Wang 320c1ae4fb perf: KI-2 FIXED — dim768 bf16 fits batch 32, tok/s 31.5K→40.8K
bf16 mixed precision (fp32 master) solves the v4 dim768 fp32 batch-32
OOM and speeds up the now-compute-bound dim768 GEMMs (dash5 1× RTX
5090 32GB, dim768/18L/24h×32 ffn2048 seq256, steady-state):

  config          batch  peak mem   tok/s   fits 32GB
  fp32            16      27.2 GB    31.5K   yes
  bf16            16      19.3 GB    35.5K   yes   (-29% mem / +13% tok/s)
  fp32            32      —          —       OOM
  bf16            32      31.1 GB    40.8K   yes   (+29% vs fp32-b16)

Verified on dash5: fp32 suite green at tight tol + xserv export md5
bit-identical to registry; bf16 looser-tol (loss 1.2e-4, logits p99
6.8e-3, grad 1.0e-2) + 150-step convergence tracks fp32 (3.984 vs
3.988); 2-GPU bf16 DDP at per-rank batch 32 trains cleanly.

Mark KI-2 FIXED; fill docs/11 results.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:28:20 +08:00

99 lines
8.5 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T12: bf16 混合精度fp32 master— Design Document
> KI-2 的具体落地。v4dim768, fp32在单卡 32GB 下 per-rank batch 32global 256**OOM**,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益xserv 推理是 **BF16-only**bf16 训练让闭环更贴。
## Goal
在**不动 fp32 路径任何数值**的前提下,新增一个 **opt-in 的 bf16 混合精度模式**(标准 AMPfp32 master weights
1. **正确性硬闸门**fp32 全套回归T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在**同样紧容差**下保持绿。bf16 是**加法、可选**的,绝不扰动 fp32。
2. **bf16 正确性**bf16 前向/梯度在**更松的 bf16 容差**≈23 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段**短 bf16 训练收敛对住 fp32**loss 曲线接近、无 NaN/发散)。
3. **显存+吞吐payoff**dim768 bf16 能跑 per-rank batch 32解 OOM测 dim768 bf16 vs fp32 的显存+tok/s。
## 什么是 bf16、什么是 fp32标准 AMP split
| 组件 | 精度 | 理由 |
|---|---|---|
| **master weights** + AdamW state(m/v) + 优化器更新 | **fp32** | 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉 |
| **linear GEMM**q/k/v/o、gate/up/down、lm_head | **bf16 in/out + fp32 accum** | compute+memory 主体tensor-core 走 `cublasGemmEx``CUDA_R_16BF` in/out`CUBLAS_COMPUTE_32F` 累加) |
| **激活流**残差流、attention Q/K/V/probs/out、MLP 中间) | **bf16** | 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速 |
| **RMSNorm / QK-norm** | **fp32**bf16→fp32 算 reduction→bf16 | 求和/rsqrt 数值敏感 |
| **softmaxattention/ RoPE / cross-entropy** | **fp32** | softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感 |
| **梯度 → AdamW** | **fp32** | 见下「cast 算子」——grad 在 fp32 master leaf 上累加AdamW/clip/DDP all-reduce 全程 fp32、**完全不改** |
**无 loss scaling**bf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp165-bit 指数易下溢)。所以梯度不会下溢到 0**不需要** loss scaling。
## Module Layoutsurgicalfp32 路径逐字节不动)
核心思路:**所有 op 按 `self.dtype()` 分派**。fp32 分支跑原 kernel一字不改bf16 分支是新增代码。
### 1. `xtrain-tensor::dtype` — 加 `BF16`
- `DType::BF16``size_bytes()=2``half::bf16` 实现 `TensorDType``half` crate 已是依赖)。
### 2. `xtrain-cuda` — bf16 GEMM + cast kernel
- `ffi.rs`:声明 `cublasGemmEx` / `cublasGemmStridedBatchedEx`void* 指针、`a_type/b_type/c_type/compute_type`),常量 `CUDA_R_16BF=14``CUDA_R_32F=0``CUBLAS_COMPUTE_32F=68`(数值同 xserv `gemm.rs`)。
- `cublas.rs``gemm_ex(...)` / `gemm_ex_strided_batched(...)`——和 `sgemm` 同样的 row-major⟺col-major 转置代数,只是走 `GemmEx`、in/out=bf16、accum=fp32。
- `csrc/ops/cast.cu` + ffi`launch_cast_f32_to_bf16` / `launch_cast_bf16_to_f32`(逐元素 `__float2bfloat16` / `__bfloat162float`)。
### 3. `xtrain-tensor::tensor` — dtype-polymorphic ops
- `to_dtype(target)`f32↔bf16 castCUDA同 dtype 直接 clone。
- `matmul` / `matmul_backward` / `attention` / `attention_backward`:按 dtype 分派——fp32 走原 `sgemm`**不动**bf16 走 `gemm_ex`,输出同 dtype。
- 逐元素 op`add`/`mul`/`silu`/`scale`…)+ `embedding`:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」新增 bf16 kernel或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。**fp32 调用走原 f32 kernel 不变。**
### 4. `xtrain-autodiff::ops` — `cast` 算子 + bf16 透传
- **`cast(x, target_dtype)`**:前向 `x.to_dtype(target)`**反向把 grad cast 回 `x` 的 dtype**。这是 AMP 的关键钩子:
- fp32 master weight leaf → `cast(w, BF16)` 喂给 matmulmatmul 的 bf16 grad 经 cast 反向**升回 fp32**,累加在 fp32 leaf 上。
-`.grad()`**fp32**AdamW / `clip_grad_norm_gpu` / DDP `all_reduce_average_grads` **一行不改**,全程 fp32 master。
- 其它 op 的 backward 自然按张量 dtype 流转softmax/rms_norm wrapper 内部 upcast→fp32→downcast对外是 bf16
### 5. `xtrain-model::TinyTransformer` — `compute_dtype` 开关
- `new_amp(cfg, device, dtype, init)``forward_batched``compute_dtype: DType`(默认 `F32` = 原路径,逐字节同)。
- bf16 模式embedding 输出 `cast→bf16` 进入残差流;每个 weight matmul 前 `cast(w, BF16)`norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logits `cast→fp32` 给 cross_entropy。**fp32 模式 `compute_dtype==F32` 时跳过所有 castgraph 与 T10/T11 完全一致。**
### 6. `xtrain-train` / `xtrain-distributed` — `--bf16` flag
- `TrainConfig`/`DdpConfig``compute_dtype: DType``train.rs`/`train_ddp.rs``--bf16` flag。
- AdamW / clip / checkpoint / DDP all-reduce **不改**master 永远 fp32grad 永远 fp32
## Key Design Decisions
- **cast 算子承载 fp32 master ↔ bf16 compute 的桥**:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master前向临时 cast 出 bf16反向 grad 自动升回 fp32。最小改动、零优化器侵入。
- **按 dtype 分派而非新类型**fp32 路径走的还是同一个函数的 `F32` 分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。
- **norm/softmax/CE 不写 bf16 reduction kernel**wrapper 里 `to_dtype(F32)` → 复用现有 fp32 kernel → `to_dtype(BF16)`。多两个 cast launch但**复用已验证的 fp32 数值**,且这些不是显存/算力大头。
- **无 loss scaling**bf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。
## 验证方法(双闸门)
### 闸门 ① fp32 不回归hard gate
全套现有测试在原紧容差下保持绿bf16 是 opt-in默认 dtype=F32
- `cargo test` 全 crategrad-checkrel ≤2e-2、structural、GEMM 对 cuBLAS~1e-7、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、**PyTorch 对拍**loss/logits/grad)。
- **xserv 闭环**v4 ckptfp32 重导 safetensors md5 一致 + xserv 贪心逐 token 对住
### 闸门 ② bf16 正确性 + 收敛
- **bf16 looser-tol 数值**同一组随机权重/输入bf16 forward logits bf16 grad fp32 参考在 **rel ~1e-2**bf16 23 位有效数字
- **短训练收敛**dim768或缩小代理bf16 跑数百步loss 曲线对住 fp32 NaN/发散end loss 接近
### 闸门 ③ 显存 + 吞吐payoff
- **dim768 bf16 能跑 per-rank batch 32**v4 OOM 的触发点)。
- dim768 **bf16 vs fp32** 的峰值显存 + steady-state tok/s预期显存↓、tok/s↑)。
## 实测结果dash5, 1× RTX 5090 32GB, sm_120
**闸门 ① fp32 不回归**全套测试在原紧容差绿autograd 15 / structural 5 / GEMM 5 / batched==looped / overfit 27/27 / AdamW GPU bit-exact + host torch / checkpoint 逐位 / DDP 2)。**xserv 闭环**v3 ckpt T12 代码重导 `model.safetensors` registry **md5 逐位一致**`b04fc9f9a0c9af04c47d9ca649aea12e`)——export/fp32 数值零漂移
**闸门 ② bf16 正确性 + 收敛**
- **looser-toltests/bf16.rs** fp32 master fp32 vs bf16——loss rel `1.2e-4`logits mean rel `2.0e-3` / p99 `6.8e-3`grad worst scaled-mean `1.0e-2` NaNgrad fp32master 未动)。
- **收敛**dim768 短训 150 bf16-b16 loss 轨迹对住 fp32-b16step50 `4.40` vs `4.40`step149 `3.984` vs `3.988`单调降无发散
**闸门 ③ 显存 + 吞吐dim768/18L/24h×32 ffn2048 seq256, steady-state**
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|---|---|---|---|---|
| fp32 | 16 (v4 fallback) | 27.2 GB | 31.5K | |
| **bf16** | 16 | **19.3 GB29%** | **35.5K+13%** | |
| fp32 | 32 | | | **OOM** |
| **bf16** | **32甜点区** | **31.1 GB** | **40.8K** | **解 OOM** |
batchbf16 显存 29% / tok/s +13%**bf16 fp32-batch32 OOM**batch32 40.8K tok/s+29% vs fp32-b16)。KI-2 **FIXED**
</content>
</invoke>