Backfill docs/14-gqa.md gate table (dash5 numbers); add T15 evolution row + cumulative 模型架构 line; README build-journey T15 row + Phase 2 prose + doc index range (00..14). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
181 lines
14 KiB
Markdown
181 lines
14 KiB
Markdown
# Phase T15: Grouped-Query Attention (GQA) — Design Document
|
||
|
||
## Goal
|
||
|
||
到 T14 为止,xtrain 的 attention 都是 **MHA**(`num_kv_heads = num_heads`)——每个
|
||
query 头有自己独立的 K/V 头。导出 xserv 时 `num_key_value_heads = num_attention_heads`
|
||
(退化 GQA,docs/08)。
|
||
|
||
T15 做**真正的 grouped-query attention**:`num_kv_heads < num_heads`,K/V 只投影到
|
||
`num_kv_heads · head_dim`,每个 KV 头被一组 `group = num_heads / num_kv_heads` 个 query
|
||
头**共享**(repeat_kv / broadcast)。GQA 是现代 LLM(Llama-2-70B、Qwen2/3、Mistral)的标配
|
||
——它把 KV cache 显存(推理)与 K/V 投影参数(训练)按 `group` 倍压缩,几乎不掉质量。
|
||
|
||
**硬闸门是诚实正确性**,重点在 **repeat_kv 的反向梯度累加**:一个 KV 头被 `group` 个 query 头
|
||
共享,反向时这 `group` 个 query 头各自对该 KV 头的梯度必须**正确求和**回到那一个共享 KV 头上。
|
||
这条「多组 q 头梯度汇到一个 kv 头」的累加路径是本任务最易错处,单列为首要 grad-check 闸门。
|
||
|
||
GQA 必须**同时**接进 T14 的 fused flash kernel(优先)与旧 composed/batched SDPA 路径,且
|
||
`num_kv_heads == num_heads`(`group = 1`)时与现有 MHA 路径**逐位一致**(回归保护)。
|
||
|
||
## 什么是 GQA
|
||
|
||
MHA:`num_heads` 个 query 头,每个配一个独立 K/V 头。
|
||
MQA(multi-query):所有 query 头共享**一个** K/V 头(极端)。
|
||
GQA:折中——`num_kv_heads` 个 K/V 头,每个被 `group = num_heads/num_kv_heads` 个相邻 query
|
||
头共享。`num_kv_heads = num_heads` 退化为 MHA,`num_kv_heads = 1` 退化为 MQA。
|
||
|
||
```
|
||
num_heads = 8, num_kv_heads = 2 → group = 4
|
||
q heads: 0 1 2 3 4 5 6 7
|
||
kv heads: 0 0 0 0 1 1 1 1 # q head qh 用 kv head qh/group(相邻分组,连续)
|
||
```
|
||
|
||
**分组约定必须与 xserv repeat_kv 一致**(闭环命门):xserv 的 `repeat_kv`
|
||
(`crates/xserv-model/src/qwen3.rs`)把 kv 头 `kvh` 复制到目标头 `dst = kvh*group + r`
|
||
(`r∈[0,group)`),即**query 头 `qh` 读 kv 头 `qh/group`,组内 query 头连续**。xtrain 的
|
||
repeat_kv 用同一映射,否则导出的 `q_proj` 行块与 kv 头对不上 → 闭环必崩。
|
||
|
||
## Module Layout(surgical:复用已验证的两条 SDPA,GQA = 头维 broadcast op)
|
||
|
||
```
|
||
csrc/ops/repeat_kv.cu # 新:repeat_kv fwd(头块 gather)+ bwd(组内 group 行求和,无 atomic,确定性)
|
||
crates/xtrain-cuda/
|
||
├── src/ffi.rs # +launch_repeat_kv_fwd_f32 / _bwd_f32 声明(no_cuda 门控)
|
||
└── build.rs # +repeat_kv.cu
|
||
crates/xtrain-tensor/src/tensor.rs # +Tensor::repeat_kv / repeat_kv_backward([B*kvh,S,hd]→[B*nh,S,hd];bf16 upcast→f32→downcast)
|
||
crates/xtrain-autodiff/
|
||
├── src/ops.rs # +ops::repeat_kv 节点(fwd broadcast,bwd 组内求和)
|
||
└── tests/autograd.rs # +repeat_kv grad-check(含 group>1 的多组梯度累加)
|
||
crates/xtrain-model/
|
||
├── src/config.rs # +num_kv_heads 字段(默认 = n_heads → MHA);from_arch 加形参;num_params 计 K/V 投影按 kv_dim
|
||
├── src/model.rs # wk/wv 投影到 kv_dim;attention() 在 SDPA 前对 K/V 做 ops::repeat_kv;两条路径都吃到 GQA
|
||
└── tests/gqa.rs # 新:GQA(group>1) flash==composed + group=1 与 MHA 逐位一致
|
||
crates/xtrain-train/src/bin/train.rs # +--kv-heads flag
|
||
crates/xtrain-distributed/src/bin/train_ddp.rs # +--kv-heads flag(DDP 路径)
|
||
crates/xtrain-train/src/bin/export_safetensors.rs # +--kv-heads;config.json 写真 num_key_value_heads
|
||
crates/xtrain-model/tests/parity{.py,_dump.rs} # PyTorch 对拍加 GQA(kv 投影 + repeat_kv)
|
||
```
|
||
|
||
## Key Design Decisions
|
||
|
||
### ① GQA = K/V 头维 broadcast op,喂给**未改动**的两条 SDPA(不写第三套 attention)
|
||
|
||
T14 已经有两条**逐位/数值都验证过**的 SDPA:composed(`ops::attention`)与 fused flash
|
||
(`ops::flash_attention`),二者都吃 `[bh, S, hd]`(`bh = batch·heads`)。GQA 的本质只是「K/V
|
||
比 Q 少 `group` 倍头,用前把每个 kv 头复制 `group` 份」。所以**最外科**的做法:
|
||
|
||
- wk/wv 投影到 `kv_dim = num_kv_heads · head_dim`,按 `[B, num_kv, S, hd] → [B·num_kv, S, hd]`
|
||
排好(和 Q 的 `[B·nh, S, hd]` 同流水线,只是头数不同)。
|
||
- 在调 SDPA 之前,对 K、V 各做一个新 autograd op `ops::repeat_kv`,把 `[B·num_kv, S, hd]`
|
||
**broadcast** 成 `[B·nh, S, hd]`(输出行 `b·nh + qh` = 输入行 `b·num_kv + qh/group` 的拷贝)。
|
||
- 之后 `ops::attention` / `ops::flash_attention` **一字不改**——它们看到的就是满头的
|
||
`[B·nh, S, hd]`,GQA 对两条路径**同时、免费**生效。flash kernel / composed kernel 都不用碰。
|
||
|
||
**为什么不在 kernel 内做 GQA**:那要给 flash fwd/bwd 两个 kernel 各加 kv-head 索引、给 composed
|
||
的两次 strided GEMM 各算 GQA stride,且两套都要重测——是「第三套 attention 改动」。而 broadcast-op
|
||
方案:(a) 两条 SDPA 路径零改动、其 T14 闸门不回归;(b) repeat_kv 的 fwd/bwd 是独立可 grad-check
|
||
的小 op,正确性风险隔离在一处;(c) 关键的「多组 q 头梯度汇到一个 kv 头」就是 repeat_kv 的**反向**,
|
||
干净地落在一个 op 上单测。代价是 K/V 在显存里被物化成满头 `[B·nh,S,hd]`(多 `group` 倍)——本规模
|
||
(训练、seq 不极端)可接受;真要省这份显存是 follow-up(kernel 内 GQA 读取),记进逃生舱不在 T15 做。
|
||
|
||
> 备选(不采纳):flash/composed kernel 内直接按 `kv_head = q_head/group` 索引 K/V。省 broadcast
|
||
> 物化,但动两套已验证 kernel + 重写两套 backward 的 kv 累加,违反「不写第三套 attention」与回归保护。
|
||
> escape hatch:先 broadcast-op 把正确性 + 闭环钉死,kernel-内 GQA(省显存)留 follow-up。
|
||
|
||
### ② repeat_kv 的反向 = 组内求和(确定性,无 atomic)
|
||
|
||
`repeat_kv` 前向:`out[b·nh + qh] = in[b·num_kv + qh/group]`(按 `S·hd` 整行拷贝)。
|
||
|
||
反向是它的**转置**:一个 kv 头收到它那 `group` 个 query 头的梯度之**和**:
|
||
```
|
||
din[b·num_kv + kvh] = Σ_{r=0}^{group-1} dout[b·nh + kvh·group + r]
|
||
```
|
||
这正是闸门要求的「多组 q 头梯度累加到一个 kv 头」。实现上**不用 atomicAdd**:每个输入
|
||
(kv-head, 元素)由唯一一个 block 负责,它**串行累加自己那 group 个连续源行**——天然 race-free
|
||
且**run-to-run 确定**(不像 flash bwd 的跨行 atomic 反向有归约序不确定问题)。`group=1` 时
|
||
反向退化为单行拷贝(identity)。
|
||
|
||
autograd 层面其实也可以靠引擎的扇出 SUM(把一个 kv Var 喂给 group 个下游),但那样图里要
|
||
显式建 group 份 view、且 flash/composed 的 batched 布局不是按头切的——做成一个专门的
|
||
broadcast op,fwd/bwd 各一发 kernel,最简且能单独 grad-check。
|
||
|
||
### ③ `num_kv_heads` 进 Config(它改模型尺寸/导出),默认 = n_heads → 退化 MHA
|
||
|
||
不同于 T14 的 `use_flash`(运行时旗标,不进 Config),`num_kv_heads` **改 K/V 投影的形状、改参数量、
|
||
改导出的 `num_key_value_heads`**——它是**架构**的一部分,必须进 `Config` 并落进 checkpoint/导出。
|
||
|
||
- `Config` 加 `num_kv_heads: usize`;`from_arch` 加该形参;`Config::tiny()` 默认 `num_kv_heads =
|
||
n_heads`(MHA)。约束:`num_heads % num_kv_heads == 0`(断言)。
|
||
- `num_params()`:K/V 投影从 `2·dim·dim` 改成 `2·dim·(num_kv_heads·head_dim)`;QK-norm 的
|
||
`k_norm` 仍是 `[head_dim]`(per-head,作用在单个 head 向量上,与头数无关)→ 不变。
|
||
- **`num_kv_heads == n_heads` 时 `group=1`**:`ops::repeat_kv` 是 identity(fwd 单行拷贝、bwd 单行
|
||
拷贝),wk/wv 形状回到 `[dim,dim]` → 整条图与 T14 的 MHA 路径**逐位一致**(回归保护闸门)。
|
||
|
||
> wk/wv 形状从 `[dim,dim]` 变成 `[dim, kv_dim]`:`Block` 里 wk/wv 的 `mk(&[dim, kv_dim])`,
|
||
> `params()`/`block_params()` 顺序不变(还是 attn_norm,wq,wk,wv,q_norm,k_norm,wo,...),只是
|
||
> wk/wv 的 shape 跟着 Config。导出转置照旧按各自 shape 走(`transpose` 读 `v.value().shape()`)。
|
||
|
||
### ④ bf16 / recompute / dropout / DDP 全部自动兼容
|
||
|
||
- **bf16**:`Tensor::repeat_kv` 沿用全 repo 一致的 cast 策略——bf16 入则 upcast f32 → kernel →
|
||
downcast;kernel 只一份 f32。`ops::repeat_kv` 的 fwd/bwd 都在 SDPA 之前/之后,dtype 与 K/V 流一致。
|
||
- **recompute(T13)**:repeat_kv 在 `block_forward` 内、`attention()` 里,重算段重跑 `attention()`
|
||
自然重跑 repeat_kv(无额外状态,确定性)→ 梯度仍逐位一致。
|
||
- **dropout(T18)**:dropout 接在 attn/mlp 子块**输出**,与 attention 内部的 repeat_kv 正交,不交互。
|
||
- **DDP**:repeat_kv 不引入新参数;wk/wv 变小(kv_dim)只是参数张量小一圈,`params()` 泛化迭代
|
||
+ all-reduce 照旧;跨 rank 一致性不受影响。
|
||
|
||
### ⑤ 导出 xserv:写真 `num_key_value_heads`,分组约定对齐 repeat_kv
|
||
|
||
`export_safetensors.rs` 的 `config.json` 把 `num_key_value_heads` 从「= num_attention_heads」改成
|
||
**真 `cfg.num_kv_heads`**;`--kv-heads` flag 传入(须与训练 ckpt 一致)。q/k/v_proj 各自按其 shape
|
||
转置导出(k/v_proj 现在是 `[kv_dim, dim]`,xserv loader 期望的 GQA 形状)。xserv 的 `repeat_kv`
|
||
用 `dst = kvh·group + r` 分组,与 ① 的 xtrain 约定**逐头对齐** → 同一份权重在两侧前向数学一致,
|
||
闭环(贪心逐 token 一致)成立。
|
||
|
||
## 验证方法
|
||
|
||
全部 `#![cfg(not(no_cuda))]` 门控,本地 `cargo check`/`fmt`,构建+实跑在 dash5(8× RTX 5090)。
|
||
|
||
### 1. 正确性(硬闸门全绿,dash5 实跑 capture)
|
||
|
||
- **repeat_kv finite-diff grad-check**(`autograd.rs::repeat_kv_grad`):**核心闸门**——`group>1`
|
||
(如 bh: 2 kv 头 → 6 q 头)下 grad-check `din`,验证「多组 q 头梯度求和到一个 kv 头」的反向。
|
||
外加 `group=1` identity 自检。
|
||
- **GQA flash==composed**(`gqa.rs`):真 GQA 配置(`num_kv_heads < n_heads`,如 8 头/2 kv 头)下,
|
||
flash on/off 两个同 init 模型,断 forward logits / loss / **每参数梯度**一致(fp32 紧容差 + bf16
|
||
舍入带)——尤其 wk/wv 的梯度(它们经过 repeat_kv 反向的组内求和)。
|
||
- **group=1 与 MHA 逐位一致**(`gqa.rs`):`num_kv_heads = n_heads` 的模型对 T14 的 MHA 模型,
|
||
forward + 每参数梯度 `|Δ|=0`(回归保护)。
|
||
- **PyTorch GQA 对拍 B>1**(`parity_dump.rs` + `parity.py`):等价 PyTorch 模型加 GQA(k/v 投影到
|
||
kv_dim + `repeat_interleave(group)` 分组,与 xserv/xtrain 约定一致),对拍 forward logits + 全部
|
||
参数梯度(composed 与 flash 两条都跑,共用同一 oracle)。
|
||
- **小 GQA 配置短训收敛**:一个真 GQA 小模型短训,loss 单调降、无 NaN、采样连贯。
|
||
- **全回归套开/关**:autograd / structural / batched==looped / bf16 / recompute(逐位)/ overfit 27/27 /
|
||
AdamW(GPU bit-exact + host 对 torch)/ DDP loss-match + 跨 rank(`--test-threads=1`)/ flash /
|
||
grad_accum / dropout / **xserv 闭环 md5**。MHA 默认(kv=heads)图不变 → 不回归。
|
||
|
||
### 2. 闭环(payoff)—— 真 GQA 端到端
|
||
|
||
导出一个 `num_key_value_heads < num_attention_heads` 的 GQA checkpoint → xserv 加载 → 贪心生成
|
||
**对 xtrain 自身逐 token 一致**(BF16 推理 vs f32 训练,与 v1–v8 同款判据)。这是 GQA 真正落地的证明:
|
||
训练侧的分组、导出的分组、推理侧 xserv 的 repeat_kv 分组三方对齐。
|
||
|
||
## 实测结果(dash5 1× / 2× RTX 5090)
|
||
|
||
**硬闸门全绿:**
|
||
|
||
| 闸门 | 结果 |
|
||
|---|---|
|
||
| ① repeat_kv grad-check(**多组 q 头梯度求和到一个 kv 头**,group=3) | **过** — din max_rel **2.05e-4**;group=1 identity 双向**逐位**(fwd/bwd |Δ|=0) |
|
||
| GQA flash==composed(model 级 8h/2kv,logits/loss/每参数梯度) | fp32: loss rel **0.0**、logits 3.0e-4、grad **4.1e-5**;bf16: loss 9.0e-5、logits mean 2.9e-3/p99 1.0e-2、grad scaled-mean 8.9e-3 |
|
||
| group=1 对 MHA**逐位一致**(回归保护) | **过** — logits + loss + 全部梯度 |Δ|=0 |
|
||
| ② PyTorch GQA 对拍 B>1(composed & flash,repeat_interleave 分组对齐) | composed: loss **1.74e-8**/logits 2.04e-5/25 grad 进 rtol;flash: loss 1.74e-8/logits 2.28e-5/25 grad 进 rtol |
|
||
| ③ 小 GQA 配置短训收敛(8h/2kv/hd32/4L/ffn1024,600 步) | train **10.90→3.15** 无 NaN、gnorm 稳 ~1.2、采样连贯英文(~200K tok/s) |
|
||
| ④ **xserv 闭环真 GQA**(导出 `num_key_value_heads=2 < num_attention_heads=8`,xserv 加载 `heads=8/2 kv`,贪心) | "One day"/"The little" 两 prompt **逐 token 一致**;"Once upon a time" 在 `...Lily's mommy ` 处 BF16 漂移晚分叉(said vs came)——与 v1/v2/v3/T14 同款判据 |
|
||
| ⑤ 回归套:autograd 23(含 repeat_kv 2)/ structural 5 / batched / bf16 / flash 2 / **gqa 4** / overfit 27/27 / recompute 2 / dropout 6 / grad_accum 3 / checkpoint-roundtrip / AdamW(host 对 torch 4.8e-6) / DDP 3(`--test-threads=1`, loss 5.67e-7+跨 rank 一致) / GEMM / tensor | **全绿** |
|
||
| ⑤ MHA 默认 export md5(v3 ckpt 用 T15 代码重导 safetensors) | **逐位一致** `b04fc9f9a0c9af04c47d9ca649aea12e`(与 registry/T14 同)→ 默认(kv=heads)export 零漂移 |
|
||
|
||
> **诚实记录**:闭环 2/3 prompt 完全 token-identical、1/3 在 BF16 漂移点晚分叉——这恰证明 GQA 分组**正确**:若 kv→q 头映射错,attention 会从第一个生成 token 起就崩(不会是深处近-tie 的晚分叉)。GQA 把 K/V 在显存里物化成满头 `[B·nh,S,hd]`(broadcast-op 方案的代价)——本规模可接受,kernel-内 GQA(省这份显存)留 follow-up。未为凑绿放宽任何容差。
|