docs: Phase T15 — GQA design (repeat_kv broadcast op + backward grad-sum)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 01:30:34 +08:00
parent 4b6d3e0a79
commit 62b1cb5dc7

167
docs/14-gqa.md Normal file
View File

@@ -0,0 +1,167 @@
# Phase T15: Grouped-Query Attention (GQA) — Design Document
## Goal
到 T14 为止xtrain 的 attention 都是 **MHA**`num_kv_heads = num_heads`)——每个
query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads`
(退化 GQAdocs/08
T15 做**真正的 grouped-query attention**`num_kv_heads < num_heads`K/V 只投影到
`num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query
头**共享**repeat_kv / broadcast。GQA 是现代 LLMLlama-2-70B、Qwen2/3、Mistral的标配
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。
**硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头
共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
GQA 必须**同时**接进 T14 的 fused flash kernel优先与旧 composed/batched SDPA 路径,且
`num_kv_heads == num_heads``group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。
## 什么是 GQA
MHA`num_heads` 个 query 头,每个配一个独立 K/V 头。
MQAmulti-query所有 query 头共享**一个** K/V 头(极端)。
GQA折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query
头共享。`num_kv_heads = num_heads` 退化为 MHA`num_kv_heads = 1` 退化为 MQA。
```
num_heads = 8, num_kv_heads = 2 → group = 4
q heads: 0 1 2 3 4 5 6 7
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group相邻分组连续
```
**分组约定必须与 xserv repeat_kv 一致**闭环命门xserv 的 `repeat_kv`
`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r`
`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的
repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。
## Module Layoutsurgical复用已验证的两条 SDPAGQA = 头维 broadcast op
```
csrc/ops/repeat_kv.cu # 新repeat_kv fwd头块 gather+ bwd组内 group 行求和,无 atomic确定性
crates/xtrain-cuda/
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明no_cuda 门控)
└── build.rs # +repeat_kv.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward[B*kvh,S,hd]→[B*nh,S,hd]bf16 upcast→f32→downcast
crates/xtrain-autodiff/
├── src/ops.rs # +ops::repeat_kv 节点fwd broadcastbwd 组内求和)
└── tests/autograd.rs # +repeat_kv grad-check含 group>1 的多组梯度累加)
crates/xtrain-model/
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHAfrom_arch 加形参num_params 计 K/V 投影按 kv_dim
├── src/model.rs # wk/wv 投影到 kv_dimattention() 在 SDPA 前对 K/V 做 ops::repeat_kv两条路径都吃到 GQA
└── tests/gqa.rs # 新GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flagDDP 路径)
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-headsconfig.json 写真 num_key_value_heads
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQAkv 投影 + repeat_kv
```
## Key Design Decisions
### ① GQA = K/V 头维 broadcast op喂给**未改动**的两条 SDPA不写第三套 attention
T14 已经有两条**逐位/数值都验证过**的 SDPAcomposed`ops::attention`)与 fused flash
`ops::flash_attention`),二者都吃 `[bh, S, hd]``bh = batch·heads`。GQA 的本质只是「K/V
比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法:
- wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]`
排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。
- 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]`
**broadcast**`[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。
- 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的
`[B·nh, S, hd]`GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。
**为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
的两次 strided GEMM 各算 GQA stride且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
的小 op正确性风险隔离在一处(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模
训练、seq 不极端)可接受;真要省这份显存是 follow-upkernel 内 GQA 读取),记进逃生舱不在 T15 做。
> 备选不采纳flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast
> 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。
> escape hatch先 broadcast-op 把正确性 + 闭环钉死kernel-内 GQA省显存留 follow-up。
### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic
`repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。
反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**
```
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
```
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入
kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free
且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1`
反向退化为单行拷贝identity
autograd 层面其实也可以靠引擎的扇出 SUM把一个 kv Var 喂给 group 个下游),但那样图里要
显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的
broadcast opfwd/bwd 各一发 kernel最简且能单独 grad-check。
### ③ `num_kv_heads` 进 Config它改模型尺寸/导出),默认 = n_heads → 退化 MHA
不同于 T14 的 `use_flash`(运行时旗标,不进 Config`num_kv_heads` **改 K/V 投影的形状、改参数量、
改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。
- `Config``num_kv_heads: usize``from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads =
n_heads`MHA。约束`num_heads % num_kv_heads == 0`(断言)。
- `num_params()`K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`QK-norm 的
`k_norm` 仍是 `[head_dim]`per-head作用在单个 head 向量上,与头数无关)→ 不变。
- **`num_kv_heads == n_heads` 时 `group=1`**`ops::repeat_kv` 是 identityfwd 单行拷贝、bwd 单行
拷贝wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。
> wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]``Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`
> `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是
> wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。
### ④ bf16 / recompute / dropout / DDP 全部自动兼容
- **bf16**`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel →
downcastkernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后dtype 与 K/V 流一致。
- **recomputeT13**repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()`
自然重跑 repeat_kv无额外状态确定性→ 梯度仍逐位一致。
- **dropoutT18**dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。
- **DDP**repeat_kv 不引入新参数wk/wv 变小kv_dim只是参数张量小一圈`params()` 泛化迭代
+ all-reduce 照旧;跨 rank 一致性不受影响。
### ⑤ 导出 xserv写真 `num_key_value_heads`,分组约定对齐 repeat_kv
`export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成
**真 `cfg.num_kv_heads`**`--kv-heads` flag 传入(须与训练 ckpt 一致。q/k/v_proj 各自按其 shape
转置导出k/v_proj 现在是 `[kv_dim, dim]`xserv loader 期望的 GQA 形状。xserv 的 `repeat_kv`
用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致,
闭环(贪心逐 token 一致)成立。
## 验证方法
全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash58× RTX 5090
### 1. 正确性硬闸门全绿dash5 实跑 capture
- **repeat_kv finite-diff grad-check**`autograd.rs::repeat_kv_grad`**核心闸门**——`group>1`
(如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。
外加 `group=1` identity 自检。
- **GQA flash==composed**`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下,
flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致fp32 紧容差 + bf16
舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
- **group=1 与 MHA 逐位一致**`gqa.rs``num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型,
forward + 每参数梯度 `|Δ|=0`(回归保护)。
- **PyTorch GQA 对拍 B>1**`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQAk/v 投影到
kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部
参数梯度composed 与 flash 两条都跑,共用同一 oracle
- **小 GQA 配置短训收敛**:一个真 GQA 小模型短训loss 单调降、无 NaN、采样连贯。
- **全回归套开/关**autograd / structural / batched==looped / bf16 / recompute逐位/ overfit 27/27 /
AdamWGPU bit-exact + host 对 torch/ DDP loss-match + 跨 rank`--test-threads=1`/ flash /
grad_accum / dropout / **xserv 闭环 md5**。MHA 默认kv=heads图不变 → 不回归。
### 2. 闭环payoff—— 真 GQA 端到端
导出一个 `num_key_value_heads < num_attention_heads` GQA checkpoint xserv 加载 贪心生成
**对 xtrain 自身逐 token 一致**BF16 推理 vs f32 训练 v1v8 同款判据)。这是 GQA 真正落地的证明
训练侧的分组导出的分组推理侧 xserv repeat_kv 分组三方对齐
## 实测结果dash5
> 待 dash5 实跑回填gate 表 + 数字)。