# Phase T15: Grouped-Query Attention (GQA) — Design Document ## Goal 到 T14 为止,xtrain 的 attention 都是 **MHA**(`num_kv_heads = num_heads`)——每个 query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads` (退化 GQA,docs/08)。 T15 做**真正的 grouped-query attention**:`num_kv_heads < num_heads`,K/V 只投影到 `num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query 头**共享**(repeat_kv / broadcast)。GQA 是现代 LLM(Llama-2-70B、Qwen2/3、Mistral)的标配 ——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。 **硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头 共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。 这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。 GQA 必须**同时**接进 T14 的 fused flash kernel(优先)与旧 composed/batched SDPA 路径,且 `num_kv_heads == num_heads`(`group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。 ## 什么是 GQA MHA:`num_heads` 个 query 头,每个配一个独立 K/V 头。 MQA(multi-query):所有 query 头共享**一个** K/V 头(极端)。 GQA:折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query 头共享。`num_kv_heads = num_heads` 退化为 MHA,`num_kv_heads = 1` 退化为 MQA。 ``` num_heads = 8, num_kv_heads = 2 → group = 4 q heads: 0 1 2 3 4 5 6 7 kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group(相邻分组,连续) ``` **分组约定必须与 xserv repeat_kv 一致**(闭环命门):xserv 的 `repeat_kv` (`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r` (`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的 repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。 ## Module Layout(surgical:复用已验证的两条 SDPA,GQA = 头维 broadcast op) ``` csrc/ops/repeat_kv.cu # 新:repeat_kv fwd(头块 gather)+ bwd(组内 group 行求和,无 atomic,确定性) crates/xtrain-cuda/ ├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明(no_cuda 门控) └── build.rs # +repeat_kv.cu crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward([B*kvh,S,hd]→[B*nh,S,hd];bf16 upcast→f32→downcast) crates/xtrain-autodiff/ ├── src/ops.rs # +ops::repeat_kv 节点(fwd broadcast,bwd 组内求和) └── tests/autograd.rs # +repeat_kv grad-check(含 group>1 的多组梯度累加) crates/xtrain-model/ ├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHA);from_arch 加形参;num_params 计 K/V 投影按 kv_dim ├── src/model.rs # wk/wv 投影到 kv_dim;attention() 在 SDPA 前对 K/V 做 ops::repeat_kv;两条路径都吃到 GQA └── tests/gqa.rs # 新:GQA(group>1) flash==composed + group=1 与 MHA 逐位一致 crates/xtrain-train/src/bin/train.rs # +--kv-heads flag crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flag(DDP 路径) crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-heads;config.json 写真 num_key_value_heads crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQA(kv 投影 + repeat_kv) ``` ## Key Design Decisions ### ① GQA = K/V 头维 broadcast op,喂给**未改动**的两条 SDPA(不写第三套 attention) T14 已经有两条**逐位/数值都验证过**的 SDPA:composed(`ops::attention`)与 fused flash (`ops::flash_attention`),二者都吃 `[bh, S, hd]`(`bh = batch·heads`)。GQA 的本质只是「K/V 比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法: - wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]` 排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。 - 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]` **broadcast** 成 `[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。 - 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的 `[B·nh, S, hd]`,GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。 **为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed 的两次 strided GEMM 各算 GQA stride,且两套都要重测——是「第三套 attention 改动」。而 broadcast-op 方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check 的小 op,正确性风险隔离在一处;(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**, 干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模 (训练、seq 不极端)可接受;真要省这份显存是 follow-up(kernel 内 GQA 读取),记进逃生舱不在 T15 做。 > 备选(不采纳):flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast > 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。 > escape hatch:先 broadcast-op 把正确性 + 闭环钉死,kernel-内 GQA(省显存)留 follow-up。 ### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic) `repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。 反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**: ``` din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r] ``` 这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入 (kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free 且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1` 时 反向退化为单行拷贝(identity)。 autograd 层面其实也可以靠引擎的扇出 SUM(把一个 kv Var 喂给 group 个下游),但那样图里要 显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的 broadcast op,fwd/bwd 各一发 kernel,最简且能单独 grad-check。 ### ③ `num_kv_heads` 进 Config(它改模型尺寸/导出),默认 = n_heads → 退化 MHA 不同于 T14 的 `use_flash`(运行时旗标,不进 Config),`num_kv_heads` **改 K/V 投影的形状、改参数量、 改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。 - `Config` 加 `num_kv_heads: usize`;`from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads = n_heads`(MHA)。约束:`num_heads % num_kv_heads == 0`(断言)。 - `num_params()`:K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`;QK-norm 的 `k_norm` 仍是 `[head_dim]`(per-head,作用在单个 head 向量上,与头数无关)→ 不变。 - **`num_kv_heads == n_heads` 时 `group=1`**:`ops::repeat_kv` 是 identity(fwd 单行拷贝、bwd 单行 拷贝),wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。 > wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]`:`Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`, > `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是 > wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。 ### ④ bf16 / recompute / dropout / DDP 全部自动兼容 - **bf16**:`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel → downcast;kernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后,dtype 与 K/V 流一致。 - **recompute(T13)**:repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()` 自然重跑 repeat_kv(无额外状态,确定性)→ 梯度仍逐位一致。 - **dropout(T18)**:dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。 - **DDP**:repeat_kv 不引入新参数;wk/wv 变小(kv_dim)只是参数张量小一圈,`params()` 泛化迭代 + all-reduce 照旧;跨 rank 一致性不受影响。 ### ⑤ 导出 xserv:写真 `num_key_value_heads`,分组约定对齐 repeat_kv `export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成 **真 `cfg.num_kv_heads`**;`--kv-heads` flag 传入(须与训练 ckpt 一致)。q/k/v_proj 各自按其 shape 转置导出(k/v_proj 现在是 `[kv_dim, dim]`,xserv loader 期望的 GQA 形状)。xserv 的 `repeat_kv` 用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致, 闭环(贪心逐 token 一致)成立。 ## 验证方法 全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090)。 ### 1. 正确性(硬闸门全绿,dash5 实跑 capture) - **repeat_kv finite-diff grad-check**(`autograd.rs::repeat_kv_grad`):**核心闸门**——`group>1` (如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。 外加 `group=1` identity 自检。 - **GQA flash==composed**(`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下, flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致(fp32 紧容差 + bf16 舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。 - **group=1 与 MHA 逐位一致**(`gqa.rs`):`num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型, forward + 每参数梯度 `|Δ|=0`(回归保护)。 - **PyTorch GQA 对拍 B>1**(`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQA(k/v 投影到 kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部 参数梯度(composed 与 flash 两条都跑,共用同一 oracle)。 - **小 GQA 配置短训收敛**:一个真 GQA 小模型短训,loss 单调降、无 NaN、采样连贯。 - **全回归套开/关**:autograd / structural / batched==looped / bf16 / recompute(逐位)/ overfit 27/27 / AdamW(GPU bit-exact + host 对 torch)/ DDP loss-match + 跨 rank(`--test-threads=1`)/ flash / grad_accum / dropout / **xserv 闭环 md5**。MHA 默认(kv=heads)图不变 → 不回归。 ### 2. 闭环(payoff)—— 真 GQA 端到端 导出一个 `num_key_value_heads < num_attention_heads` 的 GQA checkpoint → xserv 加载 → 贪心生成 **对 xtrain 自身逐 token 一致**(BF16 推理 vs f32 训练,与 v1–v8 同款判据)。这是 GQA 真正落地的证明: 训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。 ## 实测结果(dash5 1× / 2× RTX 5090) **硬闸门全绿:** | 闸门 | 结果 | |---|---| | ① repeat_kv grad-check(**多组 q 头梯度求和到一个 kv 头**,group=3) | **过** — din max_rel **2.05e-4**;group=1 identity 双向**逐位**(fwd/bwd |Δ|=0) | | GQA flash==composed(model 级 8h/2kv,logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 3.0e-4、grad **4.1e-5**;bf16: loss 9.0e-5、logits mean 2.9e-3/p99 1.0e-2、grad scaled-mean 8.9e-3 | | group=1 对 MHA**逐位一致**(回归保护) | **过** — logits + loss + 全部梯度 |Δ|=0 | | ② PyTorch GQA 对拍 B>1(composed & flash,repeat_interleave 分组对齐) | composed: loss **1.74e-8**/logits 2.04e-5/25 grad 进 rtol;flash: loss 1.74e-8/logits 2.28e-5/25 grad 进 rtol | | ③ 小 GQA 配置短训收敛(8h/2kv/hd32/4L/ffn1024,600 步) | train **10.90→3.15** 无 NaN、gnorm 稳 ~1.2、采样连贯英文(~200K tok/s) | | ④ **xserv 闭环真 GQA**(导出 `num_key_value_heads=2 < num_attention_heads=8`,xserv 加载 `heads=8/2 kv`,贪心) | "One day"/"The little" 两 prompt **逐 token 一致**;"Once upon a time" 在 `...Lily's mommy ` 处 BF16 漂移晚分叉(said vs came)——与 v1/v2/v3/T14 同款判据 | | ⑤ 回归套:autograd 23(含 repeat_kv 2)/ structural 5 / batched / bf16 / flash 2 / **gqa 4** / overfit 27/27 / recompute 2 / dropout 6 / grad_accum 3 / checkpoint-roundtrip / AdamW(host 对 torch 4.8e-6) / DDP 3(`--test-threads=1`, loss 5.67e-7+跨 rank 一致) / GEMM / tensor | **全绿** | | ⑤ MHA 默认 export md5(v3 ckpt 用 T15 代码重导 safetensors) | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry/T14 同)→ 默认(kv=heads)export 零漂移 | > **诚实记录**:闭环 2/3 prompt 完全 token-identical、1/3 在 BF16 漂移点晚分叉——这恰证明 GQA 分组**正确**:若 kv→q 头映射错,attention 会从第一个生成 token 起就崩(不会是深处近-tie 的晚分叉)。GQA 把 K/V 在显存里物化成满头 `[B·nh,S,hd]`(broadcast-op 方案的代价)——本规模可接受,kernel-内 GQA(省这份显存)留 follow-up。未为凑绿放宽任何容差。