210 lines
7.5 KiB
Rust
210 lines
7.5 KiB
Rust
// T14 flash-attention correctness gate: the fused flash SDPA core must match the
|
||
// composed T10 path (cublasSgemmStridedBatched×2 + causal-softmax kernel) in
|
||
// forward logits, loss, AND every parameter gradient — flash is the SAME SDPA
|
||
// math (online softmax never materializes the [bh,S,S] scores), so it differs
|
||
// from composed only by reduction order (in-kernel fp32 FMA vs cuBLAS, and the
|
||
// dK/dV atomicAdd order in backward). This test makes that a closed on-GPU loop:
|
||
//
|
||
// build two identical models (same init), one with `--flash` on, one off, run
|
||
// the SAME batched loss + backward on both, and assert
|
||
// 1. the forward logits match within tolerance
|
||
// 2. the loss matches
|
||
// 3. EVERY parameter's grad matches within tolerance
|
||
//
|
||
// Parameterised over fp32 AND bf16 (T12). bf16 just adds the bf16 rounding band on
|
||
// top — flash's bf16 path upcasts Q/K/V to fp32 for the kernel exactly like the
|
||
// composed path's fp32 softmax, so the two are still the same softmax numerics.
|
||
#![cfg(not(no_cuda))]
|
||
|
||
use xtrain_cuda::device;
|
||
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
||
use xtrain_tensor::{DType, Device};
|
||
|
||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||
let mut state = seed
|
||
.wrapping_mul(2862933555777941757)
|
||
.wrapping_add(3037000493);
|
||
(0..n)
|
||
.map(|_| {
|
||
state = state
|
||
.wrapping_mul(6364136223846793005)
|
||
.wrapping_add(1442695040888963407);
|
||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
||
let mut seed = 1u64;
|
||
let m = TinyTransformer::new(cfg, device, |shape| {
|
||
seed = seed.wrapping_add(1);
|
||
let n: usize = shape.iter().product();
|
||
if shape.len() == 1 {
|
||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||
} else {
|
||
fill(n, seed, 0.08)
|
||
}
|
||
});
|
||
m.with_compute_dtype(dtype).with_flash(flash)
|
||
}
|
||
|
||
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
||
t.to_dtype(DType::F32)
|
||
.to_device(Device::Cpu)
|
||
.as_slice::<f32>()
|
||
.to_vec()
|
||
}
|
||
|
||
// fp32: same SDPA math, differs only by reduction order → tight per-element check.
|
||
fn run_fp32(logit_tol: f32, grad_tol: f32) {
|
||
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::F32);
|
||
|
||
let logit_rel = off_logits
|
||
.iter()
|
||
.zip(&on_logits)
|
||
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
||
.fold(0.0f32, f32::max);
|
||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||
println!(
|
||
"[F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
|
||
logits max rel {logit_rel:.2e}"
|
||
);
|
||
assert!(
|
||
logit_rel < logit_tol,
|
||
"[F32] logits diverged: {logit_rel:.2e}"
|
||
);
|
||
assert!(loss_rel < logit_tol, "[F32] loss diverged: {loss_rel:.2e}");
|
||
|
||
let mut max_grad_rel = 0.0f32;
|
||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||
for (a, b) in off_g.iter().zip(on_g) {
|
||
max_grad_rel = max_grad_rel.max((a - b).abs() / a.abs().max(1e-3));
|
||
}
|
||
}
|
||
println!("[F32] flash on/off: grad max rel err = {max_grad_rel:.3e}");
|
||
assert!(
|
||
max_grad_rel < grad_tol,
|
||
"[F32] flash grads diverged from composed: {max_grad_rel:.3e}"
|
||
);
|
||
}
|
||
|
||
// bf16: ~2-3 decimal digits → robust comparison (mean + p99 with abs().max(1.0)
|
||
// for logits, per-tensor scale-relative mean for grads), the same convention as
|
||
// the repo's bf16.rs gate (per-element max-rel blows up on near-zero bf16 logits).
|
||
fn run_bf16() {
|
||
let (off_logits, off_loss, off_grads, on_logits, on_loss, on_grads) = run_both(DType::BF16);
|
||
|
||
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
||
println!("[BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
|
||
assert!(loss_rel < 2e-2, "[BF16] loss diverged: {loss_rel:.3e}");
|
||
|
||
let n = off_logits.len();
|
||
let mut rels: Vec<f32> = off_logits
|
||
.iter()
|
||
.zip(&on_logits)
|
||
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
||
.collect();
|
||
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||
let p99 = rels[(n as f32 * 0.99) as usize];
|
||
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
||
println!("[BF16] flash on/off logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
||
assert!(mean < 1e-2, "[BF16] logits mean rel too high: {mean:.3e}");
|
||
assert!(p99 < 5e-2, "[BF16] logits p99 rel too high: {p99:.3e}");
|
||
|
||
let mut worst = 0.0f32;
|
||
for (off_g, on_g) in off_grads.iter().zip(&on_grads) {
|
||
let scale = off_g
|
||
.iter()
|
||
.map(|v| v.abs())
|
||
.fold(0.0f32, f32::max)
|
||
.max(1e-6);
|
||
let mean_err: f32 = off_g
|
||
.iter()
|
||
.zip(on_g)
|
||
.map(|(f, b)| (f - b).abs())
|
||
.sum::<f32>()
|
||
/ off_g.len() as f32
|
||
/ scale;
|
||
worst = worst.max(mean_err);
|
||
}
|
||
println!("[BF16] flash on/off grads: worst per-tensor scaled-mean err = {worst:.3e}");
|
||
assert!(worst < 3e-2, "[BF16] flash grads diverged: {worst:.3e}");
|
||
}
|
||
|
||
#[allow(clippy::type_complexity)]
|
||
fn run_both(dtype: DType) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
|
||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||
device::set_device(0).unwrap();
|
||
let device = Device::Cuda(0);
|
||
|
||
// seq=40 > FA_TILE=32 so the online-softmax tile-rescale path is exercised.
|
||
let mut cfg = Config::tiny();
|
||
cfg.vocab = 16;
|
||
cfg.n_layers = 4;
|
||
let batch = 3usize;
|
||
let seq = 40usize;
|
||
let seqs: Vec<Vec<i32>> = (0..batch)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
let tgts: Vec<Vec<i32>> = (0..batch)
|
||
.map(|b| {
|
||
(0..seq)
|
||
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
||
.collect()
|
||
})
|
||
.collect();
|
||
let ids = batched_ids_tensor(&seqs, device);
|
||
let tgt = batched_ids_tensor(&tgts, device);
|
||
|
||
// --- flash OFF (composed reference) ---
|
||
let off = build(cfg, device, dtype, false);
|
||
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
||
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
||
let off_loss_val = host(&off_loss.value())[0];
|
||
off_loss.backward();
|
||
let off_grads: Vec<Vec<f32>> = off
|
||
.params()
|
||
.iter()
|
||
.map(|p| host(&p.grad().expect("off grad")))
|
||
.collect();
|
||
|
||
// --- flash ON ---
|
||
let on = build(cfg, device, dtype, true);
|
||
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
||
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
||
let on_loss_val = host(&on_loss.value())[0];
|
||
on_loss.backward();
|
||
let on_grads: Vec<Vec<f32>> = on
|
||
.params()
|
||
.iter()
|
||
.map(|p| host(&p.grad().expect("on grad")))
|
||
.collect();
|
||
|
||
(
|
||
off_logits,
|
||
off_loss_val,
|
||
off_grads,
|
||
on_logits,
|
||
on_loss_val,
|
||
on_grads,
|
||
)
|
||
}
|
||
|
||
#[test]
|
||
fn flash_matches_composed_fp32() {
|
||
// fp32: same SDPA math, differs only by reduction order (in-kernel fp32 FMA vs
|
||
// cuBLAS, dK/dV atomicAdd order). Tight per-element check, not bit-exact.
|
||
run_fp32(1e-3, 2e-2);
|
||
}
|
||
|
||
#[test]
|
||
fn flash_matches_composed_bf16() {
|
||
// bf16 (T12 composition): bf16 rounding band on the fp32-softmax core; robust
|
||
// (mean/p99/scaled-mean) comparison per the repo's bf16 convention.
|
||
run_bf16();
|
||
}
|