From 3366f30c4d647c276ace55579107cd83c2244967 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Mon, 15 Jun 2026 16:07:30 +0800 Subject: [PATCH] model: PyTorch parity harness (weight dump + equivalent torch model) parity_dump.rs (#[ignore] fixture generator) dumps the model's exact weights, ids, forward logits, loss, and per-param grads after one backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA), runs fwd+bwd, and compares logits + every grad within rtol. Co-Authored-By: Claude Opus 4.8 --- crates/xtrain-model/tests/parity.py | 176 +++++++++++++++++++++++ crates/xtrain-model/tests/parity_dump.rs | 160 +++++++++++++++++++++ 2 files changed, 336 insertions(+) create mode 100644 crates/xtrain-model/tests/parity.py create mode 100644 crates/xtrain-model/tests/parity_dump.rs diff --git a/crates/xtrain-model/tests/parity.py b/crates/xtrain-model/tests/parity.py new file mode 100644 index 0000000..2d82b1e --- /dev/null +++ b/crates/xtrain-model/tests/parity.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""PyTorch parity check for the xtrain tiny transformer (Phase T5). + +Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL +model in PyTorch (same x@W convention, same RoPE rotate_half + position=row, +same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward, +and compares the forward logits and every parameter's gradient against the Rust +values within a relative tolerance. + +Usage: python3 parity.py /tmp/xtrain_parity +""" +import sys +import os +import math +import torch + +DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity" + + +def read_vec(name): + path = os.path.join(DIR, name) + shape = None + vals = [] + with open(path) as f: + for line in f: + line = line.strip() + if line.startswith("# shape"): + shape = [int(x) for x in line.split()[2].split(",") if x] + elif line: + vals.append(float(line)) + t = torch.tensor(vals, dtype=torch.float64) + if shape: + t = t.reshape(shape) + return t + + +def read_cfg(): + cfg = {} + with open(os.path.join(DIR, "config.txt")) as f: + for line in f: + k, v = line.split() + cfg[k] = v + return cfg + + +def read_ids(name): + with open(os.path.join(DIR, name)) as f: + return [int(x) for x in f.read().split()] + + +cfg = read_cfg() +DIM = int(cfg["dim"]) +NL = int(cfg["n_layers"]) +NH = int(cfg["n_heads"]) +HD = int(cfg["head_dim"]) +EPS = float(cfg["eps"]) +THETA = float(cfg["rope_theta"]) + +ids = read_ids("ids.txt") +targets = read_ids("targets.txt") +SEQ = len(ids) + +# Load params as leaf tensors requiring grad (float64 for a clean reference). +P = {} + + +def load(name): + t = read_vec(f"w_{name}.txt").clone().requires_grad_(True) + P[name] = t + return t + + +def rms_norm(x, gamma): + # y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction) + ms = x.pow(2).mean(dim=-1, keepdim=True) + return x * torch.rsqrt(ms + EPS) * gamma + + +def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel + half = HD // 2 + out = torch.empty_like(x) + i = torch.arange(half, dtype=torch.float64) + freq = THETA ** (-(2.0 * i) / HD) # [half] + pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1] + ang = pos * freq # [seq, half] + c = torch.cos(ang).reshape(SEQ, 1, half) + s = torch.sin(ang).reshape(SEQ, 1, half) + x0 = x[..., :half] + x1 = x[..., half:] + out[..., :half] = x0 * c - x1 * s + out[..., half:] = x1 * c + x0 * s + return out + + +emb = load("embed") +final_norm = load("final_norm") +lm_head = load("lm_head") +layers = [] +for l in range(NL): + layers.append({p: load(f"l{l}_{p}") for p in + ["attn_norm", "wq", "wk", "wv", "wo", + "ffn_norm", "w_gate", "w_up", "w_down"]}) + +idx = torch.tensor(ids, dtype=torch.long) +mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1) + +h = emb[idx] # [seq, dim] +for L in layers: + # Attention + x = rms_norm(h, L["attn_norm"]) + q = (x @ L["wq"]).reshape(SEQ, NH, HD) + k = (x @ L["wk"]).reshape(SEQ, NH, HD) + v = (x @ L["wv"]).reshape(SEQ, NH, HD) + q = rope(q).transpose(0, 1) # [nh, seq, hd] + k = rope(k).transpose(0, 1) + v = v.transpose(0, 1) + scale = 1.0 / math.sqrt(HD) + scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq] + probs = torch.softmax(scores, dim=-1) + out = probs @ v # [nh, seq, hd] + out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim] + attn = out @ L["wo"] + h = h + attn + # MLP + x = rms_norm(h, L["ffn_norm"]) + gate = x @ L["w_gate"] + up = x @ L["w_up"] + act = torch.nn.functional.silu(gate) * up + mlp = act @ L["w_down"] + h = h + mlp + +h = rms_norm(h, final_norm) +logits = h @ lm_head # [seq, vocab] + +loss = torch.nn.functional.cross_entropy( + logits, torch.tensor(targets, dtype=torch.long), reduction="mean") +loss.backward() + +# ---- Compare ---- +def relerr(a, b): + a = a.double() + b = b.double() + denom = b.abs().clamp(min=1e-6) + return ((a - b).abs() / denom).max().item() + + +ref_logits = read_vec("logits.txt") +ref_loss = read_vec("loss.txt").item() + +print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} " + f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}") +le = relerr(logits.detach(), ref_logits) +print(f"logits: max relerr = {le:.2e}") + +RTOL = 2e-2 +worst = le +worst_name = "logits" +fails = [] +if le > RTOL: + fails.append(("logits", le)) + +for name, t in P.items(): + ref_g = read_vec(f"g_{name}.txt") + ge = relerr(t.grad, ref_g) + if ge > worst: + worst, worst_name = ge, f"grad[{name}]" + if ge > RTOL: + fails.append((f"grad[{name}]", ge)) + +print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})") +if fails: + print("FAIL:") + for n, e in fails: + print(f" {n}: relerr={e:.3e}") + sys.exit(1) +print("PARITY OK: forward logits + all param grads within rtol") diff --git a/crates/xtrain-model/tests/parity_dump.rs b/crates/xtrain-model/tests/parity_dump.rs new file mode 100644 index 0000000..e2c921f --- /dev/null +++ b/crates/xtrain-model/tests/parity_dump.rs @@ -0,0 +1,160 @@ +// PyTorch parity, step 1 of 2: dump the Rust tiny-transformer's exact weights, +// inputs, forward logits, loss, and per-parameter gradients (after one backward) +// to a directory, so an equivalent PyTorch model (tests/parity.py) can be built +// from the SAME weights and the forward + grads compared within rtol. +// +// Run: XTRAIN_PARITY_DIR=/tmp/xtrain_parity cargo test -p xtrain-model \ +// --test parity_dump -- --nocapture --ignored +// then: python3 crates/xtrain-model/tests/parity.py /tmp/xtrain_parity +// +// Marked #[ignore] (it's a fixture generator, not a pass/fail assertion) and +// gated #![cfg(not(no_cuda))]. +#![cfg(not(no_cuda))] + +use std::fs; +use std::io::Write; +use std::path::PathBuf; +use xtrain_cuda::device; +use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host}; +use xtrain_tensor::Device; + +fn fill(n: usize, seed: u64, scale: f32) -> Vec { + let mut state = seed + .wrapping_mul(2862933555777941757) + .wrapping_add(3037000493); + (0..n) + .map(|_| { + state = state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + (((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale + }) + .collect() +} + +fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) { + let mut f = fs::File::create(dir.join(name)).unwrap(); + let shape_str: Vec = shape.iter().map(|d| d.to_string()).collect(); + writeln!(f, "# shape {}", shape_str.join(",")).unwrap(); + for v in data { + writeln!(f, "{v:.8e}").unwrap(); + } +} + +#[test] +#[ignore = "fixture generator for PyTorch parity; run with --ignored"] +fn dump_for_parity() { + assert!(device::device_count().unwrap() > 0, "no CUDA device"); + device::set_device(0).unwrap(); + let device = Device::Cuda(0); + + let dir = PathBuf::from( + std::env::var("XTRAIN_PARITY_DIR").unwrap_or_else(|_| "/tmp/xtrain_parity".to_string()), + ); + fs::create_dir_all(&dir).unwrap(); + + // Fixed config + ids (independent of any text, for reproducibility). + let mut cfg = Config::tiny(); + cfg.vocab = 12; + let ids: Vec = vec![3, 1, 4, 1, 5, 9, 2, 6]; + let targets: Vec = vec![1, 4, 1, 5, 9, 2, 6, 0]; + let seq = ids.len(); + + // Same deterministic init as the overfit test. + let mut seed = 1u64; + let model = TinyTransformer::new(cfg, device, |shape| { + seed = seed.wrapping_add(1); + let n: usize = shape.iter().product(); + if shape.len() == 1 { + fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect() + } else { + fill(n, seed, 0.08) + } + }); + + // config + ids + { + let mut f = fs::File::create(dir.join("config.txt")).unwrap(); + writeln!(f, "vocab {}", cfg.vocab).unwrap(); + writeln!(f, "dim {}", cfg.dim).unwrap(); + writeln!(f, "n_layers {}", cfg.n_layers).unwrap(); + writeln!(f, "n_heads {}", cfg.n_heads).unwrap(); + writeln!(f, "head_dim {}", cfg.head_dim).unwrap(); + writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap(); + writeln!(f, "eps {:e}", cfg.eps).unwrap(); + writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap(); + writeln!(f, "seq {seq}").unwrap(); + } + { + let mut f = fs::File::create(dir.join("ids.txt")).unwrap(); + for v in &ids { + writeln!(f, "{v}").unwrap(); + } + let mut f = fs::File::create(dir.join("targets.txt")).unwrap(); + for v in &targets { + writeln!(f, "{v}").unwrap(); + } + } + + // Stable param order, named to match parity.py. + let names = param_names(&cfg); + let params = model.params(); + assert_eq!(names.len(), params.len(), "param name/count mismatch"); + for (name, p) in names.iter().zip(¶ms) { + let shape = p.value().shape().to_vec(); + write_vec(&dir, &format!("w_{name}.txt"), ¶m_to_host(p), &shape); + } + + // Forward logits + loss, then backward → per-param grads. + let ids_t = ids_tensor(&ids, device); + let targets_t = ids_tensor(&targets, device); + let logits = model.forward(&ids_t); + write_vec( + &dir, + "logits.txt", + ¶m_to_host(&logits), + logits.value().shape(), + ); + + let loss = model.loss(&ids_t, &targets_t); + let loss_val = param_to_host(&loss)[0]; + { + let mut f = fs::File::create(dir.join("loss.txt")).unwrap(); + writeln!(f, "{loss_val:.8e}").unwrap(); + } + loss.backward(); + for (name, p) in names.iter().zip(¶ms) { + let g = p.grad().expect("param has no grad"); + let gh = g.to_device(Device::Cpu); + write_vec( + &dir, + &format!("g_{name}.txt"), + gh.as_slice::(), + g.shape(), + ); + } + + println!("parity: dumped to {} (loss={loss_val:.6e})", dir.display()); +} + +fn param_names(cfg: &Config) -> Vec { + let mut names = vec!["embed".to_string()]; + for l in 0..cfg.n_layers { + for p in [ + "attn_norm", + "wq", + "wk", + "wv", + "wo", + "ffn_norm", + "w_gate", + "w_up", + "w_down", + ] { + names.push(format!("l{l}_{p}")); + } + } + names.push("final_norm".to_string()); + names.push("lm_head".to_string()); + names +}