Files
xtrain/crates/xtrain-model/tests/gqa.rs
Gahow Wang 830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00

269 lines
10 KiB
Rust

// T15 GQA correctness gate. Real grouped-query attention (num_kv_heads <
// num_heads): K/V project to num_kv_heads·head_dim and are repeat_kv-broadcast to
// the full head set before the SDPA. This test pins three things:
//
// 1. GQA flash == GQA composed (forward logits + loss + EVERY param grad) — the
// repeat_kv broadcast feeds both SDPA paths unchanged, so they must agree; in
// particular the wk/wv grads (which flow back through repeat_kv's group-sum)
// must match. Parameterised over fp32 (tight) and bf16 (rounding band).
// 2. group==1 (num_kv_heads == n_heads) is BIT-IDENTICAL to the pre-T15 MHA path
// (a model with num_kv_heads explicitly == n_heads vs the default config):
// forward logits + every grad |Δ|=0. The regression guard.
// 3. wk/wv really shrank to [dim, kv_dim] under GQA (shape check).
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
// A real GQA config: 8 query heads, 2 kv heads → group 4. seq=40 > FA_TILE=32 so
// the flash online-softmax tile path is exercised too.
fn gqa_cfg() -> Config {
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 3;
// tiny() is 2 heads; rebuild with 8 query / 2 kv heads keeping head_dim=16.
Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden).with_kv_heads(2)
}
fn ids_targets(cfg: &Config, batch: usize, seq: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
let seqs = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
(seqs, tgts)
}
#[allow(clippy::type_complexity)]
fn run_both(
cfg: Config,
dtype: DType,
) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let (batch, seq) = (3usize, 40usize);
let (seqs, tgts) = ids_targets(&cfg, batch, seq);
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
let off = build(cfg, device, dtype, false);
let off_logits = host(&off.forward_batched(&ids, batch).value());
let off_loss = off.loss_batched(&ids, &tgt, batch);
let off_loss_val = host(&off_loss.value())[0];
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().expect("off grad")))
.collect();
let on = build(cfg, device, dtype, true);
let on_logits = host(&on.forward_batched(&ids, batch).value());
let on_loss = on.loss_batched(&ids, &tgt, batch);
let on_loss_val = host(&on_loss.value())[0];
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().expect("on grad")))
.collect();
(
off_logits,
off_loss_val,
off_grads,
on_logits,
on_loss_val,
on_grads,
)
}
// GQA flash vs composed: same SDPA math on the same repeat_kv-broadcast K/V → fp32
// agrees to reduction-order, bf16 to its rounding band.
#[test]
fn gqa_flash_matches_composed_fp32() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
let cfg = gqa_cfg();
assert!(cfg.num_kv_heads < cfg.n_heads, "test must be real GQA");
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(cfg, DType::F32);
let logit_rel = off_l
.iter()
.zip(&on_l)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!(
"[GQA F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
logits max rel {logit_rel:.2e}"
);
assert!(
logit_rel < 1e-3,
"[GQA F32] logits diverged: {logit_rel:.2e}"
);
assert!(loss_rel < 1e-3, "[GQA F32] loss diverged: {loss_rel:.2e}");
let mut worst = 0.0f32;
for (a_g, b_g) in off_g.iter().zip(&on_g) {
for (a, b) in a_g.iter().zip(b_g) {
worst = worst.max((a - b).abs() / a.abs().max(1e-3));
}
}
println!("[GQA F32] flash on/off grad max rel = {worst:.3e}");
assert!(worst < 2e-2, "[GQA F32] grads diverged: {worst:.3e}");
}
#[test]
fn gqa_flash_matches_composed_bf16() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(gqa_cfg(), DType::BF16);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!("[GQA BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
assert!(loss_rel < 2e-2, "[GQA BF16] loss diverged: {loss_rel:.3e}");
let n = off_l.len();
let mut rels: Vec<f32> = off_l
.iter()
.zip(&on_l)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
let p99 = rels[(n as f32 * 0.99) as usize];
println!("[GQA BF16] logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(
mean < 1e-2,
"[GQA BF16] logits mean rel too high: {mean:.3e}"
);
assert!(p99 < 5e-2, "[GQA BF16] logits p99 rel too high: {p99:.3e}");
let mut worst = 0.0f32;
for (a_g, b_g) in off_g.iter().zip(&on_g) {
let scale = a_g.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
let mean_err: f32 =
a_g.iter().zip(b_g).map(|(f, b)| (f - b).abs()).sum::<f32>() / a_g.len() as f32 / scale;
worst = worst.max(mean_err);
}
println!("[GQA BF16] grads: worst per-tensor scaled-mean err = {worst:.3e}");
assert!(worst < 3e-2, "[GQA BF16] grads diverged: {worst:.3e}");
}
// REGRESSION GUARD: num_kv_heads == n_heads (group 1) must be BIT-IDENTICAL to the
// pre-T15 MHA path. Build one model with the default config (num_kv_heads ==
// n_heads, the untouched path: repeat_kv not even invoked) and one that explicitly
// sets num_kv_heads = n_heads, then assert forward logits + every grad match to the
// bit. (Same composed path, so this is exact equality, not a tolerance.)
#[test]
fn gqa_group1_bit_identical_to_mha() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut base = Config::tiny();
base.vocab = 16;
base.n_layers = 3;
let base = Config::from_arch(base.vocab, 4, base.head_dim, base.n_layers, base.ffn_hidden);
// `explicit` sets num_kv_heads = n_heads (already the default, but exercises the
// with_kv_heads path); they are the same config → must produce identical output.
let explicit = base.with_kv_heads(base.n_heads);
assert_eq!(base.num_kv_heads, explicit.num_kv_heads);
let (batch, seq) = (2usize, 8usize);
let (seqs, tgts) = ids_targets(&base, batch, seq);
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
let run = |cfg: Config| -> (Vec<f32>, f32, Vec<Vec<f32>>) {
let m = build(cfg, device, DType::F32, false);
let logits = host(&m.forward_batched(&ids, batch).value());
let loss = m.loss_batched(&ids, &tgt, batch);
let loss_v = host(&loss.value())[0];
loss.backward();
let grads = m
.params()
.iter()
.map(|p| host(&p.grad().unwrap()))
.collect();
(logits, loss_v, grads)
};
let (la, sa, ga) = run(base);
let (lb, sb, gb) = run(explicit);
assert_eq!(la, lb, "group-1 logits must be bit-identical to MHA");
assert_eq!(sa, sb, "group-1 loss must be bit-identical to MHA");
for (a, b) in ga.iter().zip(&gb) {
assert_eq!(a, b, "group-1 grad must be bit-identical to MHA");
}
println!("[GQA group1] bit-identical to MHA: logits + loss + all grads |Δ|=0");
}
// Under GQA, wk/wv must be [dim, kv_dim] (= num_kv_heads·head_dim), wq stays [dim,dim].
#[test]
fn gqa_kv_proj_shape() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let cfg = gqa_cfg();
let m = build(cfg, device, DType::F32, false);
let p = m.params();
// params order: embed, then per block [attn_norm, wq, wk, wv, q_norm, k_norm, wo, ...]
let wq = p[1].value().shape().to_vec();
let wk = p[2].value().shape().to_vec();
let wv = p[3].value().shape().to_vec();
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");
println!(
"[GQA shapes] wq {:?} wk {:?} wv {:?} (kv_dim {})",
wq,
wk,
wv,
cfg.kv_dim()
);
}