docs: Phase T12 — bf16 mixed precision design

docs/11-bf16-mixed-precision.md: the AMP split (bf16 linears +
activations, fp32 master / norms / softmax / RoPE / CE, no loss
scaling), the cast-op bridge, module layout, and the dual
verification gate (fp32 unchanged + bf16 looser-tol + convergence +
mem/throughput). Memory/throughput before->after to be filled from
the dash5 bench.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:15:02 +08:00
parent 0a2a4dcaa8
commit 30db62d8f2

View File

@@ -0,0 +1,83 @@
# Phase T12: bf16 混合精度fp32 master— Design Document
> KI-2 的具体落地。v4dim768, fp32在单卡 32GB 下 per-rank batch 32global 256**OOM**,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益xserv 推理是 **BF16-only**bf16 训练让闭环更贴。
## Goal
在**不动 fp32 路径任何数值**的前提下,新增一个 **opt-in 的 bf16 混合精度模式**(标准 AMPfp32 master weights
1. **正确性硬闸门**fp32 全套回归T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在**同样紧容差**下保持绿。bf16 是**加法、可选**的,绝不扰动 fp32。
2. **bf16 正确性**bf16 前向/梯度在**更松的 bf16 容差**≈23 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段**短 bf16 训练收敛对住 fp32**loss 曲线接近、无 NaN/发散)。
3. **显存+吞吐payoff**dim768 bf16 能跑 per-rank batch 32解 OOM测 dim768 bf16 vs fp32 的显存+tok/s。
## 什么是 bf16、什么是 fp32标准 AMP split
| 组件 | 精度 | 理由 |
|---|---|---|
| **master weights** + AdamW state(m/v) + 优化器更新 | **fp32** | 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉 |
| **linear GEMM**q/k/v/o、gate/up/down、lm_head | **bf16 in/out + fp32 accum** | compute+memory 主体tensor-core 走 `cublasGemmEx``CUDA_R_16BF` in/out`CUBLAS_COMPUTE_32F` 累加) |
| **激活流**残差流、attention Q/K/V/probs/out、MLP 中间) | **bf16** | 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速 |
| **RMSNorm / QK-norm** | **fp32**bf16→fp32 算 reduction→bf16 | 求和/rsqrt 数值敏感 |
| **softmaxattention/ RoPE / cross-entropy** | **fp32** | softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感 |
| **梯度 → AdamW** | **fp32** | 见下「cast 算子」——grad 在 fp32 master leaf 上累加AdamW/clip/DDP all-reduce 全程 fp32、**完全不改** |
**无 loss scaling**bf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp165-bit 指数易下溢)。所以梯度不会下溢到 0**不需要** loss scaling。
## Module Layoutsurgicalfp32 路径逐字节不动)
核心思路:**所有 op 按 `self.dtype()` 分派**。fp32 分支跑原 kernel一字不改bf16 分支是新增代码。
### 1. `xtrain-tensor::dtype` — 加 `BF16`
- `DType::BF16``size_bytes()=2``half::bf16` 实现 `TensorDType``half` crate 已是依赖)。
### 2. `xtrain-cuda` — bf16 GEMM + cast kernel
- `ffi.rs`:声明 `cublasGemmEx` / `cublasGemmStridedBatchedEx`void* 指针、`a_type/b_type/c_type/compute_type`),常量 `CUDA_R_16BF=14``CUDA_R_32F=0``CUBLAS_COMPUTE_32F=68`(数值同 xserv `gemm.rs`)。
- `cublas.rs``gemm_ex(...)` / `gemm_ex_strided_batched(...)`——和 `sgemm` 同样的 row-major⟺col-major 转置代数,只是走 `GemmEx`、in/out=bf16、accum=fp32。
- `csrc/ops/cast.cu` + ffi`launch_cast_f32_to_bf16` / `launch_cast_bf16_to_f32`(逐元素 `__float2bfloat16` / `__bfloat162float`)。
### 3. `xtrain-tensor::tensor` — dtype-polymorphic ops
- `to_dtype(target)`f32↔bf16 castCUDA同 dtype 直接 clone。
- `matmul` / `matmul_backward` / `attention` / `attention_backward`:按 dtype 分派——fp32 走原 `sgemm`**不动**bf16 走 `gemm_ex`,输出同 dtype。
- 逐元素 op`add`/`mul`/`silu`/`scale`…)+ `embedding`:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」新增 bf16 kernel或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。**fp32 调用走原 f32 kernel 不变。**
### 4. `xtrain-autodiff::ops` — `cast` 算子 + bf16 透传
- **`cast(x, target_dtype)`**:前向 `x.to_dtype(target)`**反向把 grad cast 回 `x` 的 dtype**。这是 AMP 的关键钩子:
- fp32 master weight leaf → `cast(w, BF16)` 喂给 matmulmatmul 的 bf16 grad 经 cast 反向**升回 fp32**,累加在 fp32 leaf 上。
-`.grad()`**fp32**AdamW / `clip_grad_norm_gpu` / DDP `all_reduce_average_grads` **一行不改**,全程 fp32 master。
- 其它 op 的 backward 自然按张量 dtype 流转softmax/rms_norm wrapper 内部 upcast→fp32→downcast对外是 bf16
### 5. `xtrain-model::TinyTransformer` — `compute_dtype` 开关
- `new_amp(cfg, device, dtype, init)``forward_batched``compute_dtype: DType`(默认 `F32` = 原路径,逐字节同)。
- bf16 模式embedding 输出 `cast→bf16` 进入残差流;每个 weight matmul 前 `cast(w, BF16)`norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logits `cast→fp32` 给 cross_entropy。**fp32 模式 `compute_dtype==F32` 时跳过所有 castgraph 与 T10/T11 完全一致。**
### 6. `xtrain-train` / `xtrain-distributed` — `--bf16` flag
- `TrainConfig`/`DdpConfig``compute_dtype: DType``train.rs`/`train_ddp.rs``--bf16` flag。
- AdamW / clip / checkpoint / DDP all-reduce **不改**master 永远 fp32grad 永远 fp32
## Key Design Decisions
- **cast 算子承载 fp32 master ↔ bf16 compute 的桥**:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master前向临时 cast 出 bf16反向 grad 自动升回 fp32。最小改动、零优化器侵入。
- **按 dtype 分派而非新类型**fp32 路径走的还是同一个函数的 `F32` 分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。
- **norm/softmax/CE 不写 bf16 reduction kernel**wrapper 里 `to_dtype(F32)` → 复用现有 fp32 kernel → `to_dtype(BF16)`。多两个 cast launch但**复用已验证的 fp32 数值**,且这些不是显存/算力大头。
- **无 loss scaling**bf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。
## 验证方法(双闸门)
### 闸门 ① fp32 不回归hard gate
全套现有测试在原紧容差下保持绿bf16 是 opt-in默认 dtype=F32
- `cargo test` 全 crategrad-checkrel ≤2e-2、structural、GEMM 对 cuBLAS~1e-7、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、**PyTorch 对拍**loss/logits/grad)。
- **xserv 闭环**v4 ckptfp32 重导 safetensors md5 一致 + xserv 贪心逐 token 对住
### 闸门 ② bf16 正确性 + 收敛
- **bf16 looser-tol 数值**同一组随机权重/输入bf16 forward logits bf16 grad fp32 参考在 **rel ~1e-2**bf16 23 位有效数字
- **短训练收敛**dim768或缩小代理bf16 跑数百步loss 曲线对住 fp32 NaN/发散end loss 接近
### 闸门 ③ 显存 + 吞吐payoff
- **dim768 bf16 能跑 per-rank batch 32**v4 OOM 的触发点)。
- dim768 **bf16 vs fp32** 的峰值显存 + steady-state tok/s预期显存↓、tok/s↑)。
## 实测结果dash5, 8× RTX 5090, sm_120
> 见 `docs/known-issues.md` KI-2 的 before→after 表fp32 batch32 OOM → bf16 batch32 fit显存 A→Btok/s A→B
</content>
</invoke>