Files
xtrain/docs/14-gqa.md
Gahow Wang 2ff4573a31 docs: T15 GQA results + evolution row (模型架构) + README build-journey row
Backfill docs/14-gqa.md gate table (dash5 numbers); add T15 evolution row +
cumulative 模型架构 line; README build-journey T15 row + Phase 2 prose + doc
index range (00..14).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:44:58 +08:00

181 lines
14 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase T15: Grouped-Query Attention (GQA) — Design Document
## Goal
到 T14 为止xtrain 的 attention 都是 **MHA**`num_kv_heads = num_heads`)——每个
query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads`
(退化 GQAdocs/08
T15 做**真正的 grouped-query attention**`num_kv_heads < num_heads`K/V 只投影到
`num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query
头**共享**repeat_kv / broadcast。GQA 是现代 LLMLlama-2-70B、Qwen2/3、Mistral的标配
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。
**硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头
共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
GQA 必须**同时**接进 T14 的 fused flash kernel优先与旧 composed/batched SDPA 路径,且
`num_kv_heads == num_heads``group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。
## 什么是 GQA
MHA`num_heads` 个 query 头,每个配一个独立 K/V 头。
MQAmulti-query所有 query 头共享**一个** K/V 头(极端)。
GQA折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query
头共享。`num_kv_heads = num_heads` 退化为 MHA`num_kv_heads = 1` 退化为 MQA。
```
num_heads = 8, num_kv_heads = 2 → group = 4
q heads: 0 1 2 3 4 5 6 7
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group相邻分组连续
```
**分组约定必须与 xserv repeat_kv 一致**闭环命门xserv 的 `repeat_kv`
`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r`
`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的
repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。
## Module Layoutsurgical复用已验证的两条 SDPAGQA = 头维 broadcast op
```
csrc/ops/repeat_kv.cu # 新repeat_kv fwd头块 gather+ bwd组内 group 行求和,无 atomic确定性
crates/xtrain-cuda/
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明no_cuda 门控)
└── build.rs # +repeat_kv.cu
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward[B*kvh,S,hd]→[B*nh,S,hd]bf16 upcast→f32→downcast
crates/xtrain-autodiff/
├── src/ops.rs # +ops::repeat_kv 节点fwd broadcastbwd 组内求和)
└── tests/autograd.rs # +repeat_kv grad-check含 group>1 的多组梯度累加)
crates/xtrain-model/
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHAfrom_arch 加形参num_params 计 K/V 投影按 kv_dim
├── src/model.rs # wk/wv 投影到 kv_dimattention() 在 SDPA 前对 K/V 做 ops::repeat_kv两条路径都吃到 GQA
└── tests/gqa.rs # 新GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flagDDP 路径)
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-headsconfig.json 写真 num_key_value_heads
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQAkv 投影 + repeat_kv
```
## Key Design Decisions
### ① GQA = K/V 头维 broadcast op喂给**未改动**的两条 SDPA不写第三套 attention
T14 已经有两条**逐位/数值都验证过**的 SDPAcomposed`ops::attention`)与 fused flash
`ops::flash_attention`),二者都吃 `[bh, S, hd]``bh = batch·heads`。GQA 的本质只是「K/V
比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法:
- wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]`
排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。
- 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]`
**broadcast**`[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。
- 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的
`[B·nh, S, hd]`GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。
**为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
的两次 strided GEMM 各算 GQA stride且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
的小 op正确性风险隔离在一处(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模
训练、seq 不极端)可接受;真要省这份显存是 follow-upkernel 内 GQA 读取),记进逃生舱不在 T15 做。
> 备选不采纳flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast
> 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。
> escape hatch先 broadcast-op 把正确性 + 闭环钉死kernel-内 GQA省显存留 follow-up。
### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic
`repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。
反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**
```
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
```
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入
kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free
且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1`
反向退化为单行拷贝identity
autograd 层面其实也可以靠引擎的扇出 SUM把一个 kv Var 喂给 group 个下游),但那样图里要
显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的
broadcast opfwd/bwd 各一发 kernel最简且能单独 grad-check。
### ③ `num_kv_heads` 进 Config它改模型尺寸/导出),默认 = n_heads → 退化 MHA
不同于 T14 的 `use_flash`(运行时旗标,不进 Config`num_kv_heads` **改 K/V 投影的形状、改参数量、
改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。
- `Config``num_kv_heads: usize``from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads =
n_heads`MHA。约束`num_heads % num_kv_heads == 0`(断言)。
- `num_params()`K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`QK-norm 的
`k_norm` 仍是 `[head_dim]`per-head作用在单个 head 向量上,与头数无关)→ 不变。
- **`num_kv_heads == n_heads` 时 `group=1`**`ops::repeat_kv` 是 identityfwd 单行拷贝、bwd 单行
拷贝wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。
> wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]``Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`
> `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是
> wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。
### ④ bf16 / recompute / dropout / DDP 全部自动兼容
- **bf16**`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel →
downcastkernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后dtype 与 K/V 流一致。
- **recomputeT13**repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()`
自然重跑 repeat_kv无额外状态确定性→ 梯度仍逐位一致。
- **dropoutT18**dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。
- **DDP**repeat_kv 不引入新参数wk/wv 变小kv_dim只是参数张量小一圈`params()` 泛化迭代
+ all-reduce 照旧;跨 rank 一致性不受影响。
### ⑤ 导出 xserv写真 `num_key_value_heads`,分组约定对齐 repeat_kv
`export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成
**真 `cfg.num_kv_heads`**`--kv-heads` flag 传入(须与训练 ckpt 一致。q/k/v_proj 各自按其 shape
转置导出k/v_proj 现在是 `[kv_dim, dim]`xserv loader 期望的 GQA 形状。xserv 的 `repeat_kv`
用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致,
闭环(贪心逐 token 一致)成立。
## 验证方法
全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash58× RTX 5090
### 1. 正确性硬闸门全绿dash5 实跑 capture
- **repeat_kv finite-diff grad-check**`autograd.rs::repeat_kv_grad`**核心闸门**——`group>1`
(如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。
外加 `group=1` identity 自检。
- **GQA flash==composed**`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下,
flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致fp32 紧容差 + bf16
舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
- **group=1 与 MHA 逐位一致**`gqa.rs``num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型,
forward + 每参数梯度 `|Δ|=0`(回归保护)。
- **PyTorch GQA 对拍 B>1**`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQAk/v 投影到
kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部
参数梯度composed 与 flash 两条都跑,共用同一 oracle
- **小 GQA 配置短训收敛**:一个真 GQA 小模型短训loss 单调降、无 NaN、采样连贯。
- **全回归套开/关**autograd / structural / batched==looped / bf16 / recompute逐位/ overfit 27/27 /
AdamWGPU bit-exact + host 对 torch/ DDP loss-match + 跨 rank`--test-threads=1`/ flash /
grad_accum / dropout / **xserv 闭环 md5**。MHA 默认kv=heads图不变 → 不回归。
### 2. 闭环payoff—— 真 GQA 端到端
导出一个 `num_key_value_heads < num_attention_heads` 的 GQA checkpoint → xserv 加载 → 贪心生成
**对 xtrain 自身逐 token 一致**BF16 推理 vs f32 训练,与 v1v8 同款判据)。这是 GQA 真正落地的证明:
训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。
## 实测结果dash5 1× / 2× RTX 5090
**硬闸门全绿:**
| 闸门 | 结果 |
|---|---|
| ① repeat_kv grad-check**多组 q 头梯度求和到一个 kv 头**group=3 | **过** — din max_rel **2.05e-4**group=1 identity 双向**逐位**fwd/bwd |Δ|=0 |
| GQA flash==composedmodel 级 8h/2kvlogits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 3.0e-4、grad **4.1e-5**bf16: loss 9.0e-5、logits mean 2.9e-3/p99 1.0e-2、grad scaled-mean 8.9e-3 |
| group=1 对 MHA**逐位一致**(回归保护) | **过** — logits + loss + 全部梯度 |Δ|=0 |
| ② PyTorch GQA 对拍 B>1composed & flashrepeat_interleave 分组对齐) | composed: loss **1.74e-8**/logits 2.04e-5/25 grad 进 rtolflash: loss 1.74e-8/logits 2.28e-5/25 grad 进 rtol |
| ③ 小 GQA 配置短训收敛8h/2kv/hd32/4L/ffn1024600 步) | train **10.90→3.15** 无 NaN、gnorm 稳 ~1.2、采样连贯英文(~200K tok/s |
| ④ **xserv 闭环真 GQA**(导出 `num_key_value_heads=2 < num_attention_heads=8`xserv 加载 `heads=8/2 kv`,贪心) | "One day"/"The little" 两 prompt **逐 token 一致**"Once upon a time" 在 `...Lily's mommy ` 处 BF16 漂移晚分叉said vs came——与 v1/v2/v3/T14 同款判据 |
| ⑤ 回归套autograd 23含 repeat_kv 2/ structural 5 / batched / bf16 / flash 2 / **gqa 4** / overfit 27/27 / recompute 2 / dropout 6 / grad_accum 3 / checkpoint-roundtrip / AdamW(host 对 torch 4.8e-6) / DDP 3(`--test-threads=1`, loss 5.67e-7+跨 rank 一致) / GEMM / tensor | **全绿** |
| ⑤ MHA 默认 export md5v3 ckpt 用 T15 代码重导 safetensors | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry/T14 同)→ 默认kv=headsexport 零漂移 |
> **诚实记录**:闭环 2/3 prompt 完全 token-identical、1/3 在 BF16 漂移点晚分叉——这恰证明 GQA 分组**正确**:若 kv→q 头映射错attention 会从第一个生成 token 起就崩(不会是深处近-tie 的晚分叉。GQA 把 K/V 在显存里物化成满头 `[B·nh,S,hd]`broadcast-op 方案的代价)——本规模可接受kernel-内 GQA省这份显存 follow-up未为凑绿放宽任何容差