train: --bf16 flag (fp32-master AMP) + bf16 correctness test
- TinyTransformer::with_compute_dtype(BF16): embedding stays fp32 master then casts to bf16; each linear casts its fp32 weight to bf16 on the fly; logits cast back to fp32 for cross-entropy. Default F32 reproduces the v0-v4 forward graph bit-for-bit. - --bf16 flag on bin/train and bin/train_ddp (off by default). - tests/bf16.rs: same fp32 master weights run fp32 vs bf16; assert loss/logits/grads within a loose bf16 tol, no NaN, and grads are fp32 (master untouched). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -82,6 +82,9 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
let ckpt: Option<PathBuf> = args
|
||||
.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -161,12 +164,22 @@ fn main() {
|
||||
eval every {eval_every}"
|
||||
);
|
||||
|
||||
if bf16 {
|
||||
println!("bf16 mixed precision: ON (fp32 master weights)");
|
||||
}
|
||||
let results = launch(
|
||||
&devices,
|
||||
&train_corpus,
|
||||
valid.as_ref(),
|
||||
&dcfg,
|
||||
move |device| build_model(cfg, device),
|
||||
move |device| {
|
||||
let m = build_model(cfg, device);
|
||||
if bf16 {
|
||||
m.with_compute_dtype(xtrain_tensor::DType::BF16)
|
||||
} else {
|
||||
m
|
||||
}
|
||||
},
|
||||
);
|
||||
let r0 = &results[0];
|
||||
let start = r0.losses.first().copied().unwrap_or(0.0);
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
use crate::config::Config;
|
||||
use xtrain_autodiff::ops;
|
||||
use xtrain_autodiff::tape::Var;
|
||||
use xtrain_tensor::{Device, Tensor};
|
||||
use xtrain_tensor::{DType, Device, Tensor};
|
||||
|
||||
/// One decoder block's learnable tensors.
|
||||
struct Block {
|
||||
@@ -30,6 +30,13 @@ pub struct TinyTransformer {
|
||||
blocks: Vec<Block>,
|
||||
final_norm: Var, // [dim]
|
||||
lm_head: Var, // [dim, vocab]
|
||||
/// Compute dtype for the forward graph (Phase T12). `F32` (default) = the
|
||||
/// original path, bit-identical to T10/T11. `BF16` = mixed precision: the
|
||||
/// parameter leaves stay fp32 (master), but each linear's weight is cast to
|
||||
/// bf16 on the fly and the activation stream flows bf16 (see
|
||||
/// `docs/11-bf16-mixed-precision.md`). The cast op's backward upcasts the bf16
|
||||
/// weight grad back to fp32, so AdamW/clip/DDP stay fp32 and unchanged.
|
||||
compute_dtype: DType,
|
||||
}
|
||||
|
||||
impl TinyTransformer {
|
||||
@@ -71,6 +78,7 @@ impl TinyTransformer {
|
||||
blocks,
|
||||
final_norm,
|
||||
lm_head,
|
||||
compute_dtype: DType::F32,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +86,35 @@ impl TinyTransformer {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
/// Set the forward compute dtype (Phase T12). `BF16` enables mixed precision
|
||||
/// (fp32 master weights, bf16 linears + activations); `F32` (the default) is
|
||||
/// the unchanged full-precision path. Builder-style so existing call sites
|
||||
/// that don't opt in keep the fp32 numerics bit-for-bit.
|
||||
pub fn with_compute_dtype(mut self, dtype: DType) -> Self {
|
||||
assert!(
|
||||
matches!(dtype, DType::F32 | DType::BF16),
|
||||
"compute_dtype must be F32 or BF16"
|
||||
);
|
||||
self.compute_dtype = dtype;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn compute_dtype(&self) -> DType {
|
||||
self.compute_dtype
|
||||
}
|
||||
|
||||
/// Project `x` (activation, in the compute dtype) by weight `w` (an fp32
|
||||
/// master leaf). In bf16 mode the weight is cast to bf16 via the autograd
|
||||
/// `cast` op (whose backward upcasts the grad to fp32); in fp32 mode this is
|
||||
/// just `matmul(x, w)`. The activation `x` already carries `compute_dtype`.
|
||||
fn linear(&self, x: &Var, w: &Var) -> Var {
|
||||
match self.compute_dtype {
|
||||
DType::F32 => ops::matmul(x, w),
|
||||
DType::BF16 => ops::matmul(x, &ops::cast(w, DType::BF16)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// All learnable parameters, in a stable order. The optimizer (a hand-written
|
||||
/// GD step in T5, AdamW in T6) iterates this; each holds its `.grad()` after
|
||||
/// `backward()`.
|
||||
@@ -127,21 +164,42 @@ impl TinyTransformer {
|
||||
);
|
||||
let seq = total / batch;
|
||||
|
||||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim]
|
||||
// Embedding gathers from the fp32 master table; in bf16 mode cast the
|
||||
// activation stream to bf16 here (norms are cast to bf16 gammas too).
|
||||
let mut h = ops::embedding(&self.embed, ids); // [batch*seq, dim], fp32
|
||||
if self.compute_dtype == DType::BF16 {
|
||||
h = ops::cast(&h, DType::BF16);
|
||||
}
|
||||
for b in &self.blocks {
|
||||
// --- Attention sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &b.attn_norm, self.cfg.eps);
|
||||
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.attn_norm), self.cfg.eps);
|
||||
let attn = self.attention(b, &normed, batch, seq);
|
||||
h = ops::add(&h, &attn);
|
||||
|
||||
// --- MLP sub-block (pre-norm + residual) ---
|
||||
let normed = ops::rms_norm(&h, &b.ffn_norm, self.cfg.eps);
|
||||
let normed = ops::rms_norm(&h, &self.norm_gamma(&b.ffn_norm), self.cfg.eps);
|
||||
let mlp = self.swiglu_mlp(b, &normed);
|
||||
h = ops::add(&h, &mlp);
|
||||
}
|
||||
|
||||
let h = ops::rms_norm(&h, &self.final_norm, self.cfg.eps);
|
||||
ops::matmul(&h, &self.lm_head) // [batch*seq, vocab]
|
||||
let h = ops::rms_norm(&h, &self.norm_gamma(&self.final_norm), self.cfg.eps);
|
||||
// lm_head matmul in compute dtype; cast logits back to fp32 for CE.
|
||||
let logits = self.linear(&h, &self.lm_head); // [batch*seq, vocab]
|
||||
if self.compute_dtype == DType::BF16 {
|
||||
ops::cast(&logits, DType::F32)
|
||||
} else {
|
||||
logits
|
||||
}
|
||||
}
|
||||
|
||||
/// A norm/QK-norm gamma in the compute dtype. fp32 master leaf → bf16 (cast
|
||||
/// op, grad upcast) in bf16 mode; identity in fp32 mode.
|
||||
fn norm_gamma(&self, gamma: &Var) -> Var {
|
||||
match self.compute_dtype {
|
||||
DType::F32 => gamma.clone(),
|
||||
DType::BF16 => ops::cast(gamma, DType::BF16),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-entropy mean loss of `forward(ids)` against `targets` (`[seq]` I32).
|
||||
@@ -186,7 +244,7 @@ impl TinyTransformer {
|
||||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||||
Some(gamma) => {
|
||||
let flat = ops::reshape(&r, &[total * nh, hd]);
|
||||
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
|
||||
let normed = ops::rms_norm(&flat, &self.norm_gamma(gamma), self.cfg.eps);
|
||||
let r = ops::reshape(&normed, &[total, nh, hd]);
|
||||
ops::rope(&r, self.cfg.rope_theta, seq)
|
||||
}
|
||||
@@ -197,9 +255,9 @@ impl TinyTransformer {
|
||||
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
|
||||
};
|
||||
|
||||
let q = to_bh(ops::matmul(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_bh(ops::matmul(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_bh(ops::matmul(x, &b.wv), None);
|
||||
let q = to_bh(self.linear(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_bh(self.linear(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_bh(self.linear(x, &b.wv), None);
|
||||
|
||||
// Fused batched causal SDPA over all B*nh (sequence,head) blocks at once
|
||||
// (2 batched GEMMs + 1 causal-softmax kernel; no per-head/per-seq loop).
|
||||
@@ -210,15 +268,15 @@ impl TinyTransformer {
|
||||
let out = ops::reshape(&out, &[batch, nh, seq, hd]);
|
||||
let out = ops::transpose_4d12(&out); // [B, S, nh, hd]
|
||||
let concat = ops::reshape(&out, &[total, nh * hd]); // [B*S, dim]
|
||||
ops::matmul(&concat, &b.wo) // out projection
|
||||
self.linear(&concat, &b.wo) // out projection
|
||||
}
|
||||
|
||||
/// SwiGLU MLP: `down( silu(gate(x)) ∘ up(x) )`. `x`:[batch*seq,dim].
|
||||
fn swiglu_mlp(&self, b: &Block, x: &Var) -> Var {
|
||||
let gate = ops::matmul(x, &b.w_gate); // [seq, ffn_hidden]
|
||||
let up = ops::matmul(x, &b.w_up); // [seq, ffn_hidden]
|
||||
let gate = self.linear(x, &b.w_gate); // [seq, ffn_hidden]
|
||||
let up = self.linear(x, &b.w_up); // [seq, ffn_hidden]
|
||||
let act = ops::swiglu(&gate, &up); // silu(gate) ∘ up
|
||||
ops::matmul(&act, &b.w_down) // [seq, dim]
|
||||
self.linear(&act, &b.w_down) // [seq, dim]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
145
crates/xtrain-model/tests/bf16.rs
Normal file
145
crates/xtrain-model/tests/bf16.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
// T12 bf16 mixed-precision correctness gate (on-GPU, no PyTorch).
|
||||
//
|
||||
// The SAME model (identical fp32 master weights) run in fp32 vs bf16 compute
|
||||
// mode must agree within a LOOSE bf16 tolerance (bf16 = 7-bit mantissa ≈ 2-3
|
||||
// decimal digits → ~1e-2 relative error is expected and acceptable), both for
|
||||
// the forward loss/logits AND every parameter's gradient. We also assert no
|
||||
// NaN/Inf leaks and that the fp32 grads are fp32 (the cast op upcast the bf16
|
||||
// weight grad back to the fp32 master, so AdamW/clip/DDP stay fp32).
|
||||
//
|
||||
// This is the "bf16 within looser tol vs fp32 reference" gate; the short-run
|
||||
// convergence comparison is the train_loop-level bench on dash5.
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||||
use xtrain_tensor::{DType, Device};
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
||||
let mut seed = 1u64;
|
||||
TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||||
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bf16_matches_fp32_within_loose_tol() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
// A few layers / heads so the bf16 rounding accumulates through the depth
|
||||
// the real model has (not just a single matmul).
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 32;
|
||||
cfg.n_layers = 3;
|
||||
let batch = 2usize;
|
||||
let seq = 8usize;
|
||||
|
||||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||||
.map(|b| {
|
||||
(0..seq)
|
||||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
let ids = batched_ids_tensor(&seqs, device);
|
||||
let tgt = batched_ids_tensor(&tgts, device);
|
||||
|
||||
// fp32 reference.
|
||||
let fp32 = build(cfg, device);
|
||||
let f_logits = host(&fp32.forward_batched(&ids, batch).value());
|
||||
let f_loss = fp32.loss_batched(&ids, &tgt, batch);
|
||||
let f_loss_val = host(&f_loss.value())[0];
|
||||
f_loss.backward();
|
||||
let f_params = fp32.params();
|
||||
|
||||
// bf16 — SAME init (build re-runs the same deterministic fill).
|
||||
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
|
||||
let b_logits = host(&bf16.forward_batched(&ids, batch).value());
|
||||
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
|
||||
let b_loss_val = host(&b_loss.value())[0];
|
||||
b_loss.backward();
|
||||
let b_params = bf16.params();
|
||||
|
||||
// No NaN/Inf in the bf16 forward.
|
||||
assert!(
|
||||
b_logits.iter().all(|v| v.is_finite()) && b_loss_val.is_finite(),
|
||||
"bf16 forward produced non-finite values"
|
||||
);
|
||||
|
||||
// Forward loss within loose bf16 tol.
|
||||
let loss_rel = (b_loss_val - f_loss_val).abs() / f_loss_val.abs().max(1e-4);
|
||||
println!("bf16 vs fp32: loss {b_loss_val:.5} vs {f_loss_val:.5} (rel {loss_rel:.3e})");
|
||||
assert!(
|
||||
loss_rel < 2e-2,
|
||||
"bf16 loss too far from fp32: {loss_rel:.3e}"
|
||||
);
|
||||
|
||||
// Logits: bf16 has ~2-3 decimal digits → compare on a robust (median-style)
|
||||
// basis, requiring the bulk to be within ~3e-2 and the mean error small.
|
||||
let n = f_logits.len();
|
||||
let mut rels: Vec<f32> = f_logits
|
||||
.iter()
|
||||
.zip(&b_logits)
|
||||
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
||||
.collect();
|
||||
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let p99 = rels[(n as f32 * 0.99) as usize];
|
||||
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
||||
println!("bf16 vs fp32 logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
||||
assert!(mean < 1e-2, "bf16 logits mean rel err too high: {mean:.3e}");
|
||||
assert!(p99 < 5e-2, "bf16 logits p99 rel err too high: {p99:.3e}");
|
||||
|
||||
// Gradients: fp32 master grads must be fp32 (cast op upcast), finite, and
|
||||
// within loose bf16 tol of the fp32 reference (mean over each param tensor).
|
||||
let mut worst_param_mean = 0.0f32;
|
||||
for (fp, bp) in f_params.iter().zip(&b_params) {
|
||||
let bg = bp.grad().expect("bf16 grad");
|
||||
assert_eq!(bg.dtype(), DType::F32, "bf16-mode grad must be fp32 master");
|
||||
let fg = host(&fp.grad().expect("fp32 grad"));
|
||||
let bg = host(&bg);
|
||||
assert!(bg.iter().all(|v| v.is_finite()), "bf16 grad has non-finite");
|
||||
// Scale-relative mean error over the tensor (robust to a few small entries).
|
||||
let scale = fg.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
|
||||
let mean_err: f32 =
|
||||
fg.iter().zip(&bg).map(|(f, b)| (f - b).abs()).sum::<f32>() / fg.len() as f32 / scale;
|
||||
worst_param_mean = worst_param_mean.max(mean_err);
|
||||
}
|
||||
println!("bf16 vs fp32 grads: worst per-tensor scaled-mean err = {worst_param_mean:.3e}");
|
||||
assert!(
|
||||
worst_param_mean < 3e-2,
|
||||
"bf16 grads too far from fp32: {worst_param_mean:.3e}"
|
||||
);
|
||||
}
|
||||
@@ -31,6 +31,8 @@ use xtrain_cuda::device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_model::{Config, TinyTransformer};
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::DType;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_tensor::Device;
|
||||
#[cfg(not(no_cuda))]
|
||||
use xtrain_train::data::Corpus;
|
||||
@@ -107,6 +109,9 @@ fn main() {
|
||||
let val_tokens: usize = flag(&args, "--val-tokens", 0);
|
||||
let eval_every: usize = flag(&args, "--eval-every", 0);
|
||||
let eval_batches: usize = flag(&args, "--eval-batches", 64);
|
||||
// bf16 mixed precision (Phase T12): fp32 master weights, bf16 linears +
|
||||
// activations. Opt-in; default fp32 reproduces v0–v4 numerics.
|
||||
let bf16 = args.iter().any(|a| a == "--bf16");
|
||||
let ckpt: PathBuf = PathBuf::from(
|
||||
args.iter()
|
||||
.position(|a| a == "--ckpt")
|
||||
@@ -155,7 +160,7 @@ fn main() {
|
||||
);
|
||||
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
let mut model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
@@ -166,6 +171,10 @@ fn main() {
|
||||
fill(n, seed, 0.04)
|
||||
}
|
||||
});
|
||||
if bf16 {
|
||||
model = model.with_compute_dtype(DType::BF16);
|
||||
println!("bf16 mixed precision: ON (fp32 master weights)");
|
||||
}
|
||||
|
||||
// Eval-only mode: load a checkpoint and score it on the held-out val set, then
|
||||
// exit. Used to put an EXISTING model (e.g. v0) and a new one on the same
|
||||
|
||||
Reference in New Issue
Block a user