Files

210 lines
7.5 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// T14 flash-attention correctness gate: the fused flash SDPA core must match the
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
// math (online softmax never materializes the [bh,S,S] scores), so it differs
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
//
// build two identical models (same init), one with `--flash` on, one off, run
// the SAME batched loss + backward on both, and assert
// 1. the forward logits match within tolerance
// 2. the loss matches
// 3. EVERY parameter's grad matches within tolerance
//
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
// composed path's fp32 softmax, so the two are still the same softmax numerics.
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
// fp32: same SDPA math, differs only by reduction order → tight per-element check.
fn run_fp32(logit_tol: f32, grad_tol: f32) {
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::F32);
let logit_rel = off_logits
.iter()
.zip(&on_logits)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!(
"[F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
logits max rel {logit_rel:.2e}"
);
assert!(
logit_rel < logit_tol,
"[F32] logits diverged: {logit_rel:.2e}"
);
assert!(loss_rel < logit_tol, "[F32] loss diverged: {loss_rel:.2e}");
let mut max_grad_rel = 0.0f32;
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
for (a, b) in off_g.iter().zip(on_g) {
max_grad_rel = max_grad_rel.max((a - b).abs() / a.abs().max(1e-3));
}
}
println!("[F32] flash on/off: grad max rel err = {max_grad_rel:.3e}");
assert!(
max_grad_rel < grad_tol,
"[F32] flash grads diverged from composed: {max_grad_rel:.3e}"
);
}
// bf16: ~2-3 decimal digits → robust comparison (mean + p99 with abs().max(1.0)
// for logits, per-tensor scale-relative mean for grads), the same convention as
// the repo's bf16.rs gate (per-element max-rel blows up on near-zero bf16 logits).
fn run_bf16() {
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::BF16);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!("[BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
assert!(loss_rel < 2e-2, "[BF16] loss diverged: {loss_rel:.3e}");
let n = off_logits.len();
let mut rels: Vec<f32> = off_logits
.iter()
.zip(&on_logits)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p99 = rels[(n as f32 * 0.99) as usize];
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
println!("[BF16] flash on/off logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(mean < 1e-2, "[BF16] logits mean rel too high: {mean:.3e}");
assert!(p99 < 5e-2, "[BF16] logits p99 rel too high: {p99:.3e}");
let mut worst = 0.0f32;
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
let scale = off_g
.iter()
.map(|v| v.abs())
.fold(0.0f32, f32::max)
.max(1e-6);
let mean_err: f32 = off_g
.iter()
.zip(on_g)
.map(|(f, b)| (f - b).abs())
.sum::<f32>()
/ off_g.len() as f32
/ scale;
worst = worst.max(mean_err);
}
println!("[BF16] flash on/off grads: worst per-tensor scaled-mean err = {worst:.3e}");
assert!(worst < 3e-2, "[BF16] flash grads diverged: {worst:.3e}");
}
#[allow(clippy::type_complexity)]
fn run_both(dtype: DType) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 4;
let batch = 3usize;
let seq = 40usize;
let seqs: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts: Vec<Vec<i32>> = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
// --- flash OFF (composed reference) ---
let off = build(cfg, device, dtype, false);
let off_logits = host(&off.forward_batched(&ids, batch).value());
let off_loss = off.loss_batched(&ids, &tgt, batch);
let off_loss_val = host(&off_loss.value())[0];
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().expect("off grad")))
.collect();
// --- flash ON ---
let on = build(cfg, device, dtype, true);
let on_logits = host(&on.forward_batched(&ids, batch).value());
let on_loss = on.loss_batched(&ids, &tgt, batch);
let on_loss_val = host(&on_loss.value())[0];
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().expect("on grad")))
.collect();
(
off_logits,
off_loss_val,
off_grads,
on_logits,
on_loss_val,
on_grads,
)
}
#[test]
fn flash_matches_composed_fp32() {
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
// cuBLAS, dK/dV atomicAdd order). Tight per-element check, not bit-exact.
run_fp32(1e-3, 2e-2);
}
#[test]
fn flash_matches_composed_bf16() {
// bf16 (T12 composition): bf16 rounding band on the fp32-softmax core; robust
// (mean/p99/scaled-mean) comparison per the repo's bf16 convention.
run_bf16();
}