diff --git a/crates/xtrain-model/src/config.rs b/crates/xtrain-model/src/config.rs index 8d52109..d5aac59 100644 --- a/crates/xtrain-model/src/config.rs +++ b/crates/xtrain-model/src/config.rs @@ -42,6 +42,7 @@ impl Config { /// Total learnable parameter count (for logging / sanity). pub fn num_params(&self) -> usize { let per_layer = 2 * self.dim // 2 rmsnorm gammas + + 2 * self.head_dim // q/k per-head norm gammas + 3 * self.dim * self.dim // q/k/v proj + self.dim * self.dim // out proj + 2 * self.dim * self.ffn_hidden // gate/up proj diff --git a/crates/xtrain-model/src/lib.rs b/crates/xtrain-model/src/lib.rs index 61d5e5d..43ba6a1 100644 --- a/crates/xtrain-model/src/lib.rs +++ b/crates/xtrain-model/src/lib.rs @@ -2,9 +2,10 @@ //! //! A from-scratch decoder built entirely from the [`xtrain_autodiff`] op set: //! token embedding → `n_layers` × {pre-RMSNorm → multi-head causal attention -//! (RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} → final RMSNorm → -//! LM-head matmul. The forward builds an autograd graph; calling `.backward()` -//! on the cross-entropy loss fills every parameter's `.grad()`. +//! (per-head QK-norm + RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} → +//! final RMSNorm → LM-head matmul. The forward builds an autograd graph; calling +//! `.backward()` on the cross-entropy loss fills every parameter's `.grad()`. +//! Per-head QK-norm (Qwen3-style) makes the architecture xserv-compatible (T9). //! //! Conventions (matching the engine, not HuggingFace): //! - Linear weights are `[in, out]` and applied as `x @ W` (no transpose), since diff --git a/crates/xtrain-model/src/model.rs b/crates/xtrain-model/src/model.rs index 807d2c3..25cd7b3 100644 --- a/crates/xtrain-model/src/model.rs +++ b/crates/xtrain-model/src/model.rs @@ -13,6 +13,8 @@ struct Block { wq: Var, // [dim, dim] wk: Var, // [dim, dim] wv: Var, // [dim, dim] + q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style) + k_norm: Var, // [head_dim] wo: Var, // [dim, dim] ffn_norm: Var, // [dim] w_gate: Var, // [dim, ffn_hidden] @@ -52,6 +54,8 @@ impl TinyTransformer { wq: mk(&[cfg.dim, cfg.dim]), wk: mk(&[cfg.dim, cfg.dim]), wv: mk(&[cfg.dim, cfg.dim]), + q_norm: mk(&[cfg.head_dim]), + k_norm: mk(&[cfg.head_dim]), wo: mk(&[cfg.dim, cfg.dim]), ffn_norm: mk(&[cfg.dim]), w_gate: mk(&[cfg.dim, cfg.ffn_hidden]), @@ -87,6 +91,8 @@ impl TinyTransformer { b.wq.clone(), b.wk.clone(), b.wv.clone(), + b.q_norm.clone(), + b.k_norm.clone(), b.wo.clone(), b.ffn_norm.clone(), b.w_gate.clone(), @@ -136,22 +142,29 @@ impl TinyTransformer { // Project, then lay out as per-head [seq, head_dim] tensors. // [seq,dim] @ [dim,dim] = [seq,dim] // reshape [seq, nh, hd] + // qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE) // rope (kernel expects exactly [tokens, heads, head_dim]) // transpose [nh, seq, hd] → split into nh × [seq, hd] - let to_heads = |proj: Var, rope: bool| -> Vec { + let to_heads = |proj: Var, norm: Option<&Var>| -> Vec { let r = ops::reshape(&proj, &[seq, nh, hd]); - let r = if rope { - ops::rope(&r, self.cfg.rope_theta) - } else { - r + let r = match norm { + // Per-head RMSNorm: flatten the (seq,nh) head rows, norm over hd, + // restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs). + Some(gamma) => { + let flat = ops::reshape(&r, &[seq * nh, hd]); + let normed = ops::rms_norm(&flat, gamma, self.cfg.eps); + let r = ops::reshape(&normed, &[seq, nh, hd]); + ops::rope(&r, self.cfg.rope_theta) + } + None => r, }; let t = ops::transpose_3d01(&r); // [nh, seq, hd] ops::split_heads(&t) }; - let q = to_heads(ops::matmul(x, &b.wq), true); - let k = to_heads(ops::matmul(x, &b.wk), true); - let v = to_heads(ops::matmul(x, &b.wv), false); + let q = to_heads(ops::matmul(x, &b.wq), Some(&b.q_norm)); + let k = to_heads(ops::matmul(x, &b.wk), Some(&b.k_norm)); + let v = to_heads(ops::matmul(x, &b.wv), None); // Per-head scaled-dot-product attention with causal mask. let heads_out: Vec = (0..nh) diff --git a/crates/xtrain-model/tests/parity.py b/crates/xtrain-model/tests/parity.py index c4b3354..249388d 100644 --- a/crates/xtrain-model/tests/parity.py +++ b/crates/xtrain-model/tests/parity.py @@ -98,7 +98,7 @@ lm_head = load("lm_head") layers = [] for l in range(NL): layers.append({p: load(f"l{l}_{p}") for p in - ["attn_norm", "wq", "wk", "wv", "wo", + ["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo", "ffn_norm", "w_gate", "w_up", "w_down"]}) idx = torch.tensor(ids, dtype=torch.long) @@ -111,6 +111,9 @@ for L in layers: q = (x @ L["wq"]).reshape(SEQ, NH, HD) k = (x @ L["wk"]).reshape(SEQ, NH, HD) v = (x @ L["wv"]).reshape(SEQ, NH, HD) + # Per-head QK-norm (Qwen3-style), before RoPE. + q = rms_norm(q, L["q_norm"]) + k = rms_norm(k, L["k_norm"]) q = rope(q).transpose(0, 1) # [nh, seq, hd] k = rope(k).transpose(0, 1) v = v.transpose(0, 1) diff --git a/crates/xtrain-model/tests/parity_dump.rs b/crates/xtrain-model/tests/parity_dump.rs index e2c921f..3181148 100644 --- a/crates/xtrain-model/tests/parity_dump.rs +++ b/crates/xtrain-model/tests/parity_dump.rs @@ -145,6 +145,8 @@ fn param_names(cfg: &Config) -> Vec { "wq", "wk", "wv", + "q_norm", + "k_norm", "wo", "ffn_norm", "w_gate", diff --git a/crates/xtrain-train/tests/adamw_parity.py b/crates/xtrain-train/tests/adamw_parity.py index 24ccf2c..3e49188 100644 --- a/crates/xtrain-train/tests/adamw_parity.py +++ b/crates/xtrain-train/tests/adamw_parity.py @@ -74,7 +74,7 @@ SEQ = len(ids) NAMES = ["embed"] for l in range(NL): - for p in ["attn_norm", "wq", "wk", "wv", "wo", + for p in ["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo", "ffn_norm", "w_gate", "w_up", "w_down"]: NAMES.append(f"l{l}_{p}") NAMES += ["final_norm", "lm_head"] @@ -115,6 +115,9 @@ def forward(): q = (x @ P[f"l{l}_wq"]).reshape(SEQ, NH, HD) k = (x @ P[f"l{l}_wk"]).reshape(SEQ, NH, HD) v = (x @ P[f"l{l}_wv"]).reshape(SEQ, NH, HD) + # Per-head QK-norm (Qwen3-style), before RoPE. + q = rms_norm(q, P[f"l{l}_q_norm"]) + k = rms_norm(k, P[f"l{l}_k_norm"]) q = rope(q).transpose(0, 1) k = rope(k).transpose(0, 1) v = v.transpose(0, 1) diff --git a/crates/xtrain-train/tests/adamw_parity_dump.rs b/crates/xtrain-train/tests/adamw_parity_dump.rs index e8f89ce..d3806a8 100644 --- a/crates/xtrain-train/tests/adamw_parity_dump.rs +++ b/crates/xtrain-train/tests/adamw_parity_dump.rs @@ -156,6 +156,8 @@ fn param_names(cfg: &Config) -> Vec { "wq", "wk", "wv", + "q_norm", + "k_norm", "wo", "ffn_norm", "w_gate",