- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
& flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
(GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>