Files
Gahow Wang 5b7dde1736 test: bf16 test reads f32-cast logits (forward now returns bf16)
The `keep bf16 logits` change made forward_batched return bf16 logits
in bf16 mode; the bf16 test's host read must cast to f32 first.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-16 14:29:24 +08:00

152 lines
5.7 KiB
Rust

// T12 bf16 mixed-precision correctness gate (on-GPU, no PyTorch).
//
// The SAME model (identical fp32 master weights) run in fp32 vs bf16 compute
// mode must agree within a LOOSE bf16 tolerance (bf16 = 7-bit mantissa ≈ 2-3
// decimal digits → ~1e-2 relative error is expected and acceptable), both for
// the forward loss/logits AND every parameter's gradient. We also assert no
// NaN/Inf leaks and that the fp32 grads are fp32 (the cast op upcast the bf16
// weight grad back to the fp32 master, so AdamW/clip/DDP stay fp32).
//
// This is the "bf16 within looser tol vs fp32 reference" gate; the short-run
// convergence comparison is the train_loop-level bench on dash5.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device) -> TinyTransformer {
let mut seed = 1u64;
TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
})
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_device(Device::Cpu).as_slice::<f32>().to_vec()
}
#[test]
fn bf16_matches_fp32_within_loose_tol() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// A few layers / heads so the bf16 rounding accumulates through the depth
// the real model has (not just a single matmul).
let mut cfg = Config::tiny();
cfg.vocab = 32;
cfg.n_layers = 3;
let batch = 2usize;
let seq = 8usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// fp32 reference.
let fp32 = build(cfg, device);
let f_logits = host(&fp32.forward_batched(&ids, batch).value());
let f_loss = fp32.loss_batched(&ids, &tgt, batch);
let f_loss_val = host(&f_loss.value())[0];
f_loss.backward();
let f_params = fp32.params();
// bf16 — SAME init (build re-runs the same deterministic fill). The forward
// now returns bf16 logits (CE upcasts internally); cast to f32 to read.
let bf16 = build(cfg, device).with_compute_dtype(DType::BF16);
let b_logits = host(
&bf16
.forward_batched(&ids, batch)
.value()
.to_dtype(DType::F32),
);
let b_loss = bf16.loss_batched(&ids, &tgt, batch);
let b_loss_val = host(&b_loss.value())[0];
b_loss.backward();
let b_params = bf16.params();
// No NaN/Inf in the bf16 forward.
assert!(
b_logits.iter().all(|v| v.is_finite()) && b_loss_val.is_finite(),
"bf16 forward produced non-finite values"
);
// Forward loss within loose bf16 tol.
let loss_rel = (b_loss_val - f_loss_val).abs() / f_loss_val.abs().max(1e-4);
println!("bf16 vs fp32: loss {b_loss_val:.5} vs {f_loss_val:.5} (rel {loss_rel:.3e})");
assert!(
loss_rel < 2e-2,
"bf16 loss too far from fp32: {loss_rel:.3e}"
);
// Logits: bf16 has ~2-3 decimal digits → compare on a robust (median-style)
// basis, requiring the bulk to be within ~3e-2 and the mean error small.
let n = f_logits.len();
let mut rels: Vec<f32> = f_logits
.iter()
.zip(&b_logits)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p99 = rels[(n as f32 * 0.99) as usize];
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
println!("bf16 vs fp32 logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(mean < 1e-2, "bf16 logits mean rel err too high: {mean:.3e}");
assert!(p99 < 5e-2, "bf16 logits p99 rel err too high: {p99:.3e}");
// Gradients: fp32 master grads must be fp32 (cast op upcast), finite, and
// within loose bf16 tol of the fp32 reference (mean over each param tensor).
let mut worst_param_mean = 0.0f32;
for (fp, bp) in f_params.iter().zip(&b_params) {
let bg = bp.grad().expect("bf16 grad");
assert_eq!(bg.dtype(), DType::F32, "bf16-mode grad must be fp32 master");
let fg = host(&fp.grad().expect("fp32 grad"));
let bg = host(&bg);
assert!(bg.iter().all(|v| v.is_finite()), "bf16 grad has non-finite");
// Scale-relative mean error over the tensor (robust to a few small entries).
let scale = fg.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
let mean_err: f32 =
fg.iter().zip(&bg).map(|(f, b)| (f - b).abs()).sum::<f32>() / fg.len() as f32 / scale;
worst_param_mean = worst_param_mean.max(mean_err);
}
println!("bf16 vs fp32 grads: worst per-tensor scaled-mean err = {worst_param_mean:.3e}");
assert!(
worst_param_mean < 3e-2,
"bf16 grads too far from fp32: {worst_param_mean:.3e}"
);
}