bf16 mixed precision (fp32 master) solves the v4 dim768 fp32 batch-32 OOM and speeds up the now-compute-bound dim768 GEMMs (dash5 1× RTX 5090 32GB, dim768/18L/24h×32 ffn2048 seq256, steady-state): config batch peak mem tok/s fits 32GB fp32 16 27.2 GB 31.5K yes bf16 16 19.3 GB 35.5K yes (-29% mem / +13% tok/s) fp32 32 — — OOM bf16 32 31.1 GB 40.8K yes (+29% vs fp32-b16) Verified on dash5: fp32 suite green at tight tol + xserv export md5 bit-identical to registry; bf16 looser-tol (loss 1.2e-4, logits p99 6.8e-3, grad 1.0e-2) + 150-step convergence tracks fp32 (3.984 vs 3.988); 2-GPU bf16 DDP at per-rank batch 32 trains cleanly. Mark KI-2 FIXED; fill docs/11 results. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
99 lines
8.5 KiB
Markdown
99 lines
8.5 KiB
Markdown
# Phase T12: bf16 混合精度(fp32 master)— Design Document
|
||
|
||
> KI-2 的具体落地。v4(dim768, fp32)在单卡 32GB 下 per-rank batch 32(global 256)**OOM**,被迫降到 batch 16 训练。bf16 把激活显存减半(找回 batch-256 甜点区),并在 dim768 这个已 compute-bound 的规模上加速 tensor-core GEMM。附带收益:xserv 推理是 **BF16-only**,bf16 训练让闭环更贴。
|
||
|
||
## Goal
|
||
|
||
在**不动 fp32 路径任何数值**的前提下,新增一个 **opt-in 的 bf16 混合精度模式**(标准 AMP,fp32 master weights):
|
||
|
||
1. **正确性硬闸门**:fp32 全套回归(T3 GEMM / T4 12 算子 grad-check / T5 结构+overfit+PyTorch 对拍 / T6 AdamW+checkpoint / T8 DDP / T10 batched / xserv 闭环)在**同样紧容差**下保持绿。bf16 是**加法、可选**的,绝不扰动 fp32。
|
||
2. **bf16 正确性**:bf16 前向/梯度在**更松的 bf16 容差**(≈2–3 位十进制有效数字 → ~1e-2 相对误差)内对住 fp32 参考;一段**短 bf16 训练收敛对住 fp32**(loss 曲线接近、无 NaN/发散)。
|
||
3. **显存+吞吐(payoff)**:dim768 bf16 能跑 per-rank batch 32(解 OOM);测 dim768 bf16 vs fp32 的显存+tok/s。
|
||
|
||
## 什么是 bf16、什么是 fp32(标准 AMP split)
|
||
|
||
| 组件 | 精度 | 理由 |
|
||
|---|---|---|
|
||
| **master weights** + AdamW state(m/v) + 优化器更新 | **fp32** | 小步长更新需要 fp32 精度,否则被 bf16 的 8-bit 尾数吃掉 |
|
||
| **linear GEMM**(q/k/v/o、gate/up/down、lm_head) | **bf16 in/out + fp32 accum** | compute+memory 主体;tensor-core 走 `cublasGemmEx`(`CUDA_R_16BF` in/out,`CUBLAS_COMPUTE_32F` 累加) |
|
||
| **激活流**(残差流、attention Q/K/V/probs/out、MLP 中间) | **bf16** | 激活显存减半——这是解 OOM 的关键,不只是 GEMM 提速 |
|
||
| **RMSNorm / QK-norm** | **fp32**(bf16→fp32 算 reduction→bf16) | 求和/rsqrt 数值敏感 |
|
||
| **softmax(attention)/ RoPE / cross-entropy** | **fp32** | softmax 的 exp/求和、CE 的 log、RoPE 的 sin/cos 都数值敏感 |
|
||
| **梯度 → AdamW** | **fp32** | 见下「cast 算子」——grad 在 fp32 master leaf 上累加,AdamW/clip/DDP all-reduce 全程 fp32、**完全不改** |
|
||
|
||
**无 loss scaling**:bf16 是 8-bit 指数(与 fp32 同动态范围),不像 fp16(5-bit 指数易下溢)。所以梯度不会下溢到 0,**不需要** loss scaling。
|
||
|
||
## Module Layout(surgical:fp32 路径逐字节不动)
|
||
|
||
核心思路:**所有 op 按 `self.dtype()` 分派**。fp32 分支跑原 kernel(一字不改);bf16 分支是新增代码。
|
||
|
||
### 1. `xtrain-tensor::dtype` — 加 `BF16`
|
||
- `DType::BF16`,`size_bytes()=2`,`half::bf16` 实现 `TensorDType`(`half` crate 已是依赖)。
|
||
|
||
### 2. `xtrain-cuda` — bf16 GEMM + cast kernel
|
||
- `ffi.rs`:声明 `cublasGemmEx` / `cublasGemmStridedBatchedEx`(void* 指针、`a_type/b_type/c_type/compute_type`),常量 `CUDA_R_16BF=14`、`CUDA_R_32F=0`、`CUBLAS_COMPUTE_32F=68`(数值同 xserv `gemm.rs`)。
|
||
- `cublas.rs`:`gemm_ex(...)` / `gemm_ex_strided_batched(...)`——和 `sgemm` 同样的 row-major⟺col-major 转置代数,只是走 `GemmEx`、in/out=bf16、accum=fp32。
|
||
- `csrc/ops/cast.cu` + ffi:`launch_cast_f32_to_bf16` / `launch_cast_bf16_to_f32`(逐元素 `__float2bfloat16` / `__bfloat162float`)。
|
||
|
||
### 3. `xtrain-tensor::tensor` — dtype-polymorphic ops
|
||
- `to_dtype(target)`:f32↔bf16 cast(CUDA),同 dtype 直接 clone。
|
||
- `matmul` / `matmul_backward` / `attention` / `attention_backward`:按 dtype 分派——fp32 走原 `sgemm`(**不动**),bf16 走 `gemm_ex`,输出同 dtype。
|
||
- 逐元素 op(`add`/`mul`/`silu`/`scale`…)+ `embedding`:允许 bf16 输入。逐元素 kernel 对 bf16 走「load→fp32→算→store bf16」(新增 bf16 kernel)或对 norm/softmax/CE 在 wrapper 里 upcast→fp32 kernel→downcast。**fp32 调用走原 f32 kernel 不变。**
|
||
|
||
### 4. `xtrain-autodiff::ops` — `cast` 算子 + bf16 透传
|
||
- **`cast(x, target_dtype)`**:前向 `x.to_dtype(target)`;**反向把 grad cast 回 `x` 的 dtype**。这是 AMP 的关键钩子:
|
||
- fp32 master weight leaf → `cast(w, BF16)` 喂给 matmul;matmul 的 bf16 grad 经 cast 反向**升回 fp32**,累加在 fp32 leaf 上。
|
||
- ⟹ `.grad()` 是 **fp32**,AdamW / `clip_grad_norm_gpu` / DDP `all_reduce_average_grads` **一行不改**,全程 fp32 master。
|
||
- 其它 op 的 backward 自然按张量 dtype 流转(softmax/rms_norm wrapper 内部 upcast→fp32→downcast,对外是 bf16)。
|
||
|
||
### 5. `xtrain-model::TinyTransformer` — `compute_dtype` 开关
|
||
- `new_amp(cfg, device, dtype, init)` 或 `forward_batched` 接 `compute_dtype: DType`(默认 `F32` = 原路径,逐字节同)。
|
||
- bf16 模式:embedding 输出 `cast→bf16` 进入残差流;每个 weight matmul 前 `cast(w, BF16)`;norm/softmax/rope 对 bf16 激活自动 fp32 内算;最后 logits `cast→fp32` 给 cross_entropy。**fp32 模式 `compute_dtype==F32` 时跳过所有 cast,graph 与 T10/T11 完全一致。**
|
||
|
||
### 6. `xtrain-train` / `xtrain-distributed` — `--bf16` flag
|
||
- `TrainConfig`/`DdpConfig` 加 `compute_dtype: DType`;`train.rs`/`train_ddp.rs` 加 `--bf16` flag。
|
||
- AdamW / clip / checkpoint / DDP all-reduce **不改**(master 永远 fp32,grad 永远 fp32)。
|
||
|
||
## Key Design Decisions
|
||
|
||
- **cast 算子承载 fp32 master ↔ bf16 compute 的桥**:不需要在优化器里维护一份独立的 bf16 weight 副本——fp32 leaf 即 master,前向临时 cast 出 bf16,反向 grad 自动升回 fp32。最小改动、零优化器侵入。
|
||
- **按 dtype 分派而非新类型**:fp32 路径走的还是同一个函数的 `F32` 分支 → 原 kernel、原 cuBLAS 调用、原 launch 顺序,数值逐字节不变(满足硬闸门)。
|
||
- **norm/softmax/CE 不写 bf16 reduction kernel**:wrapper 里 `to_dtype(F32)` → 复用现有 fp32 kernel → `to_dtype(BF16)`。多两个 cast launch,但**复用已验证的 fp32 数值**,且这些不是显存/算力大头。
|
||
- **无 loss scaling**:bf16 8-bit 指数,省掉 fp16 那套 scale/unscale/inf-check。
|
||
|
||
## 验证方法(双闸门)
|
||
|
||
### 闸门 ① fp32 不回归(hard gate)
|
||
全套现有测试在原紧容差下保持绿(bf16 是 opt-in,默认 dtype=F32):
|
||
- `cargo test` 全 crate:grad-check(rel ≤2e-2)、structural、GEMM 对 cuBLAS(~1e-7)、batched==looped、overfit 27/27、AdamW GPU bit-exact + host 对 torch、checkpoint 逐位、DDP loss 对单卡 <1e-6、**PyTorch 对拍**(loss/logits/grad)。
|
||
- **xserv 闭环**:v4 ckpt(fp32 训)重导 safetensors md5 一致 + xserv 贪心逐 token 对住。
|
||
|
||
### 闸门 ② bf16 正确性 + 收敛
|
||
- **bf16 looser-tol 数值**:同一组随机权重/输入,bf16 forward logits 与 bf16 grad 对 fp32 参考在 **rel ~1e-2**(bf16 2–3 位有效数字)内。
|
||
- **短训练收敛**:dim768(或缩小代理)bf16 跑数百步,loss 曲线对住 fp32,无 NaN/发散,end loss 接近。
|
||
|
||
### 闸门 ③ 显存 + 吞吐(payoff)
|
||
- **dim768 bf16 能跑 per-rank batch 32**(v4 OOM 的触发点)。
|
||
- 测 dim768 **bf16 vs fp32** 的峰值显存 + steady-state tok/s(预期:显存↓、tok/s↑)。
|
||
|
||
## 实测结果(dash5, 1× RTX 5090 32GB, sm_120)
|
||
|
||
**闸门 ① fp32 不回归**:全套测试在原紧容差绿(autograd 15 / structural 5 / GEMM 5 / batched==looped / overfit 27/27 / AdamW GPU bit-exact + host 对 torch / checkpoint 逐位 / DDP 2)。**xserv 闭环**:v3 ckpt 用 T12 代码重导 `model.safetensors` 与 registry **md5 逐位一致**(`b04fc9f9a0c9af04c47d9ca649aea12e`)——export/fp32 数值零漂移。
|
||
|
||
**闸门 ② bf16 正确性 + 收敛**:
|
||
- **looser-tol(tests/bf16.rs)**:同 fp32 master 跑 fp32 vs bf16——loss rel `1.2e-4`、logits mean rel `2.0e-3` / p99 `6.8e-3`、grad worst scaled-mean `1.0e-2`,无 NaN,grad 仍 fp32(master 未动)。
|
||
- **收敛**:dim768 短训 150 步,bf16-b16 loss 轨迹对住 fp32-b16(step50 `4.40` vs `4.40`、step149 `3.984` vs `3.988`),单调降、无发散。
|
||
|
||
**闸门 ③ 显存 + 吞吐(dim768/18L/24h×32 ffn2048 seq256, steady-state)**:
|
||
|
||
| config | per-rank batch | 峰值显存 | tok/s | fits 32GB? |
|
||
|---|---|---|---|---|
|
||
| fp32 | 16 (v4 fallback) | 27.2 GB | 31.5K | ✅ |
|
||
| **bf16** | 16 | **19.3 GB(−29%)** | **35.5K(+13%)** | ✅ |
|
||
| fp32 | 32 | — | — | ❌ **OOM** |
|
||
| **bf16** | **32(甜点区)** | **31.1 GB** | **40.8K** | ✅ **解 OOM** |
|
||
|
||
→ 同 batch:bf16 显存 −29% / tok/s +13%;**bf16 解 fp32-batch32 OOM**,batch32 达 40.8K tok/s(+29% vs fp32-b16)。KI-2 标 **FIXED**。
|
||
</content>
|
||
</invoke>
|