The `keep bf16 logits` change made forward_batched return bf16 logits in bf16 mode; the bf16 test's host read must cast to f32 first. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
152 lines
5.7 KiB
Rust
152 lines
5.7 KiB
Rust
// T12 bf16 mixed-precision correctness gate (on-GPU, no PyTorch).
|
|
//
|
|
// The SAME model (identical fp32 master weights) run in fp32 vs bf16 compute
|
|
// mode must agree within a LOOSE bf16 tolerance (bf16 = 7-bit mantissa ≈ 2-3
|
|
// decimal digits → ~1e-2 relative error is expected and acceptable), both for
|
|
// the forward loss/logits AND every parameter's gradient. We also assert no
|
|
// NaN/Inf leaks and that the fp32 grads are fp32 (the cast op upcast the bf16
|
|
// weight grad back to the fp32 master, so AdamW/clip/DDP stay fp32).
|
|
//
|
|
// This is the "bf16 within looser tol vs fp32 reference" gate; the short-run
|
|
// convergence comparison is the train_loop-level bench on dash5.
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
|
use xtrain_tensor::{DType, Device};
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build(cfg: Config, device: Device) -> TinyTransformer {
|
|
let mut seed = 1u64;
|
|
TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
})
|
|
}
|
|
|
|
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
|
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
|
|
}
|
|
|
|
#[test]
|
|
fn bf16_matches_fp32_within_loose_tol() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
// A few layers / heads so the bf16 rounding accumulates through the depth
|
|
// the real model has (not just a single matmul).
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 32;
|
|
cfg.n_layers = 3;
|
|
let batch = 2usize;
|
|
let seq = 8usize;
|
|
|
|
let seqs: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let tgts: Vec<Vec<i32>> = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let ids = batched_ids_tensor(&seqs, device);
|
|
let tgt = batched_ids_tensor(&tgts, device);
|
|
|
|
// fp32 reference.
|
|
let fp32 = build(cfg, device);
|
|
let f_logits = host(&fp32.forward_batched(&ids, batch).value());
|
|
let f_loss = fp32.loss_batched(&ids, &tgt, batch);
|
|
let f_loss_val = host(&f_loss.value())[0];
|
|
f_loss.backward();
|
|
let f_params = fp32.params();
|
|
|
|
// bf16 — SAME init (build re-runs the same deterministic fill). The forward
|
|
// now returns bf16 logits (CE upcasts internally); cast to f32 to read.
|
|
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
|
|
let b_logits = host(
|
|
&bf16
|
|
.forward_batched(&ids, batch)
|
|
.value()
|
|
.to_dtype(DType::F32),
|
|
);
|
|
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
|
|
let b_loss_val = host(&b_loss.value())[0];
|
|
b_loss.backward();
|
|
let b_params = bf16.params();
|
|
|
|
// No NaN/Inf in the bf16 forward.
|
|
assert!(
|
|
b_logits.iter().all(|v| v.is_finite()) && b_loss_val.is_finite(),
|
|
"bf16 forward produced non-finite values"
|
|
);
|
|
|
|
// Forward loss within loose bf16 tol.
|
|
let loss_rel = (b_loss_val - f_loss_val).abs() / f_loss_val.abs().max(1e-4);
|
|
println!("bf16 vs fp32: loss {b_loss_val:.5} vs {f_loss_val:.5} (rel {loss_rel:.3e})");
|
|
assert!(
|
|
loss_rel < 2e-2,
|
|
"bf16 loss too far from fp32: {loss_rel:.3e}"
|
|
);
|
|
|
|
// Logits: bf16 has ~2-3 decimal digits → compare on a robust (median-style)
|
|
// basis, requiring the bulk to be within ~3e-2 and the mean error small.
|
|
let n = f_logits.len();
|
|
let mut rels: Vec<f32> = f_logits
|
|
.iter()
|
|
.zip(&b_logits)
|
|
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
|
.collect();
|
|
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
let p99 = rels[(n as f32 * 0.99) as usize];
|
|
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
|
println!("bf16 vs fp32 logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
|
assert!(mean < 1e-2, "bf16 logits mean rel err too high: {mean:.3e}");
|
|
assert!(p99 < 5e-2, "bf16 logits p99 rel err too high: {p99:.3e}");
|
|
|
|
// Gradients: fp32 master grads must be fp32 (cast op upcast), finite, and
|
|
// within loose bf16 tol of the fp32 reference (mean over each param tensor).
|
|
let mut worst_param_mean = 0.0f32;
|
|
for (fp, bp) in f_params.iter().zip(&b_params) {
|
|
let bg = bp.grad().expect("bf16 grad");
|
|
assert_eq!(bg.dtype(), DType::F32, "bf16-mode grad must be fp32 master");
|
|
let fg = host(&fp.grad().expect("fp32 grad"));
|
|
let bg = host(&bg);
|
|
assert!(bg.iter().all(|v| v.is_finite()), "bf16 grad has non-finite");
|
|
// Scale-relative mean error over the tensor (robust to a few small entries).
|
|
let scale = fg.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
|
|
let mean_err: f32 =
|
|
fg.iter().zip(&bg).map(|(f, b)| (f - b).abs()).sum::<f32>() / fg.len() as f32 / scale;
|
|
worst_param_mean = worst_param_mean.max(mean_err);
|
|
}
|
|
println!("bf16 vs fp32 grads: worst per-tensor scaled-mean err = {worst_param_mean:.3e}");
|
|
assert!(
|
|
worst_param_mean < 3e-2,
|
|
"bf16 grads too far from fp32: {worst_param_mean:.3e}"
|
|
);
|
|
}
|