diff --git a/crates/xtrain-autodiff/src/ops.rs b/crates/xtrain-autodiff/src/ops.rs index 99afb3e..7e91db8 100644 --- a/crates/xtrain-autodiff/src/ops.rs +++ b/crates/xtrain-autodiff/src/ops.rs @@ -376,6 +376,27 @@ pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var { ) } +/// GQA repeat_kv head broadcast (Phase T15). `kv`:[batch·num_kv, seq, head_dim] +/// (a K or V tensor) → `[batch·nh, seq, head_dim]`, each KV head broadcast to its +/// `group = nh/num_kv` query heads (qh ← kv head qh/group, contiguous groups — +/// matches xserv's repeat_kv). Feeds the unchanged composed/flash SDPA so GQA is +/// "free" for both. Backward SUMS the `group` query heads sharing each KV head back +/// onto it (the multi-group grad accumulation). `nh == num_kv` (group 1) is identity +/// → bit-identical to the MHA path. `batch` lets the op recover num_kv from kv's bh. +pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var { + let bh_kv = kv.value().shape()[0]; + let num_kv = bh_kv / batch; + let out = kv.value().repeat_kv(nh, batch); + Var::from_op( + out, + vec![kv.clone()], + Box::new(move |dout, parents| { + let din = Tensor::repeat_kv_backward(dout, num_kv, batch); + Var::push_grad(&parents[0], din); + }), + ) +} + /// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per /// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`, /// scaled by the upstream scalar grad. diff --git a/crates/xtrain-autodiff/tests/autograd.rs b/crates/xtrain-autodiff/tests/autograd.rs index 35bfbe1..91bc788 100644 --- a/crates/xtrain-autodiff/tests/autograd.rs +++ b/crates/xtrain-autodiff/tests/autograd.rs @@ -776,6 +776,94 @@ fn flash_bwd_matches_composed_bwd() { assert!(rv < 2e-2, "dV diverges: {rv:.3e}"); } +// ---- GQA repeat_kv head broadcast (Phase T15) ---- +// +// repeat_kv expands K/V from [batch·num_kv, seq, hd] to [batch·nh, seq, hd]; each +// kv head is broadcast to its `group = nh/num_kv` query heads. The forward is a +// gather (a linear map), so finite-diff is clean. The CRITICAL gate is the +// BACKWARD: a kv head receives the SUM of the `group` query heads sharing it — +// the multi-group-to-one grad accumulation GQA correctness hinges on. We grad-check +// din against finite-diff of L = sum(W∘out) with group>1, plus assert the forward +// actually broadcasts and that group==1 is exact identity. +#[test] +fn repeat_kv_grad() { + require_gpu(); + // batch 2, num_kv 2 → bh_kv 4 input rows; nh 6 → group 3, bh_q 12 output rows. + let (batch, num_kv, nh, seq, hd) = (2usize, 2usize, 6usize, 4usize, 5usize); + let n_in = batch * num_kv * seq * hd; + let n_out = batch * nh * seq * hd; + let x_h = fill(n_in, 711); + let w = fill(n_out, 712); + + let kv = Var::leaf(cuda(&x_h, &[batch * num_kv, seq, hd])); + let out = ops::repeat_kv(&kv, nh, batch); + assert_eq!(out.value().shape(), &[batch * nh, seq, hd]); + + // Forward sanity: out head (b·nh + qh) must equal in head (b·num_kv + qh/group). + let group = nh / num_kv; + let out_h = out + .value() + .to_device(Device::Cpu) + .as_slice::() + .to_vec(); + let row = seq * hd; + for b in 0..batch { + for qh in 0..nh { + let kvh = qh / group; + let o0 = (b * nh + qh) * row; + let i0 = (b * num_kv + kvh) * row; + for e in 0..row { + assert_eq!(out_h[o0 + e], x_h[i0 + e], "repeat_kv fwd mismatch"); + } + } + } + + scalar_loss(&out, &w).backward(); + let din = kv.grad().unwrap().to_device(Device::Cpu); + + let fwd = move |xh: &[f32], _s: &[usize]| -> f32 { + let kv = cuda(xh, &[batch * num_kv, seq, hd]); + let o = kv.repeat_kv(nh, batch); + weighted_sum(&o, &w) + }; + // repeat_kv is exactly linear (gather/sum), so the linear-op tolerances apply. + report( + "repeat_kv din", + &grad_check( + &x_h, + &[batch * num_kv, seq, hd], + &fwd, + din.as_slice::(), + cfg_linear(), + ), + ); +} + +// group==1 (num_kv == nh) must be a bit-exact identity in BOTH directions — this is +// the regression guard that makes the MHA path (kv_heads == n_heads) unchanged. +#[test] +fn repeat_kv_identity_group1() { + require_gpu(); + let (batch, nh, seq, hd) = (2usize, 3usize, 4usize, 5usize); + let n = batch * nh * seq * hd; + let x_h = fill(n, 721); + let w = fill(n, 722); + let kv = Var::leaf(cuda(&x_h, &[batch * nh, seq, hd])); + let out = ops::repeat_kv(&kv, nh, batch); // group 1 + let out_h = out + .value() + .to_device(Device::Cpu) + .as_slice::() + .to_vec(); + assert_eq!(out_h, x_h, "group-1 repeat_kv fwd must be identity"); + scalar_loss(&out, &w).backward(); + let din = kv.grad().unwrap().to_device(Device::Cpu); + // dL/din = w exactly (identity forward → grad passes through unchanged). + for (g, expect) in din.as_slice::().iter().zip(&w) { + assert_eq!(*g, *expect, "group-1 repeat_kv bwd must be identity"); + } +} + // ---- dropout (Phase T18) ---- // // Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant @@ -827,9 +915,17 @@ fn dropout_expectation_and_keep_rate() { let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64); let out_h = out.to_device(Device::Cpu); let mask_h = mask.to_device(Device::Cpu); - let mean_out: f64 = - out_h.as_slice::().iter().map(|&v| v as f64).sum::() / n as f64; - let kept = mask_h.as_slice::().iter().filter(|&&m| m != 0.0).count(); + let mean_out: f64 = out_h + .as_slice::() + .iter() + .map(|&v| v as f64) + .sum::() + / n as f64; + let kept = mask_h + .as_slice::() + .iter() + .filter(|&&m| m != 0.0) + .count(); mean_out_acc += mean_out; keep_acc += kept as f64 / n as f64; } diff --git a/crates/xtrain-cuda/build.rs b/crates/xtrain-cuda/build.rs index 02bb166..031b7df 100644 --- a/crates/xtrain-cuda/build.rs +++ b/crates/xtrain-cuda/build.rs @@ -37,6 +37,7 @@ fn main() { .file("../../csrc/ops/optim.cu") .file("../../csrc/ops/attention.cu") .file("../../csrc/ops/flash_attention.cu") + .file("../../csrc/ops/repeat_kv.cu") .file("../../csrc/ops/cast.cu") .file("../../csrc/ops/dropout.cu") .compile("xtrain_cuda_kernels"); diff --git a/crates/xtrain-cuda/src/ffi.rs b/crates/xtrain-cuda/src/ffi.rs index 65c2e1b..d757e98 100644 --- a/crates/xtrain-cuda/src/ffi.rs +++ b/crates/xtrain-cuda/src/ffi.rs @@ -296,6 +296,37 @@ unsafe extern "C" { ); } +// GQA repeat_kv head broadcast (csrc/ops/repeat_kv.cu, Phase T15). Expands a K/V +// tensor from [batch·num_kv, S, hd] to the full [batch·nh, S, hd] so the SDPA +// (composed or flash, both untouched) sees a full set of heads. Forward gathers +// (out head qh ← kv head qh/group, group = nh/num_kv); backward sums the `group` +// query heads sharing each kv head (deterministic, no atomics). All F32. +#[cfg(not(no_cuda))] +unsafe extern "C" { + // Forward: out[b·nh+qh] = in[b·num_kv + qh/group], per [S,hd] head block. + pub fn launch_repeat_kv_fwd_f32( + input: *const f32, + out: *mut f32, + batch: i32, + nh: i32, + num_kv: i32, + seq: i32, + hd: i32, + s: CudaStream, + ); + // Backward: din[b·num_kv+kvh] = Σ_{r