270 lines
10 KiB
Rust
270 lines
10 KiB
Rust
// T15 GQA correctness gate. Real grouped-query attention (num_kv_heads <
|
|
// num_heads): K/V project to num_kv_heads·head_dim and are repeat_kv-broadcast to
|
|
// the full head set before the SDPA. This test pins three things:
|
|
//
|
|
// 1. GQA flash == GQA composed (forward logits + loss + EVERY param grad) — the
|
|
// repeat_kv broadcast feeds both SDPA paths unchanged, so they must agree; in
|
|
// particular the wk/wv grads (which flow back through repeat_kv's group-sum)
|
|
// must match. Parameterised over fp32 (tight) and bf16 (rounding band).
|
|
// 2. group==1 (num_kv_heads == n_heads) is BIT-IDENTICAL to the pre-T15 MHA path
|
|
// (a model with num_kv_heads explicitly == n_heads vs the default config):
|
|
// forward logits + every grad |Δ|=0. The regression guard.
|
|
// 3. wk/wv really shrank to [dim, kv_dim] under GQA (shape check).
|
|
#![cfg(not(no_cuda))]
|
|
|
|
use xtrain_cuda::device;
|
|
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
|
|
use xtrain_tensor::{DType, Device};
|
|
|
|
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
|
let mut state = seed
|
|
.wrapping_mul(2862933555777941757)
|
|
.wrapping_add(3037000493);
|
|
(0..n)
|
|
.map(|_| {
|
|
state = state
|
|
.wrapping_mul(6364136223846793005)
|
|
.wrapping_add(1442695040888963407);
|
|
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
|
|
let mut seed = 1u64;
|
|
let m = TinyTransformer::new(cfg, device, |shape| {
|
|
seed = seed.wrapping_add(1);
|
|
let n: usize = shape.iter().product();
|
|
if shape.len() == 1 {
|
|
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
|
} else {
|
|
fill(n, seed, 0.08)
|
|
}
|
|
});
|
|
m.with_compute_dtype(dtype).with_flash(flash)
|
|
}
|
|
|
|
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
|
|
t.to_dtype(DType::F32)
|
|
.to_device(Device::Cpu)
|
|
.as_slice::<f32>()
|
|
.to_vec()
|
|
}
|
|
|
|
// A real GQA config: 8 query heads, 2 kv heads → group 4. seq=40 > FA_TILE=32 so
|
|
// the flash online-softmax tile path is exercised too.
|
|
fn gqa_cfg() -> Config {
|
|
let mut cfg = Config::tiny();
|
|
cfg.vocab = 16;
|
|
cfg.n_layers = 3;
|
|
// tiny() is 2 heads; rebuild with 8 query / 2 kv heads keeping head_dim=16.
|
|
Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden).with_kv_heads(2)
|
|
}
|
|
|
|
fn ids_targets(cfg: &Config, batch: usize, seq: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
|
|
let seqs = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let tgts = (0..batch)
|
|
.map(|b| {
|
|
(0..seq)
|
|
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
|
|
.collect()
|
|
})
|
|
.collect();
|
|
(seqs, tgts)
|
|
}
|
|
|
|
#[allow(clippy::type_complexity)]
|
|
fn run_both(
|
|
cfg: Config,
|
|
dtype: DType,
|
|
) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
let (batch, seq) = (3usize, 40usize);
|
|
let (seqs, tgts) = ids_targets(&cfg, batch, seq);
|
|
let ids = batched_ids_tensor(&seqs, device);
|
|
let tgt = batched_ids_tensor(&tgts, device);
|
|
|
|
let off = build(cfg, device, dtype, false);
|
|
let off_logits = host(&off.forward_batched(&ids, batch).value());
|
|
let off_loss = off.loss_batched(&ids, &tgt, batch);
|
|
let off_loss_val = host(&off_loss.value())[0];
|
|
off_loss.backward();
|
|
let off_grads: Vec<Vec<f32>> = off
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().expect("off grad")))
|
|
.collect();
|
|
|
|
let on = build(cfg, device, dtype, true);
|
|
let on_logits = host(&on.forward_batched(&ids, batch).value());
|
|
let on_loss = on.loss_batched(&ids, &tgt, batch);
|
|
let on_loss_val = host(&on_loss.value())[0];
|
|
on_loss.backward();
|
|
let on_grads: Vec<Vec<f32>> = on
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().expect("on grad")))
|
|
.collect();
|
|
|
|
(
|
|
off_logits,
|
|
off_loss_val,
|
|
off_grads,
|
|
on_logits,
|
|
on_loss_val,
|
|
on_grads,
|
|
)
|
|
}
|
|
|
|
// GQA flash vs composed: same SDPA math on the same repeat_kv-broadcast K/V → fp32
|
|
// agrees to reduction-order, bf16 to its rounding band.
|
|
#[test]
|
|
fn gqa_flash_matches_composed_fp32() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
let cfg = gqa_cfg();
|
|
assert!(cfg.num_kv_heads < cfg.n_heads, "test must be real GQA");
|
|
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(cfg, DType::F32);
|
|
|
|
let logit_rel = off_l
|
|
.iter()
|
|
.zip(&on_l)
|
|
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
|
|
.fold(0.0f32, f32::max);
|
|
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
|
println!(
|
|
"[GQA F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
|
|
logits max rel {logit_rel:.2e}"
|
|
);
|
|
assert!(
|
|
logit_rel < 1e-3,
|
|
"[GQA F32] logits diverged: {logit_rel:.2e}"
|
|
);
|
|
assert!(loss_rel < 1e-3, "[GQA F32] loss diverged: {loss_rel:.2e}");
|
|
|
|
let mut worst = 0.0f32;
|
|
for (a_g, b_g) in off_g.iter().zip(&on_g) {
|
|
for (a, b) in a_g.iter().zip(b_g) {
|
|
worst = worst.max((a - b).abs() / a.abs().max(1e-3));
|
|
}
|
|
}
|
|
println!("[GQA F32] flash on/off grad max rel = {worst:.3e}");
|
|
assert!(worst < 2e-2, "[GQA F32] grads diverged: {worst:.3e}");
|
|
}
|
|
|
|
#[test]
|
|
fn gqa_flash_matches_composed_bf16() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(gqa_cfg(), DType::BF16);
|
|
|
|
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
|
|
println!("[GQA BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
|
|
assert!(loss_rel < 2e-2, "[GQA BF16] loss diverged: {loss_rel:.3e}");
|
|
|
|
let n = off_l.len();
|
|
let mut rels: Vec<f32> = off_l
|
|
.iter()
|
|
.zip(&on_l)
|
|
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
|
|
.collect();
|
|
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
|
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
|
|
let p99 = rels[(n as f32 * 0.99) as usize];
|
|
println!("[GQA BF16] logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
|
|
assert!(
|
|
mean < 1e-2,
|
|
"[GQA BF16] logits mean rel too high: {mean:.3e}"
|
|
);
|
|
assert!(p99 < 5e-2, "[GQA BF16] logits p99 rel too high: {p99:.3e}");
|
|
|
|
let mut worst = 0.0f32;
|
|
for (a_g, b_g) in off_g.iter().zip(&on_g) {
|
|
let scale = a_g.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
|
|
let mean_err: f32 =
|
|
a_g.iter().zip(b_g).map(|(f, b)| (f - b).abs()).sum::<f32>() / a_g.len() as f32 / scale;
|
|
worst = worst.max(mean_err);
|
|
}
|
|
println!("[GQA BF16] grads: worst per-tensor scaled-mean err = {worst:.3e}");
|
|
assert!(worst < 3e-2, "[GQA BF16] grads diverged: {worst:.3e}");
|
|
}
|
|
|
|
// REGRESSION GUARD: num_kv_heads == n_heads (group 1) must be BIT-IDENTICAL to the
|
|
// pre-T15 MHA path. Build one model with the default config (num_kv_heads ==
|
|
// n_heads, the untouched path: repeat_kv not even invoked) and one that explicitly
|
|
// sets num_kv_heads = n_heads, then assert forward logits + every grad match to the
|
|
// bit. (Same composed path, so this is exact equality, not a tolerance.)
|
|
#[test]
|
|
fn gqa_group1_bit_identical_to_mha() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
|
|
let mut base = Config::tiny();
|
|
base.vocab = 16;
|
|
base.n_layers = 3;
|
|
let base = Config::from_arch(base.vocab, 4, base.head_dim, base.n_layers, base.ffn_hidden);
|
|
// `explicit` sets num_kv_heads = n_heads (already the default, but exercises the
|
|
// with_kv_heads path); they are the same config → must produce identical output.
|
|
let explicit = base.with_kv_heads(base.n_heads);
|
|
assert_eq!(base.num_kv_heads, explicit.num_kv_heads);
|
|
|
|
let (batch, seq) = (2usize, 8usize);
|
|
let (seqs, tgts) = ids_targets(&base, batch, seq);
|
|
let ids = batched_ids_tensor(&seqs, device);
|
|
let tgt = batched_ids_tensor(&tgts, device);
|
|
|
|
let run = |cfg: Config| -> (Vec<f32>, f32, Vec<Vec<f32>>) {
|
|
let m = build(cfg, device, DType::F32, false);
|
|
let logits = host(&m.forward_batched(&ids, batch).value());
|
|
let loss = m.loss_batched(&ids, &tgt, batch);
|
|
let loss_v = host(&loss.value())[0];
|
|
loss.backward();
|
|
let grads = m
|
|
.params()
|
|
.iter()
|
|
.map(|p| host(&p.grad().unwrap()))
|
|
.collect();
|
|
(logits, loss_v, grads)
|
|
};
|
|
let (la, sa, ga) = run(base);
|
|
let (lb, sb, gb) = run(explicit);
|
|
assert_eq!(la, lb, "group-1 logits must be bit-identical to MHA");
|
|
assert_eq!(sa, sb, "group-1 loss must be bit-identical to MHA");
|
|
for (a, b) in ga.iter().zip(&gb) {
|
|
assert_eq!(a, b, "group-1 grad must be bit-identical to MHA");
|
|
}
|
|
println!("[GQA group1] bit-identical to MHA: logits + loss + all grads |Δ|=0");
|
|
}
|
|
|
|
// Under GQA, wk/wv must be [dim, kv_dim] (= num_kv_heads·head_dim), wq stays [dim,dim].
|
|
#[test]
|
|
fn gqa_kv_proj_shape() {
|
|
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
|
device::set_device(0).unwrap();
|
|
let device = Device::Cuda(0);
|
|
let cfg = gqa_cfg();
|
|
let m = build(cfg, device, DType::F32, false);
|
|
let p = m.params();
|
|
// params order: embed[0], then block 0 = [attn_norm[1], wq[2], wk[3], wv[4],
|
|
// q_norm[5], k_norm[6], wo[7], ...]
|
|
let wq = p[2].value().shape().to_vec();
|
|
let wk = p[3].value().shape().to_vec();
|
|
let wv = p[4].value().shape().to_vec();
|
|
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
|
|
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
|
|
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");
|
|
println!(
|
|
"[GQA shapes] wq {:?} wk {:?} wv {:?} (kv_dim {})",
|
|
wq,
|
|
wk,
|
|
wv,
|
|
cfg.kv_dim()
|
|
);
|
|
}
|