gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each kv head sums its group of query-head grads; no atomics) + Tensor/ops node. - Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim; attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical. - --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export writes real num_key_value_heads (xserv repeat_kv grouping aligned). - Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape); parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
84
csrc/ops/repeat_kv.cu
Normal file
84
csrc/ops/repeat_kv.cu
Normal file
@@ -0,0 +1,84 @@
|
||||
// repeat_kv: the grouped-query-attention (GQA) head broadcast (Phase T15).
|
||||
//
|
||||
// GQA projects K/V to fewer heads than Q (num_kv_heads < num_heads); each KV head
|
||||
// is SHARED by a group of `group = num_heads / num_kv_heads` query heads. Before
|
||||
// the SDPA (composed or fused-flash, both untouched), we expand the KV tensor from
|
||||
// [B·num_kv, S, hd] to the full [B·nh, S, hd] so the existing per-(batch,head)
|
||||
// attention sees a full set of heads. GQA is then "free" for both SDPA paths.
|
||||
//
|
||||
// Layout: K/V are [bh_kv, S, hd] = [B·num_kv, S, hd] row-major, contiguous; the
|
||||
// output is [bh_q, S, hd] = [B·nh, S, hd]. The head ordering matches xserv's
|
||||
// repeat_kv (crates/xserv-model/src/qwen3.rs): query head qh reads kv head
|
||||
// qh/group, query heads CONTIGUOUS within a group (dst = kvh*group + r). So:
|
||||
//
|
||||
// out[b·nh + qh, :, :] = in[b·num_kv + qh/group, :, :]
|
||||
//
|
||||
// Forward is a gather (each output row copies one input row). Backward is its
|
||||
// transpose: a kv head receives the SUM of the `group` query heads that share it
|
||||
// din[b·num_kv + kvh, e] = Σ_{r∈[0,group)} dout[b·nh + kvh·group + r, e]
|
||||
// — the multi-group-to-one grad accumulation GQA's correctness hinges on. Each
|
||||
// input element is owned by exactly one thread that serially sums its `group`
|
||||
// source rows: race-free, NO atomics, run-to-run DETERMINISTIC. group==1 makes
|
||||
// both directions a plain copy (identity → bit-identical to the MHA path).
|
||||
//
|
||||
// All F32, contiguous. (bf16 callers upcast → f32 on the Rust side and downcast
|
||||
// the f32 result, mirroring the rest of the attention stack's fp32 policy.)
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Forward gather. grid-stride over the bh_q·S·hd output elements; each output
|
||||
// element copies from its kv-head source row. b = (out_bh / nh), qh = out_bh % nh,
|
||||
// kv source bh = b·num_kv + qh/group.
|
||||
__global__ void repeat_kv_fwd_k(const float* in, float* out, int nh, int num_kv,
|
||||
int group, int seq, int hd) {
|
||||
long row_elems = (long)seq * hd; // S·hd per head block
|
||||
// One block per (batch, query-head); threads cover the S·hd block.
|
||||
int out_bh = blockIdx.x; // over B·nh
|
||||
int b = out_bh / nh;
|
||||
int qh = out_bh % nh;
|
||||
int kvh = qh / group;
|
||||
const float* src = in + ((long)b * num_kv + kvh) * row_elems;
|
||||
float* dst = out + (long)out_bh * row_elems;
|
||||
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) dst[e] = src[e];
|
||||
}
|
||||
|
||||
void launch_repeat_kv_fwd_f32(const float* in, float* out, int batch, int nh,
|
||||
int num_kv, int seq, int hd, void* s) {
|
||||
int group = nh / num_kv;
|
||||
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
|
||||
if (blk < 32) blk = 32;
|
||||
repeat_kv_fwd_k<<<batch * nh, blk, 0, (cudaStream_t)s>>>(in, out, nh, num_kv,
|
||||
group, seq, hd);
|
||||
}
|
||||
|
||||
// Backward sum. One block per (batch, kv-head); threads cover the S·hd block.
|
||||
// Each owned input element sums the `group` contiguous query-head source rows.
|
||||
__global__ void repeat_kv_bwd_k(const float* dout, float* din, int nh, int num_kv,
|
||||
int group, int seq, int hd) {
|
||||
long row_elems = (long)seq * hd;
|
||||
int in_bh = blockIdx.x; // over B·num_kv
|
||||
int b = in_bh / num_kv;
|
||||
int kvh = in_bh % num_kv;
|
||||
int qh0 = kvh * group; // first query head sharing this kv head
|
||||
float* dst = din + (long)in_bh * row_elems;
|
||||
const float* base = dout + ((long)b * nh + qh0) * row_elems;
|
||||
for (long e = threadIdx.x; e < row_elems; e += blockDim.x) {
|
||||
float acc = 0.0f;
|
||||
for (int r = 0; r < group; ++r) acc += base[(long)r * row_elems + e];
|
||||
dst[e] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
void launch_repeat_kv_bwd_f32(const float* dout, float* din, int batch, int nh,
|
||||
int num_kv, int seq, int hd, void* s) {
|
||||
int group = nh / num_kv;
|
||||
int blk = (seq * hd) < 256 ? (seq * hd) : 256;
|
||||
if (blk < 32) blk = 32;
|
||||
repeat_kv_bwd_k<<<batch * num_kv, blk, 0, (cudaStream_t)s>>>(dout, din, nh,
|
||||
num_kv, group,
|
||||
seq, hd);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user