train: --bf16 flag (fp32-master AMP) + bf16 correctness test

- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32
  master then casts to bf16; each linear casts its fp32 weight to bf16
  on the fly; logits cast back to fp32 for cross-entropy. Default F32
  reproduces the v0-v4 forward graph bit-for-bit.
- --bf16 flag on bin/train and bin/train_ddp (off by default).
- tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert
  loss/logits/grads within a loose bf16 tol, no NaN, and grads are
  fp32 (master untouched).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-16 14:14:55 +08:00
parent b0086b5214
commit 0a2a4dcaa8
4 changed files with 241 additions and 16 deletions

View File

@@ -82,6 +82,9 @@ fn main() {
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
// activations. Opt-in; default fp32 reproduces v0v4 numerics.
let bf16 = args.iter().any(|a| a == "--bf16");
let ckpt: Option<PathBuf> = args
.iter()
.position(|a| a == "--ckpt")
@@ -161,12 +164,22 @@ fn main() {
eval every {eval_every}"
);
if bf16 {
println!("bf16 mixed precision: ON (fp32 master weights)");
}
let results = launch(
&devices,
&train_corpus,
valid.as_ref(),
&dcfg,
move |device| build_model(cfg, device),
move |device| {
let m = build_model(cfg, device);
if bf16 {
m.with_compute_dtype(xtrain_tensor::DType::BF16)
} else {
m
}
},
);
let r0 = &results[0];
let start = r0.losses.first().copied().unwrap_or(0.0);

View File

@@ -5,7 +5,7 @@
use crate::config::Config;
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_tensor::{Device, Tensor};
use xtrain_tensor::{DType, Device, Tensor};
/// One decoder block's learnable tensors.
struct Block {
@@ -30,6 +30,13 @@ pub struct TinyTransformer {
blocks: Vec<Block>,
final_norm: Var, // [dim]
lm_head: Var, // [dim, vocab]
/// Compute dtype for the forward graph (Phase T12). `F32` (default) = the
/// original path, bit-identical to T10/T11. `BF16` = mixed precision: the
/// parameter leaves stay fp32 (master), but each linear's weight is cast to
/// bf16 on the fly and the activation stream flows bf16 (see
/// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16
/// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged.
compute_dtype: DType,
}
impl TinyTransformer {
@@ -71,6 +78,7 @@ impl TinyTransformer {
blocks,
final_norm,
lm_head,
compute_dtype: DType::F32,
}
}
@@ -78,6 +86,35 @@ impl TinyTransformer {
&self.cfg
}
/// Set the forward compute dtype (Phase T12). `BF16` enables mixed precision
/// (fp32 master weights, bf16 linears + activations); `F32` (the default) is
/// the unchanged full-precision path. Builder-style so existing call sites
/// that don't opt in keep the fp32 numerics bit-for-bit.
pub fn with_compute_dtype(mut self, dtype: DType) -> Self {
assert!(
matches!(dtype, DType::F32 | DType::BF16),
"compute_dtype must be F32 or BF16"
);
self.compute_dtype = dtype;
self
}
pub fn compute_dtype(&self) -> DType {
self.compute_dtype
}
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32
/// master leaf). In bf16 mode the weight is cast to bf16 via the autograd
/// `cast` op (whose backward upcasts the grad to fp32); in fp32 mode this is
/// just `matmul(x, w)`. The activation `x` already carries `compute_dtype`.
fn linear(&self, x: &Var, w: &Var) -> Var {
match self.compute_dtype {
DType::F32 => ops::matmul(x, w),
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
_ => unreachable!(),
}
}
/// All learnable parameters, in a stable order. The optimizer (a hand-written
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
/// `backward()`.
@@ -127,21 +164,42 @@ impl TinyTransformer {
);
let seq = total / batch;
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim]
// Embedding gathers from the fp32 master table; in bf16 mode cast the
// activation stream to bf16 here (norms are cast to bf16 gammas too).
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
if self.compute_dtype == DType::BF16 {
h = ops::cast(&h, DType::BF16);
}
for b in &self.blocks {
// --- Attention sub-block (pre-norm + residual) ---
let normed = ops::rms_norm(&h, &b.attn_norm, self.cfg.eps);
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.attn_norm), self.cfg.eps);
let attn = self.attention(b, &normed, batch, seq);
h = ops::add(&h, &attn);
// --- MLP sub-block (pre-norm + residual) ---
let normed = ops::rms_norm(&h, &b.ffn_norm, self.cfg.eps);
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.ffn_norm), self.cfg.eps);
let mlp = self.swiglu_mlp(b, &normed);
h = ops::add(&h, &mlp);
}
let h = ops::rms_norm(&h, &self.final_norm, self.cfg.eps);
ops::matmul(&h, &self.lm_head) // [batch*seq, vocab]
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
// lm_head matmul in compute dtype; cast logits back to fp32 for CE.
let logits = self.linear(&h, &self.lm_head); // [batch*seq, vocab]
if self.compute_dtype == DType::BF16 {
ops::cast(&logits, DType::F32)
} else {
logits
}
}
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
/// op, grad upcast) in bf16 mode; identity in fp32 mode.
fn norm_gamma(&self, gamma: &Var) -> Var {
match self.compute_dtype {
DType::F32 => gamma.clone(),
DType::BF16 => ops::cast(gamma, DType::BF16),
_ => unreachable!(),
}
}
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
@@ -186,7 +244,7 @@ impl TinyTransformer {
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
Some(gamma) => {
let flat = ops::reshape(&r, &[total * nh, hd]);
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
let normed = ops::rms_norm(&flat, &self.norm_gamma(gamma), self.cfg.eps);
let r = ops::reshape(&normed, &[total, nh, hd]);
ops::rope(&r, self.cfg.rope_theta, seq)
}
@@ -197,9 +255,9 @@ impl TinyTransformer {
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
};
let q = to_bh(ops::matmul(x, &b.wq), Some(&b.q_norm));
let k = to_bh(ops::matmul(x, &b.wk), Some(&b.k_norm));
let v = to_bh(ops::matmul(x, &b.wv), None);
let q = to_bh(self.linear(x, &b.wq), Some(&b.q_norm));
let k = to_bh(self.linear(x, &b.wk), Some(&b.k_norm));
let v = to_bh(self.linear(x, &b.wv), None);
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
@@ -210,15 +268,15 @@ impl TinyTransformer {
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
ops::matmul(&concat, &b.wo) // out projection
self.linear(&concat, &b.wo) // out projection
}
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var {
let gate = ops::matmul(x, &b.w_gate); // [seq, ffn_hidden]
let up = ops::matmul(x, &b.w_up); // [seq, ffn_hidden]
let gate = self.linear(x, &b.w_gate); // [seq, ffn_hidden]
let up = self.linear(x, &b.w_up); // [seq, ffn_hidden]
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
ops::matmul(&act, &b.w_down) // [seq, dim]
self.linear(&act, &b.w_down) // [seq, dim]
}
}

View File

@@ -0,0 +1,145 @@
// T12 bf16 mixed-precision correctness gate (on-GPU, no PyTorch).
//
// The SAME model (identical fp32 master weights) run in fp32 vs bf16 compute
// mode must agree within a LOOSE bf16 tolerance (bf16 = 7-bit mantissa ≈ 2-3
// decimal digits → ~1e-2 relative error is expected and acceptable), both for
// the forward loss/logits AND every parameter's gradient. We also assert no
// NaN/Inf leaks and that the fp32 grads are fp32 (the cast op upcast the bf16
// weight grad back to the fp32 master, so AdamW/clip/DDP stay fp32).
//
// This is the "bf16 within looser tol vs fp32 reference" gate; the short-run
// convergence comparison is the train_loop-level bench on dash5.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
#[test]
fn bf16_matches_fp32_within_loose_tol() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// A few layers / heads so the bf16 rounding accumulates through the depth
// the real model has (not just a single matmul).
let mut cfg = Config::tiny();
cfg.vocab = 32;
cfg.n_layers = 3;
let batch = 2usize;
let seq = 8usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// fp32 reference.
let fp32 = build(cfg, device);
let f_logits = host(&fp32.forward_batched(&ids, batch).value());
let f_loss = fp32.loss_batched(&ids, &tgt, batch);
let f_loss_val = host(&f_loss.value())[0];
f_loss.backward();
let f_params = fp32.params();
// bf16 — SAME init (build re-runs the same deterministic fill).
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
let b_logits = host(&bf16.forward_batched(&ids, batch).value());
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
let b_loss_val = host(&b_loss.value())[0];
b_loss.backward();
let b_params = bf16.params();
// No NaN/Inf in the bf16 forward.
assert!(
b_logits.iter().all(|v| v.is_finite()) && b_loss_val.is_finite(),
"bf16 forward produced non-finite values"
);
// Forward loss within loose bf16 tol.
let loss_rel = (b_loss_val - f_loss_val).abs() / f_loss_val.abs().max(1e-4);
println!("bf16 vs fp32: loss {b_loss_val:.5} vs {f_loss_val:.5} (rel {loss_rel:.3e})");
assert!(
loss_rel < 2e-2,
"bf16 loss too far from fp32: {loss_rel:.3e}"
);
// Logits: bf16 has ~2-3 decimal digits → compare on a robust (median-style)
// basis, requiring the bulk to be within ~3e-2 and the mean error small.
let n = f_logits.len();
let mut rels: Vec<f32> = f_logits
.iter()
.zip(&b_logits)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p99 = rels[(n as f32 * 0.99) as usize];
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
println!("bf16 vs fp32 logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(mean < 1e-2, "bf16 logits mean rel err too high: {mean:.3e}");
assert!(p99 < 5e-2, "bf16 logits p99 rel err too high: {p99:.3e}");
// Gradients: fp32 master grads must be fp32 (cast op upcast), finite, and
// within loose bf16 tol of the fp32 reference (mean over each param tensor).
let mut worst_param_mean = 0.0f32;
for (fp, bp) in f_params.iter().zip(&b_params) {
let bg = bp.grad().expect("bf16 grad");
assert_eq!(bg.dtype(), DType::F32, "bf16-mode grad must be fp32 master");
let fg = host(&fp.grad().expect("fp32 grad"));
let bg = host(&bg);
assert!(bg.iter().all(|v| v.is_finite()), "bf16 grad has non-finite");
// Scale-relative mean error over the tensor (robust to a few small entries).
let scale = fg.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
let mean_err: f32 =
fg.iter().zip(&bg).map(|(f, b)| (f - b).abs()).sum::<f32>() / fg.len() as f32 / scale;
worst_param_mean = worst_param_mean.max(mean_err);
}
println!("bf16 vs fp32 grads: worst per-tensor scaled-mean err = {worst_param_mean:.3e}");
assert!(
worst_param_mean < 3e-2,
"bf16 grads too far from fp32: {worst_param_mean:.3e}"
);
}

View File

@@ -31,6 +31,8 @@ use xtrain_cuda::device;
#[cfg(not(no_cuda))]
use xtrain_model::{Config, TinyTransformer};
#[cfg(not(no_cuda))]
use xtrain_tensor::DType;
#[cfg(not(no_cuda))]
use xtrain_tensor::Device;
#[cfg(not(no_cuda))]
use xtrain_train::data::Corpus;
@@ -107,6 +109,9 @@ fn main() {
let val_tokens: usize = flag(&args, "--val-tokens", 0);
let eval_every: usize = flag(&args, "--eval-every", 0);
let eval_batches: usize = flag(&args, "--eval-batches", 64);
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
// activations. Opt-in; default fp32 reproduces v0v4 numerics.
let bf16 = args.iter().any(|a| a == "--bf16");
let ckpt: PathBuf = PathBuf::from(
args.iter()
.position(|a| a == "--ckpt")
@@ -155,7 +160,7 @@ fn main() {
);
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
let mut model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
@@ -166,6 +171,10 @@ fn main() {
fill(n, seed, 0.04)
}
});
if bf16 {
model = model.with_compute_dtype(DType::BF16);
println!("bf16 mixed precision: ON (fp32 master weights)");
}
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same