12 KiB
Phase T15: Grouped-Query Attention (GQA) — Design Document
Goal
到 T14 为止,xtrain 的 attention 都是 MHA(num_kv_heads = num_heads)——每个
query 头有自己独立的 K/V 头。导出 xserv 时 num_key_value_heads = num_attention_heads
(退化 GQA,docs/08)。
T15 做真正的 grouped-query attention:num_kv_heads < num_heads,K/V 只投影到
num_kv_heads · head_dim,每个 KV 头被一组 group = num_heads / num_kv_heads 个 query
头共享(repeat_kv / broadcast)。GQA 是现代 LLM(Llama-2-70B、Qwen2/3、Mistral)的标配
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 group 倍压缩,几乎不掉质量。
硬闸门是诚实正确性,重点在 repeat_kv 的反向梯度累加:一个 KV 头被 group 个 query 头
共享,反向时这 group 个 query 头各自对该 KV 头的梯度必须正确求和回到那一个共享 KV 头上。
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
GQA 必须同时接进 T14 的 fused flash kernel(优先)与旧 composed/batched SDPA 路径,且
num_kv_heads == num_heads(group = 1)时与现有 MHA 路径逐位一致(回归保护)。
什么是 GQA
MHA:num_heads 个 query 头,每个配一个独立 K/V 头。
MQA(multi-query):所有 query 头共享一个 K/V 头(极端)。
GQA:折中——num_kv_heads 个 K/V 头,每个被 group = num_heads/num_kv_heads 个相邻 query
头共享。num_kv_heads = num_heads 退化为 MHA,num_kv_heads = 1 退化为 MQA。
num_heads = 8, num_kv_heads = 2 → group = 4
q heads: 0 1 2 3 4 5 6 7
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group(相邻分组,连续)
分组约定必须与 xserv repeat_kv 一致(闭环命门):xserv 的 repeat_kv
(crates/xserv-model/src/qwen3.rs)把 kv 头 kvh 复制到目标头 dst = kvh*group + r
(r∈[0,group)),即query 头 qh 读 kv 头 qh/group,组内 query 头连续。xtrain 的
repeat_kv 用同一映射,否则导出的 q_proj 行块与 kv 头对不上 → 闭环必崩。
Module Layout(surgical:复用已验证的两条 SDPA,GQA = 头维 broadcast op)
csrc/ops/repeat_kv.cu # 新:repeat_kv fwd(头块 gather)+ bwd(组内 group 行求和,无 atomic,确定性)
crates/xtrain-cuda/
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明(no_cuda 门控)
└── build.rs # +repeat_kv.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward([B*kvh,S,hd]→[B*nh,S,hd];bf16 upcast→f32→downcast)
crates/xtrain-autodiff/
├── src/ops.rs # +ops::repeat_kv 节点(fwd broadcast,bwd 组内求和)
└── tests/autograd.rs # +repeat_kv grad-check(含 group>1 的多组梯度累加)
crates/xtrain-model/
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHA);from_arch 加形参;num_params 计 K/V 投影按 kv_dim
├── src/model.rs # wk/wv 投影到 kv_dim;attention() 在 SDPA 前对 K/V 做 ops::repeat_kv;两条路径都吃到 GQA
└── tests/gqa.rs # 新:GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flag(DDP 路径)
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-heads;config.json 写真 num_key_value_heads
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQA(kv 投影 + repeat_kv)
Key Design Decisions
① GQA = K/V 头维 broadcast op,喂给未改动的两条 SDPA(不写第三套 attention)
T14 已经有两条逐位/数值都验证过的 SDPA:composed(ops::attention)与 fused flash
(ops::flash_attention),二者都吃 [bh, S, hd](bh = batch·heads)。GQA 的本质只是「K/V
比 Q 少 group 倍头,用前把每个 kv 头复制 group 份」。所以最外科的做法:
- wk/wv 投影到
kv_dim = num_kv_heads · head_dim,按[B, num_kv, S, hd] → [B·num_kv, S, hd]排好(和 Q 的[B·nh, S, hd]同流水线,只是头数不同)。 - 在调 SDPA 之前,对 K、V 各做一个新 autograd op
ops::repeat_kv,把[B·num_kv, S, hd]broadcast 成[B·nh, S, hd](输出行b·nh + qh= 输入行b·num_kv + qh/group的拷贝)。 - 之后
ops::attention/ops::flash_attention一字不改——它们看到的就是满头的[B·nh, S, hd],GQA 对两条路径同时、免费生效。flash kernel / composed kernel 都不用碰。
为什么不在 kernel 内做 GQA:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
的两次 strided GEMM 各算 GQA stride,且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
的小 op,正确性风险隔离在一处;(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的反向,
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 [B·nh,S,hd](多 group 倍)——本规模
(训练、seq 不极端)可接受;真要省这份显存是 follow-up(kernel 内 GQA 读取),记进逃生舱不在 T15 做。
备选(不采纳):flash/composed kernel 内直接按
kv_head = q_head/group索引 K/V。省 broadcast 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。 escape hatch:先 broadcast-op 把正确性 + 闭环钉死,kernel-内 GQA(省显存)留 follow-up。
② repeat_kv 的反向 = 组内求和(确定性,无 atomic)
repeat_kv 前向:out[b·nh + qh] = in[b·num_kv + qh/group](按 S·hd 整行拷贝)。
反向是它的转置:一个 kv 头收到它那 group 个 query 头的梯度之和:
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上不用 atomicAdd:每个输入
(kv-head, 元素)由唯一一个 block 负责,它串行累加自己那 group 个连续源行——天然 race-free
且run-to-run 确定(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。group=1 时
反向退化为单行拷贝(identity)。
autograd 层面其实也可以靠引擎的扇出 SUM(把一个 kv Var 喂给 group 个下游),但那样图里要 显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的 broadcast op,fwd/bwd 各一发 kernel,最简且能单独 grad-check。
③ num_kv_heads 进 Config(它改模型尺寸/导出),默认 = n_heads → 退化 MHA
不同于 T14 的 use_flash(运行时旗标,不进 Config),num_kv_heads 改 K/V 投影的形状、改参数量、
改导出的 num_key_value_heads——它是架构的一部分,必须进 Config 并落进 checkpoint/导出。
Config加num_kv_heads: usize;from_arch加该形参;Config::tiny()默认num_kv_heads = n_heads(MHA)。约束:num_heads % num_kv_heads == 0(断言)。num_params():K/V 投影从2·dim·dim改成2·dim·(num_kv_heads·head_dim);QK-norm 的k_norm仍是[head_dim](per-head,作用在单个 head 向量上,与头数无关)→ 不变。num_kv_heads == n_heads时group=1:ops::repeat_kv是 identity(fwd 单行拷贝、bwd 单行 拷贝),wk/wv 形状回到[dim,dim]→ 整条图与 T14 的 MHA 路径逐位一致(回归保护闸门)。
wk/wv 形状从
[dim,dim]变成[dim, kv_dim]:Block里 wk/wv 的mk(&[dim, kv_dim]),params()/block_params()顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是 wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(transpose读v.value().shape())。
④ bf16 / recompute / dropout / DDP 全部自动兼容
- bf16:
Tensor::repeat_kv沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel → downcast;kernel 只一份 f32。ops::repeat_kv的 fwd/bwd 都在 SDPA 之前/之后,dtype 与 K/V 流一致。 - recompute(T13):repeat_kv 在
block_forward内、attention()里,重算段重跑attention()自然重跑 repeat_kv(无额外状态,确定性)→ 梯度仍逐位一致。 - dropout(T18):dropout 接在 attn/mlp 子块输出,与 attention 内部的 repeat_kv 正交,不交互。
- DDP:repeat_kv 不引入新参数;wk/wv 变小(kv_dim)只是参数张量小一圈,
params()泛化迭代- all-reduce 照旧;跨 rank 一致性不受影响。
⑤ 导出 xserv:写真 num_key_value_heads,分组约定对齐 repeat_kv
export_safetensors.rs 的 config.json 把 num_key_value_heads 从「= num_attention_heads」改成
真 cfg.num_kv_heads;--kv-heads flag 传入(须与训练 ckpt 一致)。q/k/v_proj 各自按其 shape
转置导出(k/v_proj 现在是 [kv_dim, dim],xserv loader 期望的 GQA 形状)。xserv 的 repeat_kv
用 dst = kvh·group + r 分组,与 ① 的 xtrain 约定逐头对齐 → 同一份权重在两侧前向数学一致,
闭环(贪心逐 token 一致)成立。
验证方法
全部 #![cfg(not(no_cuda))] 门控,本地 cargo check/fmt,构建+实跑在 dash5(8× RTX 5090)。
1. 正确性(硬闸门全绿,dash5 实跑 capture)
- repeat_kv finite-diff grad-check(
autograd.rs::repeat_kv_grad):核心闸门——group>1(如 bh: 2 kv 头 → 6 q 头)下 grad-checkdin,验证「多组 q 头梯度求和到一个 kv 头」的反向。 外加group=1identity 自检。 - GQA flash==composed(
gqa.rs):真 GQA 配置(num_kv_heads < n_heads,如 8 头/2 kv 头)下, flash on/off 两个同 init 模型,断 forward logits / loss / 每参数梯度一致(fp32 紧容差 + bf16 舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。 - group=1 与 MHA 逐位一致(
gqa.rs):num_kv_heads = n_heads的模型对 T14 的 MHA 模型, forward + 每参数梯度|Δ|=0(回归保护)。 - PyTorch GQA 对拍 B>1(
parity_dump.rs+parity.py):等价 PyTorch 模型加 GQA(k/v 投影到 kv_dim +repeat_interleave(group)分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部 参数梯度(composed 与 flash 两条都跑,共用同一 oracle)。 - 小 GQA 配置短训收敛:一个真 GQA 小模型短训,loss 单调降、无 NaN、采样连贯。
- 全回归套开/关:autograd / structural / batched==looped / bf16 / recompute(逐位)/ overfit 27/27 /
AdamW(GPU bit-exact + host 对 torch)/ DDP loss-match + 跨 rank(
--test-threads=1)/ flash / grad_accum / dropout / xserv 闭环 md5。MHA 默认(kv=heads)图不变 → 不回归。
2. 闭环(payoff)—— 真 GQA 端到端
导出一个 num_key_value_heads < num_attention_heads 的 GQA checkpoint → xserv 加载 → 贪心生成
对 xtrain 自身逐 token 一致(BF16 推理 vs f32 训练,与 v1–v8 同款判据)。这是 GQA 真正落地的证明:
训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。
实测结果(dash5)
待 dash5 实跑回填(gate 表 + 数字)。