gqa: fix kv-proj shape test param indices (embed,attn_norm precede wq)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-18 01:38:42 +08:00
parent 830d06ad01
commit 39df0b40c1

View File

@@ -251,10 +251,11 @@ fn gqa_kv_proj_shape() {
let cfg = gqa_cfg();
let m = build(cfg, device, DType::F32, false);
let p = m.params();
// params order: embed, then per block [attn_norm, wq, wk, wv, q_norm, k_norm, wo, ...]
let wq = p[1].value().shape().to_vec();
let wk = p[2].value().shape().to_vec();
let wv = p[3].value().shape().to_vec();
// params order: embed[0], then block 0 = [attn_norm[1], wq[2], wk[3], wv[4],
// q_norm[5], k_norm[6], wo[7], ...]
let wq = p[2].value().shape().to_vec();
let wk = p[3].value().shape().to_vec();
let wv = p[4].value().shape().to_vec();
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");