From 39df0b40c1d354c3775f4a2557843dd80dd1821d Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 18 Jun 2026 01:38:42 +0800 Subject: [PATCH] gqa: fix kv-proj shape test param indices (embed,attn_norm precede wq) Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-model/tests/gqa.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/xtrain-model/tests/gqa.rs b/crates/xtrain-model/tests/gqa.rs index ebdc89f..d118541 100644 --- a/crates/xtrain-model/tests/gqa.rs +++ b/crates/xtrain-model/tests/gqa.rs @@ -251,10 +251,11 @@ fn gqa_kv_proj_shape() { let cfg = gqa_cfg(); let m = build(cfg, device, DType::F32, false); let p = m.params(); - // params order: embed, then per block [attn_norm, wq, wk, wv, q_norm, k_norm, wo, ...] - let wq = p[1].value().shape().to_vec(); - let wk = p[2].value().shape().to_vec(); - let wv = p[3].value().shape().to_vec(); + // params order: embed[0], then block 0 = [attn_norm[1], wq[2], wk[3], wv[4], + // q_norm[5], k_norm[6], wo[7], ...] + let wq = p[2].value().shape().to_vec(); + let wk = p[3].value().shape().to_vec(); + let wv = p[4].value().shape().to_vec(); assert_eq!(wq, vec![cfg.dim, cfg.dim], "wq must be [dim,dim]"); assert_eq!(wk, vec![cfg.dim, cfg.kv_dim()], "wk must be [dim,kv_dim]"); assert_eq!(wv, vec![cfg.dim, cfg.kv_dim()], "wv must be [dim,kv_dim]");