Files
xtrain/docs/14-gqa.md

12 KiB
Raw Blame History

Phase T15: Grouped-Query Attention (GQA) — Design Document

Goal

到 T14 为止xtrain 的 attention 都是 MHAnum_kv_heads = num_heads)——每个 query 头有自己独立的 K/V 头。导出 xserv 时 num_key_value_heads = num_attention_heads (退化 GQAdocs/08

T15 做真正的 grouped-query attentionnum_kv_heads < num_headsK/V 只投影到 num_kv_heads · head_dim,每个 KV 头被一组 group = num_heads / num_kv_heads 个 query 头共享repeat_kv / broadcast。GQA 是现代 LLMLlama-2-70B、Qwen2/3、Mistral的标配 ——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 group 倍压缩,几乎不掉质量。

硬闸门是诚实正确性,重点在 repeat_kv 的反向梯度累加:一个 KV 头被 group 个 query 头 共享,反向时这 group 个 query 头各自对该 KV 头的梯度必须正确求和回到那一个共享 KV 头上。 这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。

GQA 必须同时接进 T14 的 fused flash kernel优先与旧 composed/batched SDPA 路径,且 num_kv_heads == num_headsgroup = 1)时与现有 MHA 路径逐位一致(回归保护)。

什么是 GQA

MHAnum_heads 个 query 头,每个配一个独立 K/V 头。 MQAmulti-query所有 query 头共享一个 K/V 头(极端)。 GQA折中——num_kv_heads 个 K/V 头,每个被 group = num_heads/num_kv_heads 个相邻 query 头共享。num_kv_heads = num_heads 退化为 MHAnum_kv_heads = 1 退化为 MQA。

num_heads = 8, num_kv_heads = 2  →  group = 4
q heads:  0 1 2 3 4 5 6 7
kv heads: 0 0 0 0 1 1 1 1        # q head qh 用 kv head qh/group相邻分组连续

分组约定必须与 xserv repeat_kv 一致闭环命门xserv 的 repeat_kv crates/xserv-model/src/qwen3.rs)把 kv 头 kvh 复制到目标头 dst = kvh*group + r r∈[0,group)),即query 头 qh 读 kv 头 qh/group,组内 query 头连续。xtrain 的 repeat_kv 用同一映射,否则导出的 q_proj 行块与 kv 头对不上 → 闭环必崩。

Module Layoutsurgical复用已验证的两条 SDPAGQA = 头维 broadcast op

csrc/ops/repeat_kv.cu                     # 新repeat_kv fwd头块 gather+ bwd组内 group 行求和,无 atomic确定性
crates/xtrain-cuda/
├── src/ffi.rs                            # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明no_cuda 门控)
└── build.rs                              # +repeat_kv.cu
crates/xtrain-tensor/src/tensor.rs        # +Tensor::repeat_kv / repeat_kv_backward[B*kvh,S,hd]→[B*nh,S,hd]bf16 upcast→f32→downcast
crates/xtrain-autodiff/
├── src/ops.rs                            # +ops::repeat_kv 节点fwd broadcastbwd 组内求和)
└── tests/autograd.rs                     # +repeat_kv grad-check含 group>1 的多组梯度累加)
crates/xtrain-model/
├── src/config.rs                         # +num_kv_heads 字段(默认 = n_heads → MHAfrom_arch 加形参num_params 计 K/V 投影按 kv_dim
├── src/model.rs                          # wk/wv 投影到 kv_dimattention() 在 SDPA 前对 K/V 做 ops::repeat_kv两条路径都吃到 GQA
└── tests/gqa.rs                          # 新GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
crates/xtrain-train/src/bin/train.rs            # +--kv-heads flag
crates/xtrain-distributed/src/bin/train_ddp.rs  # +--kv-heads flagDDP 路径)
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-headsconfig.json 写真 num_key_value_heads
crates/xtrain-model/tests/parity{.py,_dump.rs}  # PyTorch 对拍加 GQAkv 投影 + repeat_kv

Key Design Decisions

① GQA = K/V 头维 broadcast op喂给未改动的两条 SDPA不写第三套 attention

T14 已经有两条逐位/数值都验证过的 SDPAcomposedops::attention)与 fused flash ops::flash_attention),二者都吃 [bh, S, hd]bh = batch·heads。GQA 的本质只是「K/V 比 Q 少 group 倍头,用前把每个 kv 头复制 group 份」。所以最外科的做法:

  • wk/wv 投影到 kv_dim = num_kv_heads · head_dim,按 [B, num_kv, S, hd] → [B·num_kv, S, hd] 排好(和 Q 的 [B·nh, S, hd] 同流水线,只是头数不同)。
  • 在调 SDPA 之前,对 K、V 各做一个新 autograd op ops::repeat_kv,把 [B·num_kv, S, hd] broadcast[B·nh, S, hd](输出行 b·nh + qh = 输入行 b·num_kv + qh/group 的拷贝)。
  • 之后 ops::attention / ops::flash_attention 一字不改——它们看到的就是满头的 [B·nh, S, hd]GQA 对两条路径同时、免费生效。flash kernel / composed kernel 都不用碰。

为什么不在 kernel 内做 GQA:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed 的两次 strided GEMM 各算 GQA stride且两套都要重测——是「第三套 attention 改动」。而 broadcast-op 方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check 的小 op正确性风险隔离在一处(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的反向 干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 [B·nh,S,hd](多 group 倍)——本规模 训练、seq 不极端)可接受;真要省这份显存是 follow-upkernel 内 GQA 读取),记进逃生舱不在 T15 做。

备选不采纳flash/composed kernel 内直接按 kv_head = q_head/group 索引 K/V。省 broadcast 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。 escape hatch先 broadcast-op 把正确性 + 闭环钉死kernel-内 GQA省显存留 follow-up。

② repeat_kv 的反向 = 组内求和(确定性,无 atomic

repeat_kv 前向:out[b·nh + qh] = in[b·num_kv + qh/group](按 S·hd 整行拷贝)。

反向是它的转置:一个 kv 头收到它那 group 个 query 头的梯度之

din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]

这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上不用 atomicAdd:每个输入 kv-head, 元素)由唯一一个 block 负责,它串行累加自己那 group 个连续源行——天然 race-free 且run-to-run 确定(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。group=1 时 反向退化为单行拷贝identity

autograd 层面其实也可以靠引擎的扇出 SUM把一个 kv Var 喂给 group 个下游),但那样图里要 显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的 broadcast opfwd/bwd 各一发 kernel最简且能单独 grad-check。

num_kv_heads 进 Config它改模型尺寸/导出),默认 = n_heads → 退化 MHA

不同于 T14 的 use_flash(运行时旗标,不进 Confignum_kv_heads 改 K/V 投影的形状、改参数量、 改导出的 num_key_value_heads——它是架构的一部分,必须进 Config 并落进 checkpoint/导出。

  • Confignum_kv_heads: usizefrom_arch 加该形参;Config::tiny() 默认 num_kv_heads = n_headsMHA。约束num_heads % num_kv_heads == 0(断言)。
  • num_params()K/V 投影从 2·dim·dim 改成 2·dim·(num_kv_heads·head_dim)QK-norm 的 k_norm 仍是 [head_dim]per-head作用在单个 head 向量上,与头数无关)→ 不变。
  • num_kv_heads == n_headsgroup=1ops::repeat_kv 是 identityfwd 单行拷贝、bwd 单行 拷贝wk/wv 形状回到 [dim,dim] → 整条图与 T14 的 MHA 路径逐位一致(回归保护闸门)。

wk/wv 形状从 [dim,dim] 变成 [dim, kv_dim]Block 里 wk/wv 的 mk(&[dim, kv_dim]) params()/block_params() 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是 wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(transposev.value().shape())。

④ bf16 / recompute / dropout / DDP 全部自动兼容

  • bf16Tensor::repeat_kv 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel → downcastkernel 只一份 f32。ops::repeat_kv 的 fwd/bwd 都在 SDPA 之前/之后dtype 与 K/V 流一致。
  • recomputeT13repeat_kv 在 block_forward 内、attention() 里,重算段重跑 attention() 自然重跑 repeat_kv无额外状态确定性→ 梯度仍逐位一致。
  • dropoutT18dropout 接在 attn/mlp 子块输出,与 attention 内部的 repeat_kv 正交,不交互。
  • DDPrepeat_kv 不引入新参数wk/wv 变小kv_dim只是参数张量小一圈params() 泛化迭代
    • all-reduce 照旧;跨 rank 一致性不受影响。

⑤ 导出 xserv写真 num_key_value_heads,分组约定对齐 repeat_kv

export_safetensors.rsconfig.jsonnum_key_value_heads 从「= num_attention_heads」改成 cfg.num_kv_heads--kv-heads flag 传入(须与训练 ckpt 一致。q/k/v_proj 各自按其 shape 转置导出k/v_proj 现在是 [kv_dim, dim]xserv loader 期望的 GQA 形状。xserv 的 repeat_kvdst = kvh·group + r 分组,与 ① 的 xtrain 约定逐头对齐 → 同一份权重在两侧前向数学一致, 闭环(贪心逐 token 一致)成立。

验证方法

全部 #![cfg(not(no_cuda))] 门控,本地 cargo check/fmt,构建+实跑在 dash58× RTX 5090

1. 正确性硬闸门全绿dash5 实跑 capture

  • repeat_kv finite-diff grad-checkautograd.rs::repeat_kv_grad核心闸门——group>1 (如 bh: 2 kv 头 → 6 q 头)下 grad-check din,验证「多组 q 头梯度求和到一个 kv 头」的反向。 外加 group=1 identity 自检。
  • GQA flash==composedgqa.rs):真 GQA 配置(num_kv_heads < n_heads,如 8 头/2 kv 头)下, flash on/off 两个同 init 模型,断 forward logits / loss / 每参数梯度一致fp32 紧容差 + bf16 舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
  • group=1 与 MHA 逐位一致gqa.rsnum_kv_heads = n_heads 的模型对 T14 的 MHA 模型, forward + 每参数梯度 |Δ|=0(回归保护)。
  • PyTorch GQA 对拍 B>1parity_dump.rs + parity.py):等价 PyTorch 模型加 GQAk/v 投影到 kv_dim + repeat_interleave(group) 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部 参数梯度composed 与 flash 两条都跑,共用同一 oracle
  • 小 GQA 配置短训收敛:一个真 GQA 小模型短训loss 单调降、无 NaN、采样连贯。
  • 全回归套开/关autograd / structural / batched==looped / bf16 / recompute逐位/ overfit 27/27 / AdamWGPU bit-exact + host 对 torch/ DDP loss-match + 跨 rank--test-threads=1/ flash / grad_accum / dropout / xserv 闭环 md5。MHA 默认kv=heads图不变 → 不回归。

2. 闭环payoff—— 真 GQA 端到端

导出一个 num_key_value_heads < num_attention_heads 的 GQA checkpoint → xserv 加载 → 贪心生成 对 xtrain 自身逐 token 一致BF16 推理 vs f32 训练,与 v1v8 同款判据)。这是 GQA 真正落地的证明: 训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。

实测结果dash5

待 dash5 实跑回填gate 表 + 数字)。