gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)

- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 01:37:16 +08:00
parent 62b1cb5dc7
commit 830d06ad01
15 changed files with 712 additions and 41 deletions

84
csrc/ops/repeat_kv.cu Normal file
View File

@@ -0,0 +1,84 @@
// repeat_kv: the grouped-query-attention (GQA) head broadcast (Phase T15).
//
// GQA projects K/V to fewer heads than Q (num_kv_heads < num_heads); each KV head
// is SHARED by a group of `group = num_heads / num_kv_heads` query heads. Before
// the SDPA (composed or fused-flash, both untouched), we expand the KV tensor from
// [B·num_kv, S, hd] to the full [B·nh, S, hd] so the existing per-(batch,head)
// attention sees a full set of heads. GQA is then "free" for both SDPA paths.
//
// Layout: K/V are [bh_kv, S, hd] = [B·num_kv, S, hd] row-major, contiguous; the
// output is [bh_q, S, hd] = [B·nh, S, hd]. The head ordering matches xserv's
// repeat_kv (crates/xserv-model/src/qwen3.rs): query head qh reads kv head
// qh/group, query heads CONTIGUOUS within a group (dst = kvh*group + r). So:
//
// out[b·nh + qh, :, :] = in[b·num_kv + qh/group, :, :]
//
// Forward is a gather (each output row copies one input row). Backward is its
// transpose: a kv head receives the SUM of the `group` query heads that share it
// din[b·num_kv + kvh, e] = Σ_{r∈[0,group)} dout[b·nh + kvh·group + r, e]
// — the multi-group-to-one grad accumulation GQA's correctness hinges on. Each
// input element is owned by exactly one thread that serially sums its `group`
// source rows: race-free, NO atomics, run-to-run DETERMINISTIC. group==1 makes
// both directions a plain copy (identity → bit-identical to the MHA path).
//
// All F32, contiguous. (bf16 callers upcast → f32 on the Rust side and downcast
// the f32 result, mirroring the rest of the attention stack's fp32 policy.)
#include <math.h>
extern "C" {
// Forward gather. grid-stride over the bh_q·S·hd output elements; each output
// element copies from its kv-head source row. b = (out_bh / nh), qh = out_bh % nh,
// kv source bh = b·num_kv + qh/group.
__global__ void repeat_kv_fwd_k(const float* in, float* out, int nh, int num_kv,
int group, int seq, int hd) {
long row_elems = (long)seq * hd; // S·hd per head block
// One block per (batch, query-head); threads cover the S·hd block.
int out_bh = blockIdx.x; // over B·nh
int b = out_bh / nh;
int qh = out_bh % nh;
int kvh = qh / group;
const float* src = in + ((long)b * num_kv + kvh) * row_elems;
float* dst = out + (long)out_bh * row_elems;
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) dst[e] = src[e];
}
void launch_repeat_kv_fwd_f32(const float* in, float* out, int batch, int nh,
int num_kv, int seq, int hd, void* s) {
int group = nh / num_kv;
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
if (blk < 32) blk = 32;
repeat_kv_fwd_k<<<batch * nh, blk, 0, (cudaStream_t)s>>>(in, out, nh, num_kv,
group, seq, hd);
}
// Backward sum. One block per (batch, kv-head); threads cover the S·hd block.
// Each owned input element sums the `group` contiguous query-head source rows.
__global__ void repeat_kv_bwd_k(const float* dout, float* din, int nh, int num_kv,
int group, int seq, int hd) {
long row_elems = (long)seq * hd;
int in_bh = blockIdx.x; // over B·num_kv
int b = in_bh / num_kv;
int kvh = in_bh % num_kv;
int qh0 = kvh * group; // first query head sharing this kv head
float* dst = din + (long)in_bh * row_elems;
const float* base = dout + ((long)b * nh + qh0) * row_elems;
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) {
float acc = 0.0f;
for (int r = 0; r < group; ++r) acc += base[(long)r * row_elems + e];
dst[e] = acc;
}
}
void launch_repeat_kv_bwd_f32(const float* dout, float* din, int batch, int nh,
int num_kv, int seq, int hd, void* s) {
int group = nh / num_kv;
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
if (blk < 32) blk = 32;
repeat_kv_bwd_k<<<batch * num_kv, blk, 0, (cudaStream_t)s>>>(dout, din, nh,
num_kv, group,
seq, hd);
}
} // extern "C"