model: add per-head QK-norm (Qwen3-compat) for xserv export
xserv's Qwen3 forward unconditionally applies per-head RMSNorm to Q and K (q_norm/k_norm, shape [head_dim]) before RoPE — even gamma=1 is a real RMS divide, not identity. xtrain never had this, so an exact xserv<->xtrain loop was structurally impossible. Add it (reusing the 2D rms_norm op on the [seq*nh, hd] head rows, inserted between reshape and rope to mirror qwen3.rs's order) so the trained model is genuinely Qwen3-compatible. params() inserts q_norm,k_norm after wv; num_params() counts them; the PyTorch parity refs (parity.py / adamw_parity.py) + their name lists add the same step so the dumps stay self-consistent. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -42,6 +42,7 @@ impl Config {
|
||||
/// Total learnable parameter count (for logging / sanity).
|
||||
pub fn num_params(&self) -> usize {
|
||||
let per_layer = 2 * self.dim // 2 rmsnorm gammas
|
||||
+ 2 * self.head_dim // q/k per-head norm gammas
|
||||
+ 3 * self.dim * self.dim // q/k/v proj
|
||||
+ self.dim * self.dim // out proj
|
||||
+ 2 * self.dim * self.ffn_hidden // gate/up proj
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
//!
|
||||
//! A from-scratch decoder built entirely from the [`xtrain_autodiff`] op set:
|
||||
//! token embedding → `n_layers` × {pre-RMSNorm → multi-head causal attention
|
||||
//! (RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} → final RMSNorm →
|
||||
//! LM-head matmul. The forward builds an autograd graph; calling `.backward()`
|
||||
//! on the cross-entropy loss fills every parameter's `.grad()`.
|
||||
//! (per-head QK-norm + RoPE) → residual; pre-RMSNorm → SwiGLU MLP → residual} →
|
||||
//! final RMSNorm → LM-head matmul. The forward builds an autograd graph; calling
|
||||
//! `.backward()` on the cross-entropy loss fills every parameter's `.grad()`.
|
||||
//! Per-head QK-norm (Qwen3-style) makes the architecture xserv-compatible (T9).
|
||||
//!
|
||||
//! Conventions (matching the engine, not HuggingFace):
|
||||
//! - Linear weights are `[in, out]` and applied as `x @ W` (no transpose), since
|
||||
|
||||
@@ -13,6 +13,8 @@ struct Block {
|
||||
wq: Var, // [dim, dim]
|
||||
wk: Var, // [dim, dim]
|
||||
wv: Var, // [dim, dim]
|
||||
q_norm: Var, // [head_dim] — per-head QK-norm (Qwen3-style)
|
||||
k_norm: Var, // [head_dim]
|
||||
wo: Var, // [dim, dim]
|
||||
ffn_norm: Var, // [dim]
|
||||
w_gate: Var, // [dim, ffn_hidden]
|
||||
@@ -52,6 +54,8 @@ impl TinyTransformer {
|
||||
wq: mk(&[cfg.dim, cfg.dim]),
|
||||
wk: mk(&[cfg.dim, cfg.dim]),
|
||||
wv: mk(&[cfg.dim, cfg.dim]),
|
||||
q_norm: mk(&[cfg.head_dim]),
|
||||
k_norm: mk(&[cfg.head_dim]),
|
||||
wo: mk(&[cfg.dim, cfg.dim]),
|
||||
ffn_norm: mk(&[cfg.dim]),
|
||||
w_gate: mk(&[cfg.dim, cfg.ffn_hidden]),
|
||||
@@ -87,6 +91,8 @@ impl TinyTransformer {
|
||||
b.wq.clone(),
|
||||
b.wk.clone(),
|
||||
b.wv.clone(),
|
||||
b.q_norm.clone(),
|
||||
b.k_norm.clone(),
|
||||
b.wo.clone(),
|
||||
b.ffn_norm.clone(),
|
||||
b.w_gate.clone(),
|
||||
@@ -136,22 +142,29 @@ impl TinyTransformer {
|
||||
// Project, then lay out as per-head [seq, head_dim] tensors.
|
||||
// [seq,dim] @ [dim,dim] = [seq,dim]
|
||||
// reshape [seq, nh, hd]
|
||||
// qk-norm per-head RMSNorm over hd (Qwen3-style; Q/K only, before RoPE)
|
||||
// rope (kernel expects exactly [tokens, heads, head_dim])
|
||||
// transpose [nh, seq, hd] → split into nh × [seq, hd]
|
||||
let to_heads = |proj: Var, rope: bool| -> Vec<Var> {
|
||||
let to_heads = |proj: Var, norm: Option<&Var>| -> Vec<Var> {
|
||||
let r = ops::reshape(&proj, &[seq, nh, hd]);
|
||||
let r = if rope {
|
||||
let r = match norm {
|
||||
// Per-head RMSNorm: flatten the (seq,nh) head rows, norm over hd,
|
||||
// restore. RoPE follows on the normed Q/K (mirrors xserv qwen3.rs).
|
||||
Some(gamma) => {
|
||||
let flat = ops::reshape(&r, &[seq * nh, hd]);
|
||||
let normed = ops::rms_norm(&flat, gamma, self.cfg.eps);
|
||||
let r = ops::reshape(&normed, &[seq, nh, hd]);
|
||||
ops::rope(&r, self.cfg.rope_theta)
|
||||
} else {
|
||||
r
|
||||
}
|
||||
None => r,
|
||||
};
|
||||
let t = ops::transpose_3d01(&r); // [nh, seq, hd]
|
||||
ops::split_heads(&t)
|
||||
};
|
||||
|
||||
let q = to_heads(ops::matmul(x, &b.wq), true);
|
||||
let k = to_heads(ops::matmul(x, &b.wk), true);
|
||||
let v = to_heads(ops::matmul(x, &b.wv), false);
|
||||
let q = to_heads(ops::matmul(x, &b.wq), Some(&b.q_norm));
|
||||
let k = to_heads(ops::matmul(x, &b.wk), Some(&b.k_norm));
|
||||
let v = to_heads(ops::matmul(x, &b.wv), None);
|
||||
|
||||
// Per-head scaled-dot-product attention with causal mask.
|
||||
let heads_out: Vec<Var> = (0..nh)
|
||||
|
||||
@@ -98,7 +98,7 @@ lm_head = load("lm_head")
|
||||
layers = []
|
||||
for l in range(NL):
|
||||
layers.append({p: load(f"l{l}_{p}") for p in
|
||||
["attn_norm", "wq", "wk", "wv", "wo",
|
||||
["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
|
||||
"ffn_norm", "w_gate", "w_up", "w_down"]})
|
||||
|
||||
idx = torch.tensor(ids, dtype=torch.long)
|
||||
@@ -111,6 +111,9 @@ for L in layers:
|
||||
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
|
||||
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
|
||||
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
|
||||
# Per-head QK-norm (Qwen3-style), before RoPE.
|
||||
q = rms_norm(q, L["q_norm"])
|
||||
k = rms_norm(k, L["k_norm"])
|
||||
q = rope(q).transpose(0, 1) # [nh, seq, hd]
|
||||
k = rope(k).transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
|
||||
@@ -145,6 +145,8 @@ fn param_names(cfg: &Config) -> Vec<String> {
|
||||
"wq",
|
||||
"wk",
|
||||
"wv",
|
||||
"q_norm",
|
||||
"k_norm",
|
||||
"wo",
|
||||
"ffn_norm",
|
||||
"w_gate",
|
||||
|
||||
@@ -74,7 +74,7 @@ SEQ = len(ids)
|
||||
|
||||
NAMES = ["embed"]
|
||||
for l in range(NL):
|
||||
for p in ["attn_norm", "wq", "wk", "wv", "wo",
|
||||
for p in ["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
|
||||
"ffn_norm", "w_gate", "w_up", "w_down"]:
|
||||
NAMES.append(f"l{l}_{p}")
|
||||
NAMES += ["final_norm", "lm_head"]
|
||||
@@ -115,6 +115,9 @@ def forward():
|
||||
q = (x @ P[f"l{l}_wq"]).reshape(SEQ, NH, HD)
|
||||
k = (x @ P[f"l{l}_wk"]).reshape(SEQ, NH, HD)
|
||||
v = (x @ P[f"l{l}_wv"]).reshape(SEQ, NH, HD)
|
||||
# Per-head QK-norm (Qwen3-style), before RoPE.
|
||||
q = rms_norm(q, P[f"l{l}_q_norm"])
|
||||
k = rms_norm(k, P[f"l{l}_k_norm"])
|
||||
q = rope(q).transpose(0, 1)
|
||||
k = rope(k).transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
|
||||
@@ -156,6 +156,8 @@ fn param_names(cfg: &Config) -> Vec<String> {
|
||||
"wq",
|
||||
"wk",
|
||||
"wv",
|
||||
"q_norm",
|
||||
"k_norm",
|
||||
"wo",
|
||||
"ffn_norm",
|
||||
"w_gate",
|
||||
|
||||
Reference in New Issue
Block a user