// repeat_kv: the grouped-query-attention (GQA) head broadcast (Phase T15). // // GQA projects K/V to fewer heads than Q (num_kv_heads < num_heads); each KV head // is SHARED by a group of `group = num_heads / num_kv_heads` query heads. Before // the SDPA (composed or fused-flash, both untouched), we expand the KV tensor from // [B·num_kv, S, hd] to the full [B·nh, S, hd] so the existing per-(batch,head) // attention sees a full set of heads. GQA is then "free" for both SDPA paths. // // Layout: K/V are [bh_kv, S, hd] = [B·num_kv, S, hd] row-major, contiguous; the // output is [bh_q, S, hd] = [B·nh, S, hd]. The head ordering matches xserv's // repeat_kv (crates/xserv-model/src/qwen3.rs): query head qh reads kv head // qh/group, query heads CONTIGUOUS within a group (dst = kvh*group + r). So: // // out[b·nh + qh, :, :] = in[b·num_kv + qh/group, :, :] // // Forward is a gather (each output row copies one input row). Backward is its // transpose: a kv head receives the SUM of the `group` query heads that share it // din[b·num_kv + kvh, e] = Σ_{r∈[0,group)} dout[b·nh + kvh·group + r, e] // — the multi-group-to-one grad accumulation GQA's correctness hinges on. Each // input element is owned by exactly one thread that serially sums its `group` // source rows: race-free, NO atomics, run-to-run DETERMINISTIC. group==1 makes // both directions a plain copy (identity → bit-identical to the MHA path). // // All F32, contiguous. (bf16 callers upcast → f32 on the Rust side and downcast // the f32 result, mirroring the rest of the attention stack's fp32 policy.) #include extern "C" { // Forward gather. grid-stride over the bh_q·S·hd output elements; each output // element copies from its kv-head source row. b = (out_bh / nh), qh = out_bh % nh, // kv source bh = b·num_kv + qh/group. __global__ void repeat_kv_fwd_k(const float* in, float* out, int nh, int num_kv, int group, int seq, int hd) { long row_elems = (long)seq * hd; // S·hd per head block // One block per (batch, query-head); threads cover the S·hd block. int out_bh = blockIdx.x; // over B·nh int b = out_bh / nh; int qh = out_bh % nh; int kvh = qh / group; const float* src = in + ((long)b * num_kv + kvh) * row_elems; float* dst = out + (long)out_bh * row_elems; for (long e = threadIdx.x; e < row_elems; e += blockDim.x) dst[e] = src[e]; } void launch_repeat_kv_fwd_f32(const float* in, float* out, int batch, int nh, int num_kv, int seq, int hd, void* s) { int group = nh / num_kv; int blk = (seq * hd) < 256 ? (seq * hd) : 256; if (blk < 32) blk = 32; repeat_kv_fwd_k<<>>(in, out, nh, num_kv, group, seq, hd); } // Backward sum. One block per (batch, kv-head); threads cover the S·hd block. // Each owned input element sums the `group` contiguous query-head source rows. __global__ void repeat_kv_bwd_k(const float* dout, float* din, int nh, int num_kv, int group, int seq, int hd) { long row_elems = (long)seq * hd; int in_bh = blockIdx.x; // over B·num_kv int b = in_bh / num_kv; int kvh = in_bh % num_kv; int qh0 = kvh * group; // first query head sharing this kv head float* dst = din + (long)in_bh * row_elems; const float* base = dout + ((long)b * nh + qh0) * row_elems; for (long e = threadIdx.x; e < row_elems; e += blockDim.x) { float acc = 0.0f; for (int r = 0; r < group; ++r) acc += base[(long)r * row_elems + e]; dst[e] = acc; } } void launch_repeat_kv_bwd_f32(const float* dout, float* din, int batch, int nh, int num_kv, int seq, int hd, void* s) { int group = nh / num_kv; int blk = (seq * hd) < 256 ? (seq * hd) : 256; if (blk < 32) blk = 32; repeat_kv_bwd_k<<>>(dout, din, nh, num_kv, group, seq, hd); } } // extern "C"