gqa: fix kv-proj shape test param indices (embed,attn_norm precede wq)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -251,10 +251,11 @@ fn gqa_kv_proj_shape() {
|
||||
let cfg = gqa_cfg();
|
||||
let m = build(cfg, device, DType::F32, false);
|
||||
let p = m.params();
|
||||
// params order: embed, then per block [attn_norm, wq, wk, wv, q_norm, k_norm, wo, ...]
|
||||
let wq = p[1].value().shape().to_vec();
|
||||
let wk = p[2].value().shape().to_vec();
|
||||
let wv = p[3].value().shape().to_vec();
|
||||
// params order: embed[0], then block 0 = [attn_norm[1], wq[2], wk[3], wv[4],
|
||||
// q_norm[5], k_norm[6], wo[7], ...]
|
||||
let wq = p[2].value().shape().to_vec();
|
||||
let wk = p[3].value().shape().to_vec();
|
||||
let wv = p[4].value().shape().to_vec();
|
||||
assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]");
|
||||
assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]");
|
||||
assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");
|
||||
|
||||
Reference in New Issue
Block a user