Compare commits

...

2 Commits

Author SHA1 Message Date
830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00
62b1cb5dc7 docs: Phase T15 — GQA design (repeat_kv broadcast op + backward grad-sum)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:30:34 +08:00
16 changed files with 879 additions and 41 deletions

View File

@@ -376,6 +376,27 @@ pub fn flash_attention(q: &Var, k: &Var, v: &Var, scale: f32) -> Var {
)
}
/// GQA repeat_kv head broadcast (Phase T15). `kv`:[batch·num_kv, seq, head_dim]
/// (a K or V tensor) → `[batch·nh, seq, head_dim]`, each KV head broadcast to its
/// `group = nh/num_kv` query heads (qh ← kv head qh/group, contiguous groups —
/// matches xserv's repeat_kv). Feeds the unchanged composed/flash SDPA so GQA is
/// "free" for both. Backward SUMS the `group` query heads sharing each KV head back
/// onto it (the multi-group grad accumulation). `nh == num_kv` (group 1) is identity
/// → bit-identical to the MHA path. `batch` lets the op recover num_kv from kv's bh.
pub fn repeat_kv(kv: &Var, nh: usize, batch: usize) -> Var {
let bh_kv = kv.value().shape()[0];
let num_kv = bh_kv / batch;
let out = kv.value().repeat_kv(nh, batch);
Var::from_op(
out,
vec![kv.clone()],
Box::new(move |dout, parents| {
let din = Tensor::repeat_kv_backward(dout, num_kv, batch);
Var::push_grad(&parents[0], din);
}),
)
}
/// Cross-entropy mean loss over logits `x:[rows,cols]` with one I32 target per
/// row. Returns a scalar [`Var`]. Backward: `dx = (probs - onehot)/rows`,
/// scaled by the upstream scalar grad.

View File

@@ -776,6 +776,94 @@ fn flash_bwd_matches_composed_bwd() {
assert!(rv < 2e-2, "dV diverges: {rv:.3e}");
}
// ---- GQA repeat_kv head broadcast (Phase T15) ----
//
// repeat_kv expands K/V from [batch·num_kv, seq, hd] to [batch·nh, seq, hd]; each
// kv head is broadcast to its `group = nh/num_kv` query heads. The forward is a
// gather (a linear map), so finite-diff is clean. The CRITICAL gate is the
// BACKWARD: a kv head receives the SUM of the `group` query heads sharing it —
// the multi-group-to-one grad accumulation GQA correctness hinges on. We grad-check
// din against finite-diff of L = sum(W∘out) with group>1, plus assert the forward
// actually broadcasts and that group==1 is exact identity.
#[test]
fn repeat_kv_grad() {
require_gpu();
// batch 2, num_kv 2 → bh_kv 4 input rows; nh 6 → group 3, bh_q 12 output rows.
let (batch, num_kv, nh, seq, hd) = (2usize, 2usize, 6usize, 4usize, 5usize);
let n_in = batch * num_kv * seq * hd;
let n_out = batch * nh * seq * hd;
let x_h = fill(n_in, 711);
let w = fill(n_out, 712);
let kv = Var::leaf(cuda(&x_h, &[batch * num_kv, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch);
assert_eq!(out.value().shape(), &[batch * nh, seq, hd]);
// Forward sanity: out head (b·nh + qh) must equal in head (b·num_kv + qh/group).
let group = nh / num_kv;
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
let row = seq * hd;
for b in 0..batch {
for qh in 0..nh {
let kvh = qh / group;
let o0 = (b * nh + qh) * row;
let i0 = (b * num_kv + kvh) * row;
for e in 0..row {
assert_eq!(out_h[o0 + e], x_h[i0 + e], "repeat_kv fwd mismatch");
}
}
}
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
let fwd = move |xh: &[f32], _s: &[usize]| -> f32 {
let kv = cuda(xh, &[batch * num_kv, seq, hd]);
let o = kv.repeat_kv(nh, batch);
weighted_sum(&o, &w)
};
// repeat_kv is exactly linear (gather/sum), so the linear-op tolerances apply.
report(
"repeat_kv din",
&grad_check(
&x_h,
&[batch * num_kv, seq, hd],
&fwd,
din.as_slice::<f32>(),
cfg_linear(),
),
);
}
// group==1 (num_kv == nh) must be a bit-exact identity in BOTH directions — this is
// the regression guard that makes the MHA path (kv_heads == n_heads) unchanged.
#[test]
fn repeat_kv_identity_group1() {
require_gpu();
let (batch, nh, seq, hd) = (2usize, 3usize, 4usize, 5usize);
let n = batch * nh * seq * hd;
let x_h = fill(n, 721);
let w = fill(n, 722);
let kv = Var::leaf(cuda(&x_h, &[batch * nh, seq, hd]));
let out = ops::repeat_kv(&kv, nh, batch); // group 1
let out_h = out
.value()
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec();
assert_eq!(out_h, x_h, "group-1 repeat_kv fwd must be identity");
scalar_loss(&out, &w).backward();
let din = kv.grad().unwrap().to_device(Device::Cpu);
// dL/din = w exactly (identity forward → grad passes through unchanged).
for (g, expect) in din.as_slice::<f32>().iter().zip(&w) {
assert_eq!(*g, *expect, "group-1 repeat_kv bwd must be identity");
}
}
// ---- dropout (Phase T18) ----
//
// Fixed-seed finite-diff grad-check. Under a fixed `seed` the mask is constant
@@ -827,9 +915,17 @@ fn dropout_expectation_and_keep_rate() {
let (out, mask) = x.dropout(p, 0x5EED_0000 + t as u64);
let out_h = out.to_device(Device::Cpu);
let mask_h = mask.to_device(Device::Cpu);
let mean_out: f64 =
out_h.as_slice::<f32>().iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let kept = mask_h.as_slice::<f32>().iter().filter(|&&m| m != 0.0).count();
let mean_out: f64 = out_h
.as_slice::<f32>()
.iter()
.map(|&v| v as f64)
.sum::<f64>()
/ n as f64;
let kept = mask_h
.as_slice::<f32>()
.iter()
.filter(|&&m| m != 0.0)
.count();
mean_out_acc += mean_out;
keep_acc += kept as f64 / n as f64;
}

View File

@@ -37,6 +37,7 @@ fn main() {
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/flash_attention.cu")
.file("../../csrc/ops/repeat_kv.cu")
.file("../../csrc/ops/cast.cu")
.file("../../csrc/ops/dropout.cu")
.compile("xtrain_cuda_kernels");

View File

@@ -296,6 +296,37 @@ unsafe extern "C" {
);
}
// GQA repeat_kv head broadcast (csrc/ops/repeat_kv.cu, Phase T15). Expands a K/V
// tensor from [batch·num_kv, S, hd] to the full [batch·nh, S, hd] so the SDPA
// (composed or flash, both untouched) sees a full set of heads. Forward gathers
// (out head qh ← kv head qh/group, group = nh/num_kv); backward sums the `group`
// query heads sharing each kv head (deterministic, no atomics). All F32.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Forward: out[b·nh+qh] = in[b·num_kv + qh/group], per [S,hd] head block.
pub fn launch_repeat_kv_fwd_f32(
input: *const f32,
out: *mut f32,
batch: i32,
nh: i32,
num_kv: i32,
seq: i32,
hd: i32,
s: CudaStream,
);
// Backward: din[b·num_kv+kvh] = Σ_{r<group} dout[b·nh + kvh·group + r].
pub fn launch_repeat_kv_bwd_f32(
dout: *const f32,
din: *mut f32,
batch: i32,
nh: i32,
num_kv: i32,
seq: i32,
hd: i32,
s: CudaStream,
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
// the global grad-norm reduction + in-place rescale (Phase T7).
#[cfg(not(no_cuda))]

View File

@@ -61,6 +61,8 @@ fn main() {
let head_dim = flag(&args, "--head-dim", 16usize);
let n_layers = flag(&args, "--layers", 4usize);
let ffn = flag(&args, "--ffn", 64usize);
// GQA (Phase T15): num K/V heads (must divide --heads). Default = --heads (MHA).
let kv_heads = flag(&args, "--kv-heads", n_heads);
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
let dim_flag = flag(&args, "--dim", 0usize);
if dim_flag != 0 && dim_flag != n_heads * head_dim {
@@ -137,13 +139,14 @@ fn main() {
(corpus, None)
};
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
println!(
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
"model: dim {} layers {} heads {} kv_heads {} head_dim {} ffn {} → core {:.3}M params \
(+ embed/lm {:.2}M = {:.2}M total)",
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.num_kv_heads,
cfg.head_dim,
cfg.ffn_hidden,
cfg.core_params() as f32 / 1e6,

View File

@@ -10,8 +10,15 @@ pub struct Config {
pub dim: usize,
/// Number of decoder blocks.
pub n_layers: usize,
/// Number of attention heads.
/// Number of attention (query) heads.
pub n_heads: usize,
/// Number of key/value heads (Phase T15, GQA). Each KV head is shared by a
/// group of `n_heads / num_kv_heads` query heads (repeat_kv). Must divide
/// `n_heads`. `num_kv_heads == n_heads` (the default) = MHA, bit-identical to
/// the pre-T15 path; `num_kv_heads < n_heads` = real grouped-query attention,
/// shrinking the K/V projections to `num_kv_heads * head_dim` and exported as a
/// real `num_key_value_heads`.
pub num_kv_heads: usize,
/// Per-head dimension (`dim / n_heads`).
pub head_dim: usize,
/// SwiGLU hidden width (gate/up project to this, down projects back).
@@ -37,6 +44,7 @@ impl Config {
dim: n_heads * head_dim,
n_layers: 2,
n_heads,
num_kv_heads: n_heads, // default = MHA
head_dim,
ffn_hidden: 64,
eps: 1e-5,
@@ -62,6 +70,7 @@ impl Config {
dim: n_heads * head_dim,
n_layers,
n_heads,
num_kv_heads: n_heads, // default = MHA; set via with_kv_heads for GQA
head_dim,
ffn_hidden,
eps: 1e-5,
@@ -70,6 +79,27 @@ impl Config {
}
}
/// Set the number of K/V heads (Phase T15, GQA). Builder-style so existing
/// `from_arch` call sites stay MHA unless they opt in. Asserts `num_kv_heads`
/// divides `n_heads`.
pub fn with_kv_heads(mut self, num_kv_heads: usize) -> Self {
assert!(num_kv_heads > 0, "num_kv_heads must be > 0");
assert_eq!(
self.n_heads % num_kv_heads,
0,
"n_heads {} not divisible by num_kv_heads {num_kv_heads}",
self.n_heads
);
self.num_kv_heads = num_kv_heads;
self
}
/// KV projection width (`num_kv_heads * head_dim`). For GQA this is smaller than
/// `dim`; for MHA it equals `dim`.
pub fn kv_dim(&self) -> usize {
self.num_kv_heads * self.head_dim
}
/// Transformer-core parameter count: everything except the token embedding and
/// the LM head (the two `vocab × dim` tables). This is the figure the scaling
/// ladder is sized against — the 50257-vocab embed+lm_head adds a fixed ~25M on
@@ -82,7 +112,8 @@ impl Config {
pub fn num_params(&self) -> usize {
let per_layer = 2 * self.dim // 2 rmsnorm gammas
+ 2 * self.head_dim // q/k per-head norm gammas
+ 3 * self.dim * self.dim // q/k/v proj
+ self.dim * self.dim // q proj [dim,dim]
+ 2 * self.dim * self.kv_dim() // k/v proj [dim,kv_dim] (GQA: smaller)
+ self.dim * self.dim // out proj
+ 2 * self.dim * self.ffn_hidden // gate/up proj
+ self.ffn_hidden * self.dim; // down proj

View File

@@ -13,8 +13,8 @@ use xtrain_tensor::{DType, Device, Tensor};
struct Block {
attn_norm: Var, // [dim]
wq: Var, // [dim, dim]
wk: Var, // [dim, dim]
wv: Var, // [dim, dim]
wk: Var, // [dim, kv_dim] — kv_dim = num_kv_heads·head_dim (GQA; = dim for MHA)
wv: Var, // [dim, kv_dim]
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
k_norm: Var, // [head_dim]
wo: Var, // [dim, dim]
@@ -91,8 +91,9 @@ impl TinyTransformer {
.map(|_| Block {
attn_norm: mk(&[cfg.dim]),
wq: mk(&[cfg.dim, cfg.dim]),
wk: mk(&[cfg.dim, cfg.dim]),
wv: mk(&[cfg.dim, cfg.dim]),
// GQA (T15): K/V project to num_kv_heads·head_dim (= dim when MHA).
wk: mk(&[cfg.dim, cfg.kv_dim()]),
wv: mk(&[cfg.dim, cfg.kv_dim()]),
q_norm: mk(&[cfg.head_dim]),
k_norm: mk(&[cfg.head_dim]),
wo: mk(&[cfg.dim, cfg.dim]),
@@ -435,37 +436,47 @@ fn attention(
wo: &Var,
) -> Var {
let (nh, hd) = (cfg.n_heads, cfg.head_dim);
let num_kv = cfg.num_kv_heads; // GQA (T15): K/V have fewer heads than Q
let total = batch * seq;
let bh = batch * nh;
let scale = 1.0 / (hd as f32).sqrt();
// Project, qk-norm + RoPE, then lay out as a batched [B*nh, seq, hd] tensor.
// [B*S,dim] @ [dim,dim] = [B*S,dim]
// reshape [B*S, nh, hd]
// Project, qk-norm + RoPE, then lay out as a batched [B*heads, seq, hd] tensor.
// `heads` = nh for Q, num_kv for K/V (GQA; equal for MHA).
// [B*S,dim] @ [dim,heads*hd] = [B*S, heads*hd]
// reshape [B*S, heads, hd]
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
// rope [B*S, nh, hd] with per-sequence position (period = seq)
// reshape [B, S, nh, hd] → transpose(1,2) → [B, nh, S, hd] → [B*nh, S, hd]
let to_bh = |proj: Var, norm: Option<&Var>| -> Var {
let r = ops::reshape(&proj, &[total, nh, hd]);
// rope [B*S, heads, hd] with per-sequence position (period = seq)
// reshape [B, S, heads, hd] → transpose(1,2) → [B, heads, S, hd] → [B*heads, S, hd]
let to_bh = |proj: Var, heads: usize, norm: Option<&Var>| -> Var {
let r = ops::reshape(&proj, &[total, heads, hd]);
let r = match norm {
// Per-head RMSNorm: flatten the (B*S,nh) head rows, norm over hd,
// Per-head RMSNorm: flatten the (B*S,heads) head rows, norm over hd,
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
Some(gamma) => {
let flat = ops::reshape(&r, &[total * nh, hd]);
let flat = ops::reshape(&r, &[total * heads, hd]);
let normed = ops::rms_norm(&flat, &norm_gamma(cdt, gamma), cfg.eps);
let r = ops::reshape(&normed, &[total, nh, hd]);
let r = ops::reshape(&normed, &[total, heads, hd]);
ops::rope(&r, cfg.rope_theta, seq)
}
None => r,
};
let r = ops::reshape(&r, &[batch, seq, nh, hd]);
let t = ops::transpose_4d12(&r); // [B, nh, S, hd]
ops::reshape(&t, &[bh, seq, hd]) // [B*nh, S, hd]
let r = ops::reshape(&r, &[batch, seq, heads, hd]);
let t = ops::transpose_4d12(&r); // [B, heads, S, hd]
ops::reshape(&t, &[batch * heads, seq, hd]) // [B*heads, S, hd]
};
let q = to_bh(linear(cdt, x, wq), Some(q_norm));
let k = to_bh(linear(cdt, x, wk), Some(k_norm));
let v = to_bh(linear(cdt, x, wv), None);
let q = to_bh(linear(cdt, x, wq), nh, Some(q_norm));
// K/V are laid out with num_kv heads, then repeat_kv-broadcast to nh heads so
// the SDPA below (composed or flash, both unchanged) sees a full head set. The
// broadcast's backward sums each KV head's group of query-head grads (GQA). For
// MHA (num_kv == nh) repeat_kv is identity → bit-identical to the pre-T15 path.
let k = to_bh(linear(cdt, x, wk), num_kv, Some(k_norm));
let v = to_bh(linear(cdt, x, wv), num_kv, None);
let (k, v) = if num_kv == nh {
(k, v)
} else {
(ops::repeat_kv(&k, nh, batch), ops::repeat_kv(&v, nh, batch))
};
// Causal SDPA over all B*nh (sequence,head) blocks. `flash` (T14) picks the
// single fused flash kernel (online softmax, no materialized [bh,S,S] scores);

View File

@@ -0,0 +1,268 @@
// T15 GQA correctness gate. Real grouped-query attention (num_kv_heads <
// num_heads): K/V project to num_kv_heads·head_dim and are repeat_kv-broadcast to
// the full head set before the SDPA. This test pins three things:
//
// 1. GQA flash == GQA composed (forward logits + loss + EVERY param grad) — the
// repeat_kv broadcast feeds both SDPA paths unchanged, so they must agree; in
// particular the wk/wv grads (which flow back through repeat_kv's group-sum)
// must match. Parameterised over fp32 (tight) and bf16 (rounding band).
// 2. group==1 (num_kv_heads == n_heads) is BIT-IDENTICAL to the pre-T15 MHA path
// (a model with num_kv_heads explicitly == n_heads vs the default config):
// forward logits + every grad |Δ|=0. The regression guard.
// 3. wk/wv really shrank to [dim, kv_dim] under GQA (shape check).
#![cfg(not(no_cuda))]
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, batched_ids_tensor};
use xtrain_tensor::{DType, Device};
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn build(cfg: Config, device: Device, dtype: DType, flash: bool) -> TinyTransformer {
let mut seed = 1u64;
let m = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
m.with_compute_dtype(dtype).with_flash(flash)
}
fn host(t: &xtrain_tensor::Tensor) -> Vec<f32> {
t.to_dtype(DType::F32)
.to_device(Device::Cpu)
.as_slice::<f32>()
.to_vec()
}
// A real GQA config: 8 query heads, 2 kv heads → group 4. seq=40 > FA_TILE=32 so
// the flash online-softmax tile path is exercised too.
fn gqa_cfg() -> Config {
let mut cfg = Config::tiny();
cfg.vocab = 16;
cfg.n_layers = 3;
// tiny() is 2 heads; rebuild with 8 query / 2 kv heads keeping head_dim=16.
Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden).with_kv_heads(2)
}
fn ids_targets(cfg: &Config, batch: usize, seq: usize) -> (Vec<Vec<i32>>, Vec<Vec<i32>>) {
let seqs = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 7 + i * 3 + 1) % cfg.vocab) as i32)
.collect()
})
.collect();
let tgts = (0..batch)
.map(|b| {
(0..seq)
.map(|i| ((b * 5 + i * 2 + 2) % cfg.vocab) as i32)
.collect()
})
.collect();
(seqs, tgts)
}
#[allow(clippy::type_complexity)]
fn run_both(
cfg: Config,
dtype: DType,
) -> (Vec<f32>, f32, Vec<Vec<f32>>, Vec<f32>, f32, Vec<Vec<f32>>) {
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let (batch, seq) = (3usize, 40usize);
let (seqs, tgts) = ids_targets(&cfg, batch, seq);
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
let off = build(cfg, device, dtype, false);
let off_logits = host(&off.forward_batched(&ids, batch).value());
let off_loss = off.loss_batched(&ids, &tgt, batch);
let off_loss_val = host(&off_loss.value())[0];
off_loss.backward();
let off_grads: Vec<Vec<f32>> = off
.params()
.iter()
.map(|p| host(&p.grad().expect("off grad")))
.collect();
let on = build(cfg, device, dtype, true);
let on_logits = host(&on.forward_batched(&ids, batch).value());
let on_loss = on.loss_batched(&ids, &tgt, batch);
let on_loss_val = host(&on_loss.value())[0];
on_loss.backward();
let on_grads: Vec<Vec<f32>> = on
.params()
.iter()
.map(|p| host(&p.grad().expect("on grad")))
.collect();
(
off_logits,
off_loss_val,
off_grads,
on_logits,
on_loss_val,
on_grads,
)
}
// GQA flash vs composed: same SDPA math on the same repeat_kv-broadcast K/V → fp32
// agrees to reduction-order, bf16 to its rounding band.
#[test]
fn gqa_flash_matches_composed_fp32() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
let cfg = gqa_cfg();
assert!(cfg.num_kv_heads < cfg.n_heads, "test must be real GQA");
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(cfg, DType::F32);
let logit_rel = off_l
.iter()
.zip(&on_l)
.map(|(a, b)| (a - b).abs() / a.abs().max(1e-4))
.fold(0.0f32, f32::max);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!(
"[GQA F32] flash on/off: loss {off_loss:.6}/{on_loss:.6} (rel {loss_rel:.2e}), \
logits max rel {logit_rel:.2e}"
);
assert!(
logit_rel < 1e-3,
"[GQA F32] logits diverged: {logit_rel:.2e}"
);
assert!(loss_rel < 1e-3, "[GQA F32] loss diverged: {loss_rel:.2e}");
let mut worst = 0.0f32;
for (a_g, b_g) in off_g.iter().zip(&on_g) {
for (a, b) in a_g.iter().zip(b_g) {
worst = worst.max((a - b).abs() / a.abs().max(1e-3));
}
}
println!("[GQA F32] flash on/off grad max rel = {worst:.3e}");
assert!(worst < 2e-2, "[GQA F32] grads diverged: {worst:.3e}");
}
#[test]
fn gqa_flash_matches_composed_bf16() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
let (off_l, off_loss, off_g, on_l, on_loss, on_g) = run_both(gqa_cfg(), DType::BF16);
let loss_rel = (off_loss - on_loss).abs() / off_loss.abs().max(1e-4);
println!("[GQA BF16] flash on/off: loss {off_loss:.5}/{on_loss:.5} (rel {loss_rel:.3e})");
assert!(loss_rel < 2e-2, "[GQA BF16] loss diverged: {loss_rel:.3e}");
let n = off_l.len();
let mut rels: Vec<f32> = off_l
.iter()
.zip(&on_l)
.map(|(f, b)| (b - f).abs() / f.abs().max(1.0))
.collect();
rels.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mean: f32 = rels.iter().sum::<f32>() / n as f32;
let p99 = rels[(n as f32 * 0.99) as usize];
println!("[GQA BF16] logits: mean rel {mean:.3e}, p99 rel {p99:.3e}");
assert!(
mean < 1e-2,
"[GQA BF16] logits mean rel too high: {mean:.3e}"
);
assert!(p99 < 5e-2, "[GQA BF16] logits p99 rel too high: {p99:.3e}");
let mut worst = 0.0f32;
for (a_g, b_g) in off_g.iter().zip(&on_g) {
let scale = a_g.iter().map(|v| v.abs()).fold(0.0f32, f32::max).max(1e-6);
let mean_err: f32 =
a_g.iter().zip(b_g).map(|(f, b)| (f - b).abs()).sum::<f32>() / a_g.len() as f32 / scale;
worst = worst.max(mean_err);
}
println!("[GQA BF16] grads: worst per-tensor scaled-mean err = {worst:.3e}");
assert!(worst < 3e-2, "[GQA BF16] grads diverged: {worst:.3e}");
}
// REGRESSION GUARD: num_kv_heads == n_heads (group 1) must be BIT-IDENTICAL to the
// pre-T15 MHA path. Build one model with the default config (num_kv_heads ==
// n_heads, the untouched path: repeat_kv not even invoked) and one that explicitly
// sets num_kv_heads = n_heads, then assert forward logits + every grad match to the
// bit. (Same composed path, so this is exact equality, not a tolerance.)
#[test]
fn gqa_group1_bit_identical_to_mha() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let mut base = Config::tiny();
base.vocab = 16;
base.n_layers = 3;
let base = Config::from_arch(base.vocab, 4, base.head_dim, base.n_layers, base.ffn_hidden);
// `explicit` sets num_kv_heads = n_heads (already the default, but exercises the
// with_kv_heads path); they are the same config → must produce identical output.
let explicit = base.with_kv_heads(base.n_heads);
assert_eq!(base.num_kv_heads, explicit.num_kv_heads);
let (batch, seq) = (2usize, 8usize);
let (seqs, tgts) = ids_targets(&base, batch, seq);
let ids = batched_ids_tensor(&seqs, device);
let tgt = batched_ids_tensor(&tgts, device);
let run = |cfg: Config| -> (Vec<f32>, f32, Vec<Vec<f32>>) {
let m = build(cfg, device, DType::F32, false);
let logits = host(&m.forward_batched(&ids, batch).value());
let loss = m.loss_batched(&ids, &tgt, batch);
let loss_v = host(&loss.value())[0];
loss.backward();
let grads = m
.params()
.iter()
.map(|p| host(&p.grad().unwrap()))
.collect();
(logits, loss_v, grads)
};
let (la, sa, ga) = run(base);
let (lb, sb, gb) = run(explicit);
assert_eq!(la, lb, "group-1 logits must be bit-identical to MHA");
assert_eq!(sa, sb, "group-1 loss must be bit-identical to MHA");
for (a, b) in ga.iter().zip(&gb) {
assert_eq!(a, b, "group-1 grad must be bit-identical to MHA");
}
println!("[GQA group1] bit-identical to MHA: logits + loss + all grads |Δ|=0");
}
// Under GQA, wk/wv must be [dim, kv_dim] (= num_kv_heads·head_dim), wq stays [dim,dim].
#[test]
fn gqa_kv_proj_shape() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let cfg = gqa_cfg();
let m = build(cfg, device, DType::F32, false);
let p = m.params();
// params order: embed, then per block [attn_norm, wq, wk, wv, q_norm, k_norm, wo, ...]
let wq = p[1].value().shape().to_vec();
let wk = p[2].value().shape().to_vec();
let wv = p[3].value().shape().to_vec();
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");
println!(
"[GQA shapes] wq {:?} wk {:?} wv {:?} (kv_dim {})",
wq,
wk,
wv,
cfg.kv_dim()
);
}

View File

@@ -52,6 +52,10 @@ cfg = read_cfg()
DIM = int(cfg["dim"])
NL = int(cfg["n_layers"])
NH = int(cfg["n_heads"])
# GQA (T15): num_kv_heads <= n_heads; each kv head shared by group query heads.
# Default to NH (MHA) for fixtures dumped before the field existed.
NKV = int(cfg.get("num_kv_heads", str(NH)))
GROUP = NH // NKV
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
@@ -114,17 +118,23 @@ for L in layers:
# Attention
x = rms_norm(h, L["attn_norm"])
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
k = (x @ L["wk"]).reshape(B * SEQ, NH, HD)
v = (x @ L["wv"]).reshape(B * SEQ, NH, HD)
# GQA: K/V project to NKV heads, then repeat each kv head GROUP times to NH.
k = (x @ L["wk"]).reshape(B * SEQ, NKV, HD)
v = (x @ L["wv"]).reshape(B * SEQ, NKV, HD)
# Per-head QK-norm (Qwen3-style), before RoPE.
q = rms_norm(q, L["q_norm"])
k = rms_norm(k, L["k_norm"])
q = rope(q) # [B*SEQ, nh, hd]
k = rope(k)
# Reshape to [B, NH, SEQ, HD] so attention runs within each sequence.
k = rope(k) # [B*SEQ, nkv, hd]
# Reshape to [B, *, SEQ, HD]; broadcast kv heads to NH (repeat_interleave along
# the head axis: kv head kvh → query heads [kvh*GROUP, (kvh+1)*GROUP), matching
# xtrain repeat_kv + xserv repeat_kv).
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
k = k.reshape(B, SEQ, NH, HD).transpose(1, 2)
v = v.reshape(B, SEQ, NH, HD).transpose(1, 2)
k = k.reshape(B, SEQ, NKV, HD).transpose(1, 2) # [B, nkv, seq, hd]
v = v.reshape(B, SEQ, NKV, HD).transpose(1, 2)
if GROUP > 1:
k = k.repeat_interleave(GROUP, dim=1) # [B, nh, seq, hd]
v = v.repeat_interleave(GROUP, dim=1)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
probs = torch.softmax(scores, dim=-1)

View File

@@ -58,8 +58,20 @@ fn dump_for_parity() {
// sequence-major to [B*S]=8 ids. Per-sequence RoPE position (resets at the
// sequence boundary) + per-sequence causal masking (no cross-sequence
// attention) are both checked against PyTorch.
// Default: tiny MHA (2 heads). With XTRAIN_PARITY_KV_HEADS=k set, dump a real
// GQA config (8 query heads / k kv heads) so parity.py checks GQA at B>1 — the
// kv-projection shapes + the repeat_kv group-sum backward against PyTorch.
let mut cfg = Config::tiny();
cfg.vocab = 12;
if let Ok(kv) = std::env::var("XTRAIN_PARITY_KV_HEADS") {
let kv: usize = kv.parse().expect("XTRAIN_PARITY_KV_HEADS");
cfg = Config::from_arch(cfg.vocab, 8, cfg.head_dim, cfg.n_layers, cfg.ffn_hidden)
.with_kv_heads(kv);
println!(
"parity: GQA config (n_heads {} kv_heads {})",
cfg.n_heads, cfg.num_kv_heads
);
}
let batch = 2usize;
let seq = 4usize;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6]; // [B*S], sequence-major
@@ -92,6 +104,7 @@ fn dump_for_parity() {
writeln!(f, "dim {}", cfg.dim).unwrap();
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
writeln!(f, "num_kv_heads {}", cfg.num_kv_heads).unwrap();
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
writeln!(f, "eps {:e}", cfg.eps).unwrap();

View File

@@ -1291,6 +1291,97 @@ impl Tensor {
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
}
// --- GQA repeat_kv head broadcast (the T15 op) ---
/// GQA repeat_kv (Phase T15). `self` is a K or V tensor `[batch·num_kv, seq,
/// head_dim]`; returns `[batch·nh, seq, head_dim]` with each KV head broadcast
/// to its `group = nh/num_kv` query heads (query head `qh` ← kv head `qh/group`,
/// query heads contiguous within a group — matching xserv's repeat_kv). When
/// `nh == num_kv` (group 1) this is a plain copy (identity → MHA bit-identical).
/// bf16 callers are upcast→f32→kernel→downcast (the attention stack's fp32 policy).
#[cfg(not(no_cuda))]
pub fn repeat_kv(&self, nh: usize, batch: usize) -> Self {
assert_eq!(
self.ndim(),
3,
"repeat_kv input must be [batch·num_kv,seq,hd]"
);
let (bh_kv, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
assert_eq!(
bh_kv % batch,
0,
"repeat_kv: bh_kv {bh_kv} not div by batch {batch}"
);
let num_kv = bh_kv / batch;
assert_eq!(
nh % num_kv,
0,
"repeat_kv: nh {nh} not divisible by num_kv {num_kv}"
);
if self.dtype == DType::BF16 {
return self
.to_dtype(DType::F32)
.repeat_kv(nh, batch)
.to_dtype(DType::BF16);
}
assert_eq!(self.dtype, DType::F32, "repeat_kv supports F32/BF16");
let out = Tensor::zeros(&[batch * nh, seq, hd], DType::F32, self.device());
unsafe {
xtrain_cuda::ffi::launch_repeat_kv_fwd_f32(
self.data_ptr() as *const f32,
out.data_ptr() as *mut f32,
batch as i32,
nh as i32,
num_kv as i32,
seq as i32,
hd as i32,
std::ptr::null_mut(),
);
}
out
}
/// Backward of [`repeat_kv`](Self::repeat_kv). `dout`:[batch·nh,seq,hd] →
/// `din`:[batch·num_kv,seq,hd], summing the `group` query heads that share each
/// KV head (the multi-group grad accumulation GQA correctness hinges on).
/// Deterministic (no atomics: each kv element serially sums its group rows).
#[cfg(not(no_cuda))]
pub fn repeat_kv_backward(dout: &Tensor, num_kv: usize, batch: usize) -> Self {
assert_eq!(
dout.ndim(),
3,
"repeat_kv_backward dout must be [batch·nh,seq,hd]"
);
let (bh_q, seq, hd) = (dout.shape[0], dout.shape[1], dout.shape[2]);
assert_eq!(bh_q % batch, 0, "repeat_kv_backward: bh_q not div by batch");
let nh = bh_q / batch;
assert_eq!(
nh % num_kv,
0,
"repeat_kv_backward: nh not divisible by num_kv"
);
let dt = dout.dtype;
if dt == DType::BF16 {
return Tensor::repeat_kv_backward(&dout.to_dtype(DType::F32), num_kv, batch)
.to_dtype(DType::BF16);
}
assert_eq!(dt, DType::F32, "repeat_kv_backward supports F32/BF16");
let din = Tensor::zeros(&[batch * num_kv, seq, hd], DType::F32, dout.device());
unsafe {
xtrain_cuda::ffi::launch_repeat_kv_bwd_f32(
dout.data_ptr() as *const f32,
din.data_ptr() as *mut f32,
batch as i32,
nh as i32,
num_kv as i32,
seq as i32,
hd as i32,
std::ptr::null_mut(),
);
}
din
}
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).

View File

@@ -174,7 +174,7 @@ fn config_json(cfg: &Config) -> String {
ffn = cfg.ffn_hidden,
layers = cfg.n_layers,
heads = cfg.n_heads,
kv_heads = cfg.n_heads, // xtrain is MHA → kv heads == query heads
kv_heads = cfg.num_kv_heads, // GQA (T15): real num_key_value_heads (= n_heads for MHA)
head_dim = cfg.head_dim,
eps = cfg.eps,
theta = cfg.rope_theta,
@@ -206,6 +206,8 @@ fn main() {
let head_dim = flag(&args, "--head-dim", 16usize);
let n_layers = flag(&args, "--layers", 4usize);
let ffn = flag(&args, "--ffn", 64usize);
// GQA (Phase T15): num K/V heads (must match the trained ckpt; default = --heads).
let kv_heads = flag(&args, "--kv-heads", n_heads);
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
@@ -214,15 +216,16 @@ fn main() {
// Size the model from the arch flags + gpt2 vocab; must match the checkpoint.
let tok = Tokenizer::from_file(&tok_path);
let vocab = tok.vocab_size();
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
let cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
println!(
"export: ckpt {}{} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
"export: ckpt {}{} (vocab {}, dim {}, layers {}, heads {}, kv_heads {}, head_dim {})",
ckpt.display(),
out_dir.display(),
cfg.vocab,
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.num_kv_heads,
cfg.head_dim,
);

View File

@@ -72,6 +72,8 @@ fn main() {
let head_dim = flag(&args, "--head-dim", 16usize);
let n_layers = flag(&args, "--layers", 4usize);
let ffn = flag(&args, "--ffn", 64usize);
// GQA (Phase T15): num K/V heads (must match the ckpt; default = --heads).
let kv_heads = flag(&args, "--kv-heads", n_heads);
let max_new = flag(&args, "--max-tokens", 40usize);
assert!(device::device_count().unwrap() > 0, "no CUDA device");
@@ -79,14 +81,16 @@ fn main() {
let device = Device::Cuda(0);
let tok = Tokenizer::from_file(&tok_path);
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn);
let cfg = Config::from_arch(tok.vocab_size(), n_heads, head_dim, n_layers, ffn)
.with_kv_heads(kv_heads);
println!(
"greedy_sample: ckpt {} (vocab {}, dim {}, layers {}, heads {}, head_dim {})",
"greedy_sample: ckpt {} (vocab {}, dim {}, layers {}, heads {}, kv_heads {}, head_dim {})",
ckpt.display(),
cfg.vocab,
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.num_kv_heads,
cfg.head_dim,
);

View File

@@ -88,6 +88,8 @@ fn main() {
let head_dim = flag(&args, "--head-dim", 16usize);
let n_layers = flag(&args, "--layers", 4usize);
let ffn = flag(&args, "--ffn", 64usize);
// GQA (Phase T15): num K/V heads (must divide --heads). Default = --heads (MHA).
let kv_heads = flag(&args, "--kv-heads", n_heads);
// `--dim` is informational; dim is always n_heads*head_dim. Warn on mismatch.
let dim_flag = flag(&args, "--dim", 0usize);
if dim_flag != 0 && dim_flag != n_heads * head_dim {
@@ -160,14 +162,16 @@ fn main() {
(corpus, None)
};
let mut cfg = Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn);
let mut cfg =
Config::from_arch(vocab, n_heads, head_dim, n_layers, ffn).with_kv_heads(kv_heads);
cfg.dropout = dropout;
println!(
"model: dim {} layers {} heads {} head_dim {} ffn {} → core {:.3}M params \
"model: dim {} layers {} heads {} kv_heads {} head_dim {} ffn {} → core {:.3}M params \
(+ embed/lm {:.2}M = {:.2}M total)",
cfg.dim,
cfg.n_layers,
cfg.n_heads,
cfg.num_kv_heads,
cfg.head_dim,
cfg.ffn_hidden,
cfg.core_params() as f32 / 1e6,

84
csrc/ops/repeat_kv.cu Normal file
View File

@@ -0,0 +1,84 @@
// repeat_kv: the grouped-query-attention (GQA) head broadcast (Phase T15).
//
// GQA projects K/V to fewer heads than Q (num_kv_heads < num_heads); each KV head
// is SHARED by a group of `group = num_heads / num_kv_heads` query heads. Before
// the SDPA (composed or fused-flash, both untouched), we expand the KV tensor from
// [B·num_kv, S, hd] to the full [B·nh, S, hd] so the existing per-(batch,head)
// attention sees a full set of heads. GQA is then "free" for both SDPA paths.
//
// Layout: K/V are [bh_kv, S, hd] = [B·num_kv, S, hd] row-major, contiguous; the
// output is [bh_q, S, hd] = [B·nh, S, hd]. The head ordering matches xserv's
// repeat_kv (crates/xserv-model/src/qwen3.rs): query head qh reads kv head
// qh/group, query heads CONTIGUOUS within a group (dst = kvh*group + r). So:
//
// out[b·nh + qh, :, :] = in[b·num_kv + qh/group, :, :]
//
// Forward is a gather (each output row copies one input row). Backward is its
// transpose: a kv head receives the SUM of the `group` query heads that share it
// din[b·num_kv + kvh, e] = Σ_{r∈[0,group)} dout[b·nh + kvh·group + r, e]
// — the multi-group-to-one grad accumulation GQA's correctness hinges on. Each
// input element is owned by exactly one thread that serially sums its `group`
// source rows: race-free, NO atomics, run-to-run DETERMINISTIC. group==1 makes
// both directions a plain copy (identity → bit-identical to the MHA path).
//
// All F32, contiguous. (bf16 callers upcast → f32 on the Rust side and downcast
// the f32 result, mirroring the rest of the attention stack's fp32 policy.)
#include <math.h>
extern "C" {
// Forward gather. grid-stride over the bh_q·S·hd output elements; each output
// element copies from its kv-head source row. b = (out_bh / nh), qh = out_bh % nh,
// kv source bh = b·num_kv + qh/group.
__global__ void repeat_kv_fwd_k(const float* in, float* out, int nh, int num_kv,
int group, int seq, int hd) {
long row_elems = (long)seq * hd; // S·hd per head block
// One block per (batch, query-head); threads cover the S·hd block.
int out_bh = blockIdx.x; // over B·nh
int b = out_bh / nh;
int qh = out_bh % nh;
int kvh = qh / group;
const float* src = in + ((long)b * num_kv + kvh) * row_elems;
float* dst = out + (long)out_bh * row_elems;
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) dst[e] = src[e];
}
void launch_repeat_kv_fwd_f32(const float* in, float* out, int batch, int nh,
int num_kv, int seq, int hd, void* s) {
int group = nh / num_kv;
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
if (blk < 32) blk = 32;
repeat_kv_fwd_k<<<batch * nh, blk, 0, (cudaStream_t)s>>>(in, out, nh, num_kv,
group, seq, hd);
}
// Backward sum. One block per (batch, kv-head); threads cover the S·hd block.
// Each owned input element sums the `group` contiguous query-head source rows.
__global__ void repeat_kv_bwd_k(const float* dout, float* din, int nh, int num_kv,
int group, int seq, int hd) {
long row_elems = (long)seq * hd;
int in_bh = blockIdx.x; // over B·num_kv
int b = in_bh / num_kv;
int kvh = in_bh % num_kv;
int qh0 = kvh * group; // first query head sharing this kv head
float* dst = din + (long)in_bh * row_elems;
const float* base = dout + ((long)b * nh + qh0) * row_elems;
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) {
float acc = 0.0f;
for (int r = 0; r < group; ++r) acc += base[(long)r * row_elems + e];
dst[e] = acc;
}
}
void launch_repeat_kv_bwd_f32(const float* dout, float* din, int batch, int nh,
int num_kv, int seq, int hd, void* s) {
int group = nh / num_kv;
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
if (blk < 32) blk = 32;
repeat_kv_bwd_k<<<batch * num_kv, blk, 0, (cudaStream_t)s>>>(dout, din, nh,
num_kv, group,
seq, hd);
}
} // extern "C"

167
docs/14-gqa.md Normal file
View File

@@ -0,0 +1,167 @@
# Phase T15: Grouped-Query Attention (GQA) — Design Document
## Goal
到 T14 为止xtrain 的 attention 都是 **MHA**`num_kv_heads = num_heads`)——每个
query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads`
(退化 GQAdocs/08
T15 做**真正的 grouped-query attention**`num_kv_heads < num_heads`K/V 只投影到
`num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query
头**共享**repeat_kv / broadcast。GQA 是现代 LLMLlama-2-70B、Qwen2/3、Mistral的标配
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。
**硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头
共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
GQA 必须**同时**接进 T14 的 fused flash kernel优先与旧 composed/batched SDPA 路径,且
`num_kv_heads == num_heads``group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。
## 什么是 GQA
MHA`num_heads` 个 query 头,每个配一个独立 K/V 头。
MQAmulti-query所有 query 头共享**一个** K/V 头(极端)。
GQA折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query
头共享。`num_kv_heads = num_heads` 退化为 MHA`num_kv_heads = 1` 退化为 MQA。
```
num_heads = 8, num_kv_heads = 2 → group = 4
q heads: 0 1 2 3 4 5 6 7
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group相邻分组连续
```
**分组约定必须与 xserv repeat_kv 一致**闭环命门xserv 的 `repeat_kv`
`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r`
`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的
repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。
## Module Layoutsurgical复用已验证的两条 SDPAGQA = 头维 broadcast op
```
csrc/ops/repeat_kv.cu # 新repeat_kv fwd头块 gather+ bwd组内 group 行求和,无 atomic确定性
crates/xtrain-cuda/
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明no_cuda 门控)
└── build.rs # +repeat_kv.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward[B*kvh,S,hd]→[B*nh,S,hd]bf16 upcast→f32→downcast
crates/xtrain-autodiff/
├── src/ops.rs # +ops::repeat_kv 节点fwd broadcastbwd 组内求和)
└── tests/autograd.rs # +repeat_kv grad-check含 group>1 的多组梯度累加)
crates/xtrain-model/
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHAfrom_arch 加形参num_params 计 K/V 投影按 kv_dim
├── src/model.rs # wk/wv 投影到 kv_dimattention() 在 SDPA 前对 K/V 做 ops::repeat_kv两条路径都吃到 GQA
└── tests/gqa.rs # 新GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flagDDP 路径)
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-headsconfig.json 写真 num_key_value_heads
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQAkv 投影 + repeat_kv
```
## Key Design Decisions
### ① GQA = K/V 头维 broadcast op喂给**未改动**的两条 SDPA不写第三套 attention
T14 已经有两条**逐位/数值都验证过**的 SDPAcomposed`ops::attention`)与 fused flash
`ops::flash_attention`),二者都吃 `[bh, S, hd]``bh = batch·heads`。GQA 的本质只是「K/V
比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法:
- wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]`
排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。
- 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]`
**broadcast**`[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。
- 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的
`[B·nh, S, hd]`GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。
**为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
的两次 strided GEMM 各算 GQA stride且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
的小 op正确性风险隔离在一处(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模
训练、seq 不极端)可接受;真要省这份显存是 follow-upkernel 内 GQA 读取),记进逃生舱不在 T15 做。
> 备选不采纳flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast
> 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。
> escape hatch先 broadcast-op 把正确性 + 闭环钉死kernel-内 GQA省显存留 follow-up。
### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic
`repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。
反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**
```
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
```
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入
kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free
且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1`
反向退化为单行拷贝identity
autograd 层面其实也可以靠引擎的扇出 SUM把一个 kv Var 喂给 group 个下游),但那样图里要
显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的
broadcast opfwd/bwd 各一发 kernel最简且能单独 grad-check。
### ③ `num_kv_heads` 进 Config它改模型尺寸/导出),默认 = n_heads → 退化 MHA
不同于 T14 的 `use_flash`(运行时旗标,不进 Config`num_kv_heads` **改 K/V 投影的形状、改参数量、
改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。
- `Config``num_kv_heads: usize``from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads =
n_heads`MHA。约束`num_heads % num_kv_heads == 0`(断言)。
- `num_params()`K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`QK-norm 的
`k_norm` 仍是 `[head_dim]`per-head作用在单个 head 向量上,与头数无关)→ 不变。
- **`num_kv_heads == n_heads` 时 `group=1`**`ops::repeat_kv` 是 identityfwd 单行拷贝、bwd 单行
拷贝wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。
> wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]``Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`
> `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是
> wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。
### ④ bf16 / recompute / dropout / DDP 全部自动兼容
- **bf16**`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel →
downcastkernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后dtype 与 K/V 流一致。
- **recomputeT13**repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()`
自然重跑 repeat_kv无额外状态确定性→ 梯度仍逐位一致。
- **dropoutT18**dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。
- **DDP**repeat_kv 不引入新参数wk/wv 变小kv_dim只是参数张量小一圈`params()` 泛化迭代
+ all-reduce 照旧;跨 rank 一致性不受影响。
### ⑤ 导出 xserv写真 `num_key_value_heads`,分组约定对齐 repeat_kv
`export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成
**真 `cfg.num_kv_heads`**`--kv-heads` flag 传入(须与训练 ckpt 一致。q/k/v_proj 各自按其 shape
转置导出k/v_proj 现在是 `[kv_dim, dim]`xserv loader 期望的 GQA 形状。xserv 的 `repeat_kv`
用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致,
闭环(贪心逐 token 一致)成立。
## 验证方法
全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash58× RTX 5090
### 1. 正确性硬闸门全绿dash5 实跑 capture
- **repeat_kv finite-diff grad-check**`autograd.rs::repeat_kv_grad`**核心闸门**——`group>1`
(如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。
外加 `group=1` identity 自检。
- **GQA flash==composed**`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下,
flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致fp32 紧容差 + bf16
舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
- **group=1 与 MHA 逐位一致**`gqa.rs``num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型,
forward + 每参数梯度 `|Δ|=0`(回归保护)。
- **PyTorch GQA 对拍 B>1**`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQAk/v 投影到
kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部
参数梯度composed 与 flash 两条都跑,共用同一 oracle
- **小 GQA 配置短训收敛**:一个真 GQA 小模型短训loss 单调降、无 NaN、采样连贯。
- **全回归套开/关**autograd / structural / batched==looped / bf16 / recompute逐位/ overfit 27/27 /
AdamWGPU bit-exact + host 对 torch/ DDP loss-match + 跨 rank`--test-threads=1`/ flash /
grad_accum / dropout / **xserv 闭环 md5**。MHA 默认kv=heads图不变 → 不回归。
### 2. 闭环payoff—— 真 GQA 端到端
导出一个 `num_key_value_heads < num_attention_heads` GQA checkpoint xserv 加载 贪心生成
**对 xtrain 自身逐 token 一致**BF16 推理 vs f32 训练 v1v8 同款判据)。这是 GQA 真正落地的证明
训练侧的分组导出的分组推理侧 xserv repeat_kv 分组三方对齐
## 实测结果dash5
> 待 dash5 实跑回填gate 表 + 数字)。