Files
xtrain/crates/xtrain-autodiff/tests/autograd.rs
Gahow Wang 830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00

1008 lines
33 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// GPU acceptance tests for the Phase T4 autograd engine + per-op backward.
// Pattern (from xtrain-tensor/tests/gemm.rs `run_bwd`): build a scalar loss
// L = sum(W ∘ out) with W fixed random ⇒ the upstream grad dOut = W. Run the op
// through the tape, call backward(), and grad-check each input's .grad() against
// central finite differences of L.
//
// Gated behind `not(no_cuda)`: compiles out on a GPU-less host, runs on dash5.
#![cfg(not(no_cuda))]
use xtrain_autodiff::ops;
use xtrain_autodiff::tape::Var;
use xtrain_autodiff::{GradCheckConfig, grad_check};
use xtrain_cuda::device;
use xtrain_tensor::{Device, Tensor};
// Deterministic LCG fill in [-0.5, 0.5), same as the gemm tests.
fn fill(n: usize, seed: u64) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5
})
.collect()
}
fn require_gpu() {
assert!(
device::device_count().expect("device count") > 0,
"no CUDA device"
);
device::set_device(0).unwrap();
}
fn cuda(data: &[f32], shape: &[usize]) -> Tensor {
Tensor::from_slice(data, shape).to_device(Device::Cuda(0))
}
// L = sum(W ∘ out) for fixed weights W over the op output.
fn weighted_sum(out: &Tensor, w: &[f32]) -> f32 {
out.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.zip(w)
.map(|(o, w)| o * w)
.sum()
}
// Tolerances: ops with elementwise/linear forwards (add, mul, scale, bias, rope)
// are exactly linear in each input, so a large eps just sharpens f32 resolution.
// Nonlinear ops (rms_norm, silu, softmax, cross_entropy) carry O(eps²) truncation
// → smaller eps. atol floors near-zero grads.
fn cfg_linear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-2,
rel_tol: 2e-2,
atol: 1e-3,
}
}
fn cfg_nonlinear() -> GradCheckConfig {
GradCheckConfig {
eps: 1e-3,
rel_tol: 3e-2,
atol: 1e-3,
}
}
fn report(name: &str, res: &xtrain_autodiff::GradCheckResult) {
println!(
"{name}: max_rel_err = {:.3e} (worst num={:.5} ana={:.5} @ {})",
res.max_rel_err, res.worst_numeric, res.worst_analytic, res.worst_index
);
assert!(res.passed, "{name} grad-check failed: {res:?}");
}
// ---- add ----
#[test]
fn add_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 1);
let b_h = fill(m * n, 2);
let w = fill(m * n, 3);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::add(&a, &b);
let loss = scalar_loss(&out, &w);
loss.backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| {
let o = cuda(v, s).add(&cuda(&bf, &[m, n]));
weighted_sum(&o, &wf)
};
report(
"add dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| {
let o = cuda(&af, &[m, n]).add(&cuda(v, s));
weighted_sum(&o, &wf)
};
report(
"add dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- mul ----
#[test]
fn mul_bwd() {
require_gpu();
let (m, n) = (8, 6);
let a_h = fill(m * n, 11);
let b_h = fill(m * n, 22);
let w = fill(m * n, 33);
let a = Var::leaf(cuda(&a_h, &[m, n]));
let b = Var::leaf(cuda(&b_h, &[m, n]));
let out = ops::mul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).mul(&cuda(&bf, &[m, n])), &wf);
report(
"mul dA",
&grad_check(&a_h, &[m, n], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, n]).mul(&cuda(v, s)), &wf);
report(
"mul dB",
&grad_check(&b_h, &[m, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- add_bias (broadcast) ----
#[test]
fn add_bias_bwd() {
require_gpu();
let (m, n) = (10, 7);
let x_h = fill(m * n, 5);
let b_h = fill(n, 6);
let w = fill(m * n, 7);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let bias = Var::leaf(cuda(&b_h, &[n]));
let out = ops::add_bias(&x, &bias);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dbias = bias.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let lx =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).add_bias(&cuda(&bf, &[n])), &wf);
report(
"add_bias dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
let xf = x_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&xf, &[m, n]).add_bias(&cuda(v, s)), &wf);
report(
"add_bias dBias",
&grad_check(&b_h, &[n], &lb, dbias.as_slice::<f32>(), cfg_linear()),
);
}
// ---- matmul (sanity through the Var layer; T3 already checks the kernel) ----
#[test]
fn matmul_bwd() {
require_gpu();
let (m, k, n) = (6, 5, 4);
let a_h = fill(m * k, 41);
let b_h = fill(k * n, 42);
let w = fill(m * n, 43);
let a = Var::leaf(cuda(&a_h, &[m, k]));
let b = Var::leaf(cuda(&b_h, &[k, n]));
let out = ops::matmul(&a, &b);
scalar_loss(&out, &w).backward();
let da = a.grad().unwrap().to_device(Device::Cpu);
let db = b.grad().unwrap().to_device(Device::Cpu);
let bf = b_h.clone();
let wf = w.clone();
let la =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).matmul(&cuda(&bf, &[k, n])), &wf);
report(
"matmul dA",
&grad_check(&a_h, &[m, k], &la, da.as_slice::<f32>(), cfg_linear()),
);
let af = a_h.clone();
let wf = w.clone();
let lb =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&af, &[m, k]).matmul(&cuda(v, s)), &wf);
report(
"matmul dB",
&grad_check(&b_h, &[k, n], &lb, db.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rms_norm ----
#[test]
fn rms_norm_bwd() {
require_gpu();
let (rows, cols) = (5, 16);
let eps = 1e-5;
let x_h = fill(rows * cols, 51);
let g_h: Vec<f32> = fill(cols, 52).iter().map(|v| v + 1.0).collect(); // gamma ~1
let w = fill(rows * cols, 53);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let gamma = Var::leaf(cuda(&g_h, &[cols]));
let out = ops::rms_norm(&x, &gamma, eps);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let dg = gamma.grad().unwrap().to_device(Device::Cpu);
let gf = g_h.clone();
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(v, s).rms_norm(&cuda(&gf, &[cols]), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let xf = x_h.clone();
let wf = w.clone();
let lg = move |v: &[f32], s: &[usize]| {
let (o, _) = cuda(&xf, &[rows, cols]).rms_norm(&cuda(v, s), eps);
weighted_sum(&o, &wf)
};
report(
"rms_norm dGamma",
&grad_check(&g_h, &[cols], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- silu ----
#[test]
fn silu_bwd() {
require_gpu();
let n = 64;
let x_h = fill(n, 61);
let w = fill(n, 62);
let x = Var::leaf(cuda(&x_h, &[n]));
let out = ops::silu(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu(), &wf);
report(
"silu dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_nonlinear()),
);
}
// ---- swiglu (composed: silu(gate) ∘ up) ----
#[test]
fn swiglu_bwd() {
require_gpu();
let n = 48;
let g_h = fill(n, 71);
let u_h = fill(n, 72);
let w = fill(n, 73);
let gate = Var::leaf(cuda(&g_h, &[n]));
let up = Var::leaf(cuda(&u_h, &[n]));
let out = ops::swiglu(&gate, &up);
scalar_loss(&out, &w).backward();
let dg = gate.grad().unwrap().to_device(Device::Cpu);
let du = up.grad().unwrap().to_device(Device::Cpu);
let uf = u_h.clone();
let wf = w.clone();
let lg =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).silu().mul(&cuda(&uf, &[n])), &wf);
report(
"swiglu dGate",
&grad_check(&g_h, &[n], &lg, dg.as_slice::<f32>(), cfg_nonlinear()),
);
let gf = g_h.clone();
let wf = w.clone();
let lu =
move |v: &[f32], s: &[usize]| weighted_sum(&cuda(&gf, &[n]).silu().mul(&cuda(v, s)), &wf);
report(
"swiglu dUp",
&grad_check(&u_h, &[n], &lu, du.as_slice::<f32>(), cfg_linear()),
);
}
// ---- rope ----
#[test]
fn rope_bwd() {
require_gpu();
let (tokens, heads, head_dim) = (4, 2, 8);
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 81);
let w = fill(n, 82);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, tokens);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).rope(theta, tokens), &wf);
report(
"rope dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- rope batched (per-sequence position = row % period) ----
// tokens = B*S laid end to end; period = S. Sequences 2 and 3 re-use positions
// 0..S, so the kernel's `tok % period` must reset RoPE per sequence.
#[test]
fn rope_batched_bwd() {
require_gpu();
let (b, s, heads, head_dim) = (3, 4, 2, 8);
let tokens = b * s;
let n = tokens * heads * head_dim;
let theta = 10000.0;
let x_h = fill(n, 83);
let w = fill(n, 84);
let x = Var::leaf(cuda(&x_h, &[tokens, heads, head_dim]));
let out = ops::rope(&x, theta, s);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], sh: &[usize]| weighted_sum(&cuda(v, sh).rope(theta, s), &wf);
report(
"rope batched dX",
&grad_check(
&x_h,
&[tokens, heads, head_dim],
&lx,
dx.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- softmax ----
#[test]
fn softmax_bwd() {
require_gpu();
let (rows, cols) = (4, 10);
let x_h = fill(rows * cols, 91);
let w = fill(rows * cols, 92);
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let out = ops::softmax(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).softmax(), &wf);
report(
"softmax dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- cross_entropy (scalar loss; backward = (softmax - onehot)/rows) ----
#[test]
fn cross_entropy_bwd() {
require_gpu();
let (rows, cols) = (5, 8);
let x_h = fill(rows * cols, 101);
let targets: Vec<i32> = (0..rows).map(|r| (r * 3 % cols) as i32).collect();
let target = Tensor::from_slice(&targets, &[rows]).to_device(Device::Cuda(0));
let x = Var::leaf(cuda(&x_h, &[rows, cols]));
let loss = ops::cross_entropy(&x, &target);
loss.backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
// Loss is already scalar (mean NLL) — grad-check it directly, no W weighting.
let tgt = targets.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = Tensor::from_slice(&tgt, &[rows]).to_device(Device::Cuda(0));
let (_, per_row) = cuda(v, s).cross_entropy(&t);
per_row
.to_device(Device::Cpu)
.as_slice::<f32>()
.iter()
.sum::<f32>()
/ rows as f32
};
report(
"cross_entropy dX",
&grad_check(
&x_h,
&[rows, cols],
&lx,
dx.as_slice::<f32>(),
cfg_nonlinear(),
),
);
}
// ---- FAN-OUT: a tensor feeding two consumers must SUM grads ----
// y = x*x + x*x via two separate mul nodes on the same Var x → dL/dx must be the
// sum of both branches. With W=1, out=2x², so dOut=W=1 and dx (numeric) = 4x.
#[test]
fn fanout_grad_accumulation() {
require_gpu();
let n = 12;
let x_h = fill(n, 111);
let w = vec![1.0f32; n];
let x = Var::leaf(cuda(&x_h, &[n]));
let sq1 = ops::mul(&x, &x); // x∘x (x consumed twice within one node)
let sq2 = ops::mul(&x, &x); // x∘x (x consumed again across nodes)
let out = ops::add(&sq1, &sq2); // 2x²
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let t = cuda(v, s);
let o = t.mul(&t).add(&t.mul(&t));
weighted_sum(&o, &wf)
};
// Analytic dx should be 4x; fan-out summed all four uses of x.
report(
"fanout dX",
&grad_check(&x_h, &[n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- COMPOSED ATTENTION: attn = matmul(softmax(matmul(Q,Kᵀ)·scale), V) ----
// Single head, single batch. Backward falls out of matmul+scale+softmax nodes.
#[test]
fn attention_composed_bwd() {
require_gpu();
let (s, d) = (5, 6); // seq_len, head_dim
let scale = 1.0 / (d as f32).sqrt();
let q_h = fill(s * d, 121);
let k_h = fill(s * d, 122);
let v_h = fill(s * d, 123);
let w = fill(s * d, 124); // weights over the [s,d] attention output
let attn = |q: &Var, k: &Var, v: &Var| -> Var {
let kt = transpose_var(k); // [d,s] (manual transpose node)
let scores = ops::scale(&ops::matmul(q, &kt), scale); // [s,s]
let probs = ops::softmax(&scores);
ops::matmul(&probs, v) // [s,d]
};
let q = Var::leaf(cuda(&q_h, &[s, d]));
let k = Var::leaf(cuda(&k_h, &[s, d]));
let v = Var::leaf(cuda(&v_h, &[s, d]));
let out = attn(&q, &k, &v);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
// Re-run the same forward inside the loss closures (host-side) per input.
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[s, d]);
let kv = cuda(kh, &[s, d]);
let vv = cuda(vh, &[s, d]);
let scores = qv.matmul(&kv.transpose_2d()).scale(scale);
let probs = scores.softmax();
weighted_sum(&probs.matmul(&vv), &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn dQ",
&grad_check(&q_h, &[s, d], &lq, dq.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn dK",
&grad_check(&k_h, &[s, d], &lk, dk.as_slice::<f32>(), cfg_nonlinear()),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn dV",
&grad_check(&v_h, &[s, d], &lv, dv.as_slice::<f32>(), cfg_linear()),
);
}
// ---- transpose_4d12 ([a,b,c,d] -> [a,c,b,d]) ----
#[test]
fn transpose_4d12_bwd() {
require_gpu();
let (a, b, c, d) = (2, 3, 4, 5);
let n = a * b * c * d;
let x_h = fill(n, 131);
let w = fill(n, 132);
let x = Var::leaf(cuda(&x_h, &[a, b, c, d]));
let out = ops::transpose_4d12(&x);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| weighted_sum(&cuda(v, s).transpose_4d12(), &wf);
report(
"transpose_4d12 dX",
&grad_check(&x_h, &[a, b, c, d], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// ---- fused batched causal attention (the T10 op) ----
// q,k,v: [bh, seq, hd]. Grad-check dq/dk/dv against finite-diff of L = sum(W∘out).
// bh = 2 (e.g. batch 1 × 2 heads, or 2 sequences × 1 head) exercises the batched
// GEMM stride; the causal mask is applied inside the op.
#[test]
fn attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 141);
let k_h = fill(n, 142);
let v_h = fill(n, 143);
let w = fill(n, 144);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"attn(batched) dQ",
&grad_check(
&q_h,
&[bh, seq, hd],
&lq,
dq.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"attn(batched) dK",
&grad_check(
&k_h,
&[bh, seq, hd],
&lk,
dk.as_slice::<f32>(),
cfg_nonlinear(),
),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"attn(batched) dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// ---- fused FLASH causal attention (the T14 op) ----
// Same structure + dimensions as attention_batched_bwd (bh=2,seq=5,hd=6), but
// exercises ops::flash_attention. Grad-check dq/dk/dv against finite-diff of
// L=sum(W∘out). This is the SINGLE-tile regime (seq<FA_TILE=32), matching the
// trusted composed grad-check's clean near-zero behavior; the MULTI-tile online-
// softmax path (seq>FA_TILE) is validated against the already-grad-checked
// composed backward by `flash_bwd_matches_composed_bwd` (seq=40) — sharper than
// finite-diff, which is unreliable on the near-zero grad elements a long softmax
// produces.
#[test]
fn flash_attention_batched_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 5, 6);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
// Scale Q/K up so the softmax is non-uniform (sharper attention) → the dQ/dK
// gradients are well-conditioned, not the near-zero saddle values a uniform
// softmax produces (those make central finite-diff give spurious 0.0 / sign
// flips that aren't backward bugs — cf. flash_bwd_matches_composed_bwd).
let q_h: Vec<f32> = fill(n, 241).iter().map(|v| v * 2.5).collect();
let k_h: Vec<f32> = fill(n, 242).iter().map(|v| v * 2.5).collect();
let v_h = fill(n, 243);
let w = fill(n, 244);
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = ops::flash_attention(&q, &k, &v, scale);
scalar_loss(&out, &w).backward();
let dq = q.grad().unwrap().to_device(Device::Cpu);
let dk = k.grad().unwrap().to_device(Device::Cpu);
let dv = v.grad().unwrap().to_device(Device::Cpu);
let fwd = move |qh: &[f32], kh: &[f32], vh: &[f32]| -> f32 {
let qv = cuda(qh, &[bh, seq, hd]);
let kv = cuda(kh, &[bh, seq, hd]);
let vv = cuda(vh, &[bh, seq, hd]);
let (o, _) = qv.flash_attention(&kv, &vv, scale);
weighted_sum(&o, &w)
};
// Attention dQ/dK carry softmax curvature; for the small grad magnitudes here
// a larger eps (2e-3) cuts the f32 rounding term (∝|L|/eps) that dominates the
// O(eps²) truncation on a ~4e-4 grad. (dV is exactly linear → cfg_linear.)
let cfg_attn = GradCheckConfig {
eps: 2e-3,
rel_tol: 3e-2,
atol: 1e-3,
};
let (kf, vf, ff) = (k_h.clone(), v_h.clone(), fwd.clone());
let lq = move |x: &[f32], _s: &[usize]| ff(x, &kf, &vf);
report(
"flash dQ",
&grad_check(&q_h, &[bh, seq, hd], &lq, dq.as_slice::<f32>(), cfg_attn),
);
let (qf, vf, ff) = (q_h.clone(), v_h.clone(), fwd.clone());
let lk = move |x: &[f32], _s: &[usize]| ff(&qf, x, &vf);
report(
"flash dK",
&grad_check(&k_h, &[bh, seq, hd], &lk, dk.as_slice::<f32>(), cfg_attn),
);
let (qf, kf, ff) = (q_h.clone(), k_h.clone(), fwd.clone());
let lv = move |x: &[f32], _s: &[usize]| ff(&qf, &kf, x);
report(
"flash dV",
&grad_check(
&v_h,
&[bh, seq, hd],
&lv,
dv.as_slice::<f32>(),
cfg_linear(),
),
);
}
// flash forward must equal the composed attention forward (same SDPA math).
#[test]
fn flash_matches_composed_fwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q = cuda(&fill(n, 341), &[bh, seq, hd]);
let k = cuda(&fill(n, 342), &[bh, seq, hd]);
let v = cuda(&fill(n, 343), &[bh, seq, hd]);
let (oc, _) = q.attention(&k, &v, scale);
let (of, _) = q.flash_attention(&k, &v, scale);
let oc = oc.to_device(Device::Cpu);
let of = of.to_device(Device::Cpu);
let max_rel = oc
.as_slice::<f32>()
.iter()
.zip(of.as_slice::<f32>())
.map(|(c, f)| (c - f).abs() / (c.abs() + 1e-6))
.fold(0.0f32, f32::max);
println!("flash-vs-composed fwd max rel: {max_rel:.3e}");
assert!(
max_rel < 1e-4,
"flash fwd diverges from composed: {max_rel:.3e}"
);
}
// flash backward must equal the (already grad-checked) composed backward. This is
// a sharper test than finite-diff: both share the trusted composed forward as the
// reference, so it isolates the flash bwd dQ/dK/dV math from finite-diff noise on
// near-zero gradient elements.
#[test]
fn flash_bwd_matches_composed_bwd() {
require_gpu();
let (bh, seq, hd) = (2, 40, 16);
let n = bh * seq * hd;
let scale = 1.0 / (hd as f32).sqrt();
let q_h = fill(n, 441);
let k_h = fill(n, 442);
let v_h = fill(n, 443);
let w = fill(n, 444);
let run = |flash: bool| -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q = Var::leaf(cuda(&q_h, &[bh, seq, hd]));
let k = Var::leaf(cuda(&k_h, &[bh, seq, hd]));
let v = Var::leaf(cuda(&v_h, &[bh, seq, hd]));
let out = if flash {
ops::flash_attention(&q, &k, &v, scale)
} else {
ops::attention(&q, &k, &v, scale)
};
scalar_loss(&out, &w).backward();
let g = |x: &Var| {
x.grad()
.unwrap()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
};
(g(&q), g(&k), g(&v))
};
let (cq, ck, cv) = run(false);
let (fq, fk, fv) = run(true);
let maxrel = |a: &[f32], b: &[f32]| -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).abs() / (x.abs() + y.abs() + 1e-4))
.fold(0.0f32, f32::max)
};
let (rq, rk, rv) = (maxrel(&cq, &fq), maxrel(&ck, &fk), maxrel(&cv, &fv));
println!("flash-vs-composed bwd max rel: dQ {rq:.3e} dK {rk:.3e} dV {rv:.3e}");
assert!(rq < 2e-2, "dQ diverges: {rq:.3e}");
assert!(rk < 2e-2, "dK diverges: {rk:.3e}");
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// ---- GQA repeat_kv head broadcast (Phase T15) ----
//
// repeat_kv expands K/V from [batch·num_kv, seq, hd] to [batch·nh, seq, hd]; each
// kv head is broadcast to its `group = nh/num_kv` query heads. The forward is a
// gather (a linear map), so finite-diff is clean. The CRITICAL gate is the
// BACKWARD: a kv head receives the SUM of the `group` query heads sharing it —
// the multi-group-to-one grad accumulation GQA correctness hinges on. We grad-check
// din against finite-diff of L = sum(W∘out) with group>1, plus assert the forward
// actually broadcasts and that group==1 is exact identity.
#[test]
fn repeat_kv_grad() {
require_gpu();
// batch 2, num_kv 2 → bh_kv 4 input rows; nh 6 → group 3, bh_q 12 output rows.
let (batch, num_kv, nh, seq, hd) = (2usize, 2usize, 6usize, 4usize, 5usize);
let n_in = batch * num_kv * seq * hd;
let n_out = batch * nh * seq * hd;
let x_h = fill(n_in, 711);
let w = fill(n_out, 712);
let kv = Var::leaf(cuda(&x_h, &[batch * num_kv, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch);
assert_eq!(out.value().shape(), &[batch * nh, seq, hd]);
// Forward sanity: out head (b·nh + qh) must equal in head (b·num_kv + qh/group).
let group = nh / num_kv;
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
let row = seq * hd;
for b in 0..batch {
for qh in 0..nh {
let kvh = qh / group;
let o0 = (b * nh + qh) * row;
let i0 = (b * num_kv + kvh) * row;
for e in 0..row {
assert_eq!(out_h[o0 + e], x_h[i0 + e], "repeat_kv fwd mismatch");
}
}
}
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
let fwd = move |xh: &[f32], _s: &[usize]| -> f32 {
let kv = cuda(xh, &[batch * num_kv, seq, hd]);
let o = kv.repeat_kv(nh, batch);
weighted_sum(&o, &w)
};
// repeat_kv is exactly linear (gather/sum), so the linear-op tolerances apply.
report(
"repeat_kv din",
&grad_check(
&x_h,
&[batch * num_kv, seq, hd],
&fwd,
din.as_slice::<f32>(),
cfg_linear(),
),
);
}
// group==1 (num_kv == nh) must be a bit-exact identity in BOTH directions — this is
// the regression guard that makes the MHA path (kv_heads == n_heads) unchanged.
#[test]
fn repeat_kv_identity_group1() {
require_gpu();
let (batch, nh, seq, hd) = (2usize, 3usize, 4usize, 5usize);
let n = batch * nh * seq * hd;
let x_h = fill(n, 721);
let w = fill(n, 722);
let kv = Var::leaf(cuda(&x_h, &[batch * nh, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch); // group 1
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
assert_eq!(out_h, x_h, "group-1 repeat_kv fwd must be identity");
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
// dL/din = w exactly (identity forward → grad passes through unchanged).
for (g, expect) in din.as_slice::<f32>().iter().zip(&w) {
assert_eq!(*g, *expect, "group-1 repeat_kv bwd must be identity");
}
}
// ---- dropout (Phase T18) ----
//
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
// (it depends only on (seed, index), NOT on x), so dropout is a fixed elementwise
// linear map `out_i = c_i·x_i` and the central difference of L is differentiable:
// the ± perturbation of each x_i sees the SAME mask. The forward function in the
// closure calls `ops::dropout(x, p, SEED)` with the same SEED, so it reproduces
// the same mask both times.
#[test]
fn dropout_bwd() {
require_gpu();
const SEED: u64 = 0xD120_FE5E;
let p = 0.3f32;
let (m, n) = (16, 12);
let x_h = fill(m * n, 71);
let w = fill(m * n, 72);
let x = Var::leaf(cuda(&x_h, &[m, n]));
let out = ops::dropout(&x, p, SEED);
scalar_loss(&out, &w).backward();
let dx = x.grad().unwrap().to_device(Device::Cpu);
let wf = w.clone();
let lx = move |v: &[f32], s: &[usize]| {
let o = ops::dropout(&Var::leaf(cuda(v, s)), p, SEED);
weighted_sum(&o.value(), &wf)
};
report(
"dropout dX",
&grad_check(&x_h, &[m, n], &lx, dx.as_slice::<f32>(), cfg_linear()),
);
}
// Inverted-dropout expectation + keep-rate check. Over a large tensor and a sweep
// of seeds, the mean of dropout(x) tracks the mean of x (E[out] ≈ x, the inverted
// 1/(1-p) scaling), and the kept fraction tracks 1-p (the RNG is ~Bernoulli).
#[test]
fn dropout_expectation_and_keep_rate() {
require_gpu();
let p = 0.25f32;
let n = 200_000usize;
let x_h = vec![1.0f32; n]; // mean(x) = 1 → mean(out) should ≈ 1
let x = cuda(&x_h, &[n]);
let trials = 8;
let mut mean_out_acc = 0.0f64;
let mut keep_acc = 0.0f64;
for t in 0..trials {
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
let out_h = out.to_device(Device::Cpu);
let mask_h = mask.to_device(Device::Cpu);
let mean_out: f64 = out_h
.as_slice::<f32>()
.iter()
.map(|&v| v as f64)
.sum::<f64>()
/ n as f64;
let kept = mask_h
.as_slice::<f32>()
.iter()
.filter(|&&m| m != 0.0)
.count();
mean_out_acc += mean_out;
keep_acc += kept as f64 / n as f64;
}
let mean_out = mean_out_acc / trials as f64;
let keep_rate = keep_acc / trials as f64;
println!(
"dropout p={p}: E[out]={mean_out:.5} (input mean 1.0), keep_rate={keep_rate:.5} (1-p={:.3})",
1.0 - p
);
assert!(
(mean_out - 1.0).abs() < 0.01,
"E[out] {mean_out} not ≈ input mean 1.0 (inverted scaling broken)"
);
assert!(
(keep_rate - (1.0 - p) as f64).abs() < 0.01,
"keep_rate {keep_rate} not ≈ 1-p {}",
1.0 - p
);
}
// p=0 is a no-op (the op returns x.clone(), no node) → output is bit-identical to
// x and its grad flows straight through (the default-graph regression guard at the
// op level; the model-level bit-identity is in xtrain-model/tests/dropout.rs).
#[test]
fn dropout_p0_is_identity() {
require_gpu();
let (m, n) = (8, 5);
let x_h = fill(m * n, 91);
let x = cuda(&x_h, &[m, n]);
let (out, _mask) = x.dropout(0.0, 12345);
let out_h = out.to_device(Device::Cpu);
for (a, b) in x_h.iter().zip(out_h.as_slice::<f32>()) {
assert_eq!(*a, *b, "p=0 dropout must be identity");
}
}
// --- test helpers ---
// Scalar loss node L = sum(W ∘ out): wraps a fixed-weight Var and reduces. We
// implement it as: elementwise mul by a constant-W leaf, then sum-to-scalar.
fn scalar_loss(out: &Var, w: &[f32]) -> Var {
let wt = Var::leaf(cuda(w, out.value().shape()));
let prod = ops::mul(out, &wt);
sum_all(&prod)
}
// Sum-to-scalar node: out = sum(x). Backward broadcasts the scalar grad to a
// ones-shaped tensor over x. Implemented here (test-local) since the engine's
// op set doesn't include a generic reduction; cross_entropy is the only loss op.
fn sum_all(x: &Var) -> Var {
let xv = x.value();
let total: f32 = xv.to_device(Device::Cpu).as_slice::<f32>().iter().sum();
let scalar = Tensor::from_slice(&[total], &[1]).to_device(xv.device());
let shape: Vec<usize> = xv.shape().to_vec();
Var::from_op(
scalar,
vec![x.clone()],
Box::new(move |d, parents| {
// d is [1]; broadcast d to a same-shape tensor over the input.
let dval = d.to_device(Device::Cpu).as_slice::<f32>()[0];
let ones = vec![dval; shape.iter().product()];
let g = Tensor::from_slice(&ones, &shape).to_device(Device::Cuda(0));
Var::push_grad(&parents[0], g);
}),
)
}
// Manual transpose node for the composed-attention test (the engine has no
// transpose op; xserv does the equivalent host-side reshape around RoPE).
fn transpose_var(x: &Var) -> Var {
let xt = x.value().transpose_2d();
Var::from_op(
xt,
vec![x.clone()],
Box::new(|d, parents| {
Var::push_grad(&parents[0], d.transpose_2d());
}),
)
}