model: add per-head QK-norm (Qwen3-compat) for xserv export

xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K
(q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS
divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop
was structurally impossible. Add it (reusing the 2D rms_norm op on the
[seq*nh, hd] head rows, inserted between reshape and rope to mirror
qwen3.rs's order) so the trained model is genuinely Qwen3-compatible.

params() inserts q_norm,k_norm after wv; num_params() counts them; the
PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the
same step so the dumps stay self-consistent.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 17:33:19 +08:00
parent ad82e8bf92
commit 7a4f69e430
7 changed files with 38 additions and 13 deletions

View File

@@ -42,6 +42,7 @@ impl Config {
/// Total learnable parameter count (for logging / sanity).
pub fn num_params(&self) -> usize {
let per_layer = 2 * self.dim // 2 rmsnorm gammas
+ 2 * self.head_dim // q/k per-head norm gammas
+ 3 * self.dim * self.dim // q/k/v proj
+ self.dim * self.dim // out proj
+ 2 * self.dim * self.ffn_hidden // gate/up proj

View File

@@ -2,9 +2,10 @@
//!
//! A from-scratch decoder built entirely from the [`xtrain_autodiff`] op set:
//! token embedding → `n_layers` × {pre-RMSNorm → multi-head causal attention
//! (RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} → final RMSNorm →
//! LM-head matmul. The forward builds an autograd graph; calling `.backward()`
//! on the cross-entropy loss fills every parameter's `.grad()`.
//! (per-head QK-norm + RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} →
//! final RMSNorm → LM-head matmul. The forward builds an autograd graph; calling
//! `.backward()` on the cross-entropy loss fills every parameter's `.grad()`.
//! Per-head QK-norm (Qwen3-style) makes the architecture xserv-compatible (T9).
//!
//! Conventions (matching the engine, not HuggingFace):
//! - Linear weights are `[in, out]` and applied as `x @ W` (no transpose), since

View File

@@ -13,6 +13,8 @@ struct Block {
wq: Var, // [dim, dim]
wk: Var, // [dim, dim]
wv: Var, // [dim, dim]
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
k_norm: Var, // [head_dim]
wo: Var, // [dim, dim]
ffn_norm: Var, // [dim]
w_gate: Var, // [dim, ffn_hidden]
@@ -52,6 +54,8 @@ impl TinyTransformer {
wq: mk(&[cfg.dim, cfg.dim]),
wk: mk(&[cfg.dim, cfg.dim]),
wv: mk(&[cfg.dim, cfg.dim]),
q_norm: mk(&[cfg.head_dim]),
k_norm: mk(&[cfg.head_dim]),
wo: mk(&[cfg.dim, cfg.dim]),
ffn_norm: mk(&[cfg.dim]),
w_gate: mk(&[cfg.dim, cfg.ffn_hidden]),
@@ -87,6 +91,8 @@ impl TinyTransformer {
b.wq.clone(),
b.wk.clone(),
b.wv.clone(),
b.q_norm.clone(),
b.k_norm.clone(),
b.wo.clone(),
b.ffn_norm.clone(),
b.w_gate.clone(),
@@ -136,22 +142,29 @@ impl TinyTransformer {
// Project, then lay out as per-head [seq, head_dim] tensors.
// [seq,dim] @ [dim,dim] = [seq,dim]
// reshape [seq, nh, hd]
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
// rope (kernel expects exactly [tokens, heads, head_dim])
// transpose [nh, seq, hd] → split into nh × [seq, hd]
let to_heads = |proj: Var, rope: bool| -> Vec<Var> {
let to_heads = |proj: Var, norm: Option<&Var>| -> Vec<Var> {
let r = ops::reshape(&proj, &[seq, nh, hd]);
let r = if rope {
ops::rope(&r, self.cfg.rope_theta)
} else {
r
let r = match norm {
// Per-head RMSNorm: flatten the (seq,nh) head rows, norm over hd,
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
Some(gamma) => {
let flat = ops::reshape(&r, &[seq * nh, hd]);
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
let r = ops::reshape(&normed, &[seq, nh, hd]);
ops::rope(&r, self.cfg.rope_theta)
}
None => r,
};
let t = ops::transpose_3d01(&r); // [nh, seq, hd]
ops::split_heads(&t)
};
let q = to_heads(ops::matmul(x, &b.wq), true);
let k = to_heads(ops::matmul(x, &b.wk), true);
let v = to_heads(ops::matmul(x, &b.wv), false);
let q = to_heads(ops::matmul(x, &b.wq), Some(&b.q_norm));
let k = to_heads(ops::matmul(x, &b.wk), Some(&b.k_norm));
let v = to_heads(ops::matmul(x, &b.wv), None);
// Per-head scaled-dot-product attention with causal mask.
let heads_out: Vec<Var> = (0..nh)

View File

@@ -98,7 +98,7 @@ lm_head = load("lm_head")
layers = []
for l in range(NL):
layers.append({p: load(f"l{l}_{p}") for p in
["attn_norm", "wq", "wk", "wv", "wo",
["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]})
idx = torch.tensor(ids, dtype=torch.long)
@@ -111,6 +111,9 @@ for L in layers:
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
# Per-head QK-norm (Qwen3-style), before RoPE.
q = rms_norm(q, L["q_norm"])
k = rms_norm(k, L["k_norm"])
q = rope(q).transpose(0, 1) # [nh, seq, hd]
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)

View File

@@ -145,6 +145,8 @@ fn param_names(cfg: &Config) -> Vec<String> {
"wq",
"wk",
"wv",
"q_norm",
"k_norm",
"wo",
"ffn_norm",
"w_gate",

View File

@@ -74,7 +74,7 @@ SEQ = len(ids)
NAMES = ["embed"]
for l in range(NL):
for p in ["attn_norm", "wq", "wk", "wv", "wo",
for p in ["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]:
NAMES.append(f"l{l}_{p}")
NAMES += ["final_norm", "lm_head"]
@@ -115,6 +115,9 @@ def forward():
q = (x @ P[f"l{l}_wq"]).reshape(SEQ, NH, HD)
k = (x @ P[f"l{l}_wk"]).reshape(SEQ, NH, HD)
v = (x @ P[f"l{l}_wv"]).reshape(SEQ, NH, HD)
# Per-head QK-norm (Qwen3-style), before RoPE.
q = rms_norm(q, P[f"l{l}_q_norm"])
k = rms_norm(k, P[f"l{l}_k_norm"])
q = rope(q).transpose(0, 1)
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)

View File

@@ -156,6 +156,8 @@ fn param_names(cfg: &Config) -> Vec<String> {
"wq",
"wk",
"wv",
"q_norm",
"k_norm",
"wo",
"ffn_norm",
"w_gate",