model: PyTorch parity harness (weight dump + equivalent torch model)

parity_dump.rs (#[ignore] fixture generator) dumps the model's exact
weights, ids, forward logits, loss, and per-param grads after one
backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W
convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA),
runs fwd+bwd, and compares logits + every grad within rtol.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:07:30 +08:00
parent e3912c2380
commit 3366f30c4d
2 changed files with 336 additions and 0 deletions

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python3
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
and compares the forward logits and every parameter's gradient against the Rust
values within a relative tolerance.
Usage: python3 parity.py /tmp/xtrain_parity
"""
import sys
import os
import math
import torch
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
def read_vec(name):
path = os.path.join(DIR, name)
shape = None
vals = []
with open(path) as f:
for line in f:
line = line.strip()
if line.startswith("# shape"):
shape = [int(x) for x in line.split()[2].split(",") if x]
elif line:
vals.append(float(line))
t = torch.tensor(vals, dtype=torch.float64)
if shape:
t = t.reshape(shape)
return t
def read_cfg():
cfg = {}
with open(os.path.join(DIR, "config.txt")) as f:
for line in f:
k, v = line.split()
cfg[k] = v
return cfg
def read_ids(name):
with open(os.path.join(DIR, name)) as f:
return [int(x) for x in f.read().split()]
cfg = read_cfg()
DIM = int(cfg["dim"])
NL = int(cfg["n_layers"])
NH = int(cfg["n_heads"])
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
ids = read_ids("ids.txt")
targets = read_ids("targets.txt")
SEQ = len(ids)
# Load params as leaf tensors requiring grad (float64 for a clean reference).
P = {}
def load(name):
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
P[name] = t
return t
def rms_norm(x, gamma):
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + EPS) * gamma
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
half = HD // 2
out = torch.empty_like(x)
i = torch.arange(half, dtype=torch.float64)
freq = THETA ** (-(2.0 * i) / HD) # [half]
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
ang = pos * freq # [seq, half]
c = torch.cos(ang).reshape(SEQ, 1, half)
s = torch.sin(ang).reshape(SEQ, 1, half)
x0 = x[..., :half]
x1 = x[..., half:]
out[..., :half] = x0 * c - x1 * s
out[..., half:] = x1 * c + x0 * s
return out
emb = load("embed")
final_norm = load("final_norm")
lm_head = load("lm_head")
layers = []
for l in range(NL):
layers.append({p: load(f"l{l}_{p}") for p in
["attn_norm", "wq", "wk", "wv", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]})
idx = torch.tensor(ids, dtype=torch.long)
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
h = emb[idx] # [seq, dim]
for L in layers:
# Attention
x = rms_norm(h, L["attn_norm"])
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
q = rope(q).transpose(0, 1) # [nh, seq, hd]
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
probs = torch.softmax(scores, dim=-1)
out = probs @ v # [nh, seq, hd]
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
attn = out @ L["wo"]
h = h + attn
# MLP
x = rms_norm(h, L["ffn_norm"])
gate = x @ L["w_gate"]
up = x @ L["w_up"]
act = torch.nn.functional.silu(gate) * up
mlp = act @ L["w_down"]
h = h + mlp
h = rms_norm(h, final_norm)
logits = h @ lm_head # [seq, vocab]
loss = torch.nn.functional.cross_entropy(
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
loss.backward()
# ---- Compare ----
def relerr(a, b):
a = a.double()
b = b.double()
denom = b.abs().clamp(min=1e-6)
return ((a - b).abs() / denom).max().item()
ref_logits = read_vec("logits.txt")
ref_loss = read_vec("loss.txt").item()
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
le = relerr(logits.detach(), ref_logits)
print(f"logits: max relerr = {le:.2e}")
RTOL = 2e-2
worst = le
worst_name = "logits"
fails = []
if le > RTOL:
fails.append(("logits", le))
for name, t in P.items():
ref_g = read_vec(f"g_{name}.txt")
ge = relerr(t.grad, ref_g)
if ge > worst:
worst, worst_name = ge, f"grad[{name}]"
if ge > RTOL:
fails.append((f"grad[{name}]", ge))
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
if fails:
print("FAIL:")
for n, e in fails:
print(f" {n}: relerr={e:.3e}")
sys.exit(1)
print("PARITY OK: forward logits + all param grads within rtol")

View File

@@ -0,0 +1,160 @@
// PyTorch parity, step 1 of 2: dump the Rust tiny-transformer's exact weights,
// inputs, forward logits, loss, and per-parameter gradients (after one backward)
// to a directory, so an equivalent PyTorch model (tests/parity.py) can be built
// from the SAME weights and the forward + grads compared within rtol.
//
// Run: XTRAIN_PARITY_DIR=/tmp/xtrain_parity cargo test -p xtrain-model \
// --test parity_dump -- --nocapture --ignored
// then: python3 crates/xtrain-model/tests/parity.py /tmp/xtrain_parity
//
// Marked #[ignore] (it's a fixture generator, not a pass/fail assertion) and
// gated #![cfg(not(no_cuda))].
#![cfg(not(no_cuda))]
use std::fs;
use std::io::Write;
use std::path::PathBuf;
use xtrain_cuda::device;
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
use xtrain_tensor::Device;
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
let mut state = seed
.wrapping_mul(2862933555777941757)
.wrapping_add(3037000493);
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
})
.collect()
}
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
let mut f = fs::File::create(dir.join(name)).unwrap();
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
for v in data {
writeln!(f, "{v:.8e}").unwrap();
}
}
#[test]
#[ignore = "fixture generator for PyTorch parity; run with --ignored"]
fn dump_for_parity() {
assert!(device::device_count().unwrap() > 0, "no CUDA device");
device::set_device(0).unwrap();
let device = Device::Cuda(0);
let dir = PathBuf::from(
std::env::var("XTRAIN_PARITY_DIR").unwrap_or_else(|_| "/tmp/xtrain_parity".to_string()),
);
fs::create_dir_all(&dir).unwrap();
// Fixed config + ids (independent of any text, for reproducibility).
let mut cfg = Config::tiny();
cfg.vocab = 12;
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
let seq = ids.len();
// Same deterministic init as the overfit test.
let mut seed = 1u64;
let model = TinyTransformer::new(cfg, device, |shape| {
seed = seed.wrapping_add(1);
let n: usize = shape.iter().product();
if shape.len() == 1 {
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
} else {
fill(n, seed, 0.08)
}
});
// config + ids
{
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
writeln!(f, "vocab {}", cfg.vocab).unwrap();
writeln!(f, "dim {}", cfg.dim).unwrap();
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
writeln!(f, "eps {:e}", cfg.eps).unwrap();
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
writeln!(f, "seq {seq}").unwrap();
}
{
let mut f = fs::File::create(dir.join("ids.txt")).unwrap();
for v in &ids {
writeln!(f, "{v}").unwrap();
}
let mut f = fs::File::create(dir.join("targets.txt")).unwrap();
for v in &targets {
writeln!(f, "{v}").unwrap();
}
}
// Stable param order, named to match parity.py.
let names = param_names(&cfg);
let params = model.params();
assert_eq!(names.len(), params.len(), "param name/count mismatch");
for (name, p) in names.iter().zip(&params) {
let shape = p.value().shape().to_vec();
write_vec(&dir, &format!("w_{name}.txt"), &param_to_host(p), &shape);
}
// Forward logits + loss, then backward → per-param grads.
let ids_t = ids_tensor(&ids, device);
let targets_t = ids_tensor(&targets, device);
let logits = model.forward(&ids_t);
write_vec(
&dir,
"logits.txt",
&param_to_host(&logits),
logits.value().shape(),
);
let loss = model.loss(&ids_t, &targets_t);
let loss_val = param_to_host(&loss)[0];
{
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();
writeln!(f, "{loss_val:.8e}").unwrap();
}
loss.backward();
for (name, p) in names.iter().zip(&params) {
let g = p.grad().expect("param has no grad");
let gh = g.to_device(Device::Cpu);
write_vec(
&dir,
&format!("g_{name}.txt"),
gh.as_slice::<f32>(),
g.shape(),
);
}
println!("parity: dumped to {} (loss={loss_val:.6e})", dir.display());
}
fn param_names(cfg: &Config) -> Vec<String> {
let mut names = vec!["embed".to_string()];
for l in 0..cfg.n_layers {
for p in [
"attn_norm",
"wq",
"wk",
"wv",
"wo",
"ffn_norm",
"w_gate",
"w_up",
"w_down",
] {
names.push(format!("l{l}_{p}"));
}
}
names.push("final_norm".to_string());
names.push("lm_head".to_string());
names
}