model: PyTorch parity harness (weight dump + equivalent torch model)
parity_dump.rs (#[ignore] fixture generator) dumps the model's exact weights, ids, forward logits, loss, and per-param grads after one backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA), runs fwd+bwd, and compares logits + every grad within rtol. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
176
crates/xtrain-model/tests/parity.py
Normal file
176
crates/xtrain-model/tests/parity.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
|
||||
|
||||
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
|
||||
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
|
||||
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
|
||||
and compares the forward logits and every parameter's gradient against the Rust
|
||||
values within a relative tolerance.
|
||||
|
||||
Usage: python3 parity.py /tmp/xtrain_parity
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
|
||||
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
|
||||
|
||||
|
||||
def read_vec(name):
|
||||
path = os.path.join(DIR, name)
|
||||
shape = None
|
||||
vals = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("# shape"):
|
||||
shape = [int(x) for x in line.split()[2].split(",") if x]
|
||||
elif line:
|
||||
vals.append(float(line))
|
||||
t = torch.tensor(vals, dtype=torch.float64)
|
||||
if shape:
|
||||
t = t.reshape(shape)
|
||||
return t
|
||||
|
||||
|
||||
def read_cfg():
|
||||
cfg = {}
|
||||
with open(os.path.join(DIR, "config.txt")) as f:
|
||||
for line in f:
|
||||
k, v = line.split()
|
||||
cfg[k] = v
|
||||
return cfg
|
||||
|
||||
|
||||
def read_ids(name):
|
||||
with open(os.path.join(DIR, name)) as f:
|
||||
return [int(x) for x in f.read().split()]
|
||||
|
||||
|
||||
cfg = read_cfg()
|
||||
DIM = int(cfg["dim"])
|
||||
NL = int(cfg["n_layers"])
|
||||
NH = int(cfg["n_heads"])
|
||||
HD = int(cfg["head_dim"])
|
||||
EPS = float(cfg["eps"])
|
||||
THETA = float(cfg["rope_theta"])
|
||||
|
||||
ids = read_ids("ids.txt")
|
||||
targets = read_ids("targets.txt")
|
||||
SEQ = len(ids)
|
||||
|
||||
# Load params as leaf tensors requiring grad (float64 for a clean reference).
|
||||
P = {}
|
||||
|
||||
|
||||
def load(name):
|
||||
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
|
||||
P[name] = t
|
||||
return t
|
||||
|
||||
|
||||
def rms_norm(x, gamma):
|
||||
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
|
||||
ms = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
return x * torch.rsqrt(ms + EPS) * gamma
|
||||
|
||||
|
||||
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
|
||||
half = HD // 2
|
||||
out = torch.empty_like(x)
|
||||
i = torch.arange(half, dtype=torch.float64)
|
||||
freq = THETA ** (-(2.0 * i) / HD) # [half]
|
||||
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
|
||||
ang = pos * freq # [seq, half]
|
||||
c = torch.cos(ang).reshape(SEQ, 1, half)
|
||||
s = torch.sin(ang).reshape(SEQ, 1, half)
|
||||
x0 = x[..., :half]
|
||||
x1 = x[..., half:]
|
||||
out[..., :half] = x0 * c - x1 * s
|
||||
out[..., half:] = x1 * c + x0 * s
|
||||
return out
|
||||
|
||||
|
||||
emb = load("embed")
|
||||
final_norm = load("final_norm")
|
||||
lm_head = load("lm_head")
|
||||
layers = []
|
||||
for l in range(NL):
|
||||
layers.append({p: load(f"l{l}_{p}") for p in
|
||||
["attn_norm", "wq", "wk", "wv", "wo",
|
||||
"ffn_norm", "w_gate", "w_up", "w_down"]})
|
||||
|
||||
idx = torch.tensor(ids, dtype=torch.long)
|
||||
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
|
||||
|
||||
h = emb[idx] # [seq, dim]
|
||||
for L in layers:
|
||||
# Attention
|
||||
x = rms_norm(h, L["attn_norm"])
|
||||
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
|
||||
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
|
||||
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
|
||||
q = rope(q).transpose(0, 1) # [nh, seq, hd]
|
||||
k = rope(k).transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
out = probs @ v # [nh, seq, hd]
|
||||
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
|
||||
attn = out @ L["wo"]
|
||||
h = h + attn
|
||||
# MLP
|
||||
x = rms_norm(h, L["ffn_norm"])
|
||||
gate = x @ L["w_gate"]
|
||||
up = x @ L["w_up"]
|
||||
act = torch.nn.functional.silu(gate) * up
|
||||
mlp = act @ L["w_down"]
|
||||
h = h + mlp
|
||||
|
||||
h = rms_norm(h, final_norm)
|
||||
logits = h @ lm_head # [seq, vocab]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
|
||||
loss.backward()
|
||||
|
||||
# ---- Compare ----
|
||||
def relerr(a, b):
|
||||
a = a.double()
|
||||
b = b.double()
|
||||
denom = b.abs().clamp(min=1e-6)
|
||||
return ((a - b).abs() / denom).max().item()
|
||||
|
||||
|
||||
ref_logits = read_vec("logits.txt")
|
||||
ref_loss = read_vec("loss.txt").item()
|
||||
|
||||
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
|
||||
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
|
||||
le = relerr(logits.detach(), ref_logits)
|
||||
print(f"logits: max relerr = {le:.2e}")
|
||||
|
||||
RTOL = 2e-2
|
||||
worst = le
|
||||
worst_name = "logits"
|
||||
fails = []
|
||||
if le > RTOL:
|
||||
fails.append(("logits", le))
|
||||
|
||||
for name, t in P.items():
|
||||
ref_g = read_vec(f"g_{name}.txt")
|
||||
ge = relerr(t.grad, ref_g)
|
||||
if ge > worst:
|
||||
worst, worst_name = ge, f"grad[{name}]"
|
||||
if ge > RTOL:
|
||||
fails.append((f"grad[{name}]", ge))
|
||||
|
||||
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
|
||||
if fails:
|
||||
print("FAIL:")
|
||||
for n, e in fails:
|
||||
print(f" {n}: relerr={e:.3e}")
|
||||
sys.exit(1)
|
||||
print("PARITY OK: forward logits + all param grads within rtol")
|
||||
160
crates/xtrain-model/tests/parity_dump.rs
Normal file
160
crates/xtrain-model/tests/parity_dump.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
// PyTorch parity, step 1 of 2: dump the Rust tiny-transformer's exact weights,
|
||||
// inputs, forward logits, loss, and per-parameter gradients (after one backward)
|
||||
// to a directory, so an equivalent PyTorch model (tests/parity.py) can be built
|
||||
// from the SAME weights and the forward + grads compared within rtol.
|
||||
//
|
||||
// Run: XTRAIN_PARITY_DIR=/tmp/xtrain_parity cargo test -p xtrain-model \
|
||||
// --test parity_dump -- --nocapture --ignored
|
||||
// then: python3 crates/xtrain-model/tests/parity.py /tmp/xtrain_parity
|
||||
//
|
||||
// Marked #[ignore] (it's a fixture generator, not a pass/fail assertion) and
|
||||
// gated #![cfg(not(no_cuda))].
|
||||
#![cfg(not(no_cuda))]
|
||||
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use xtrain_cuda::device;
|
||||
use xtrain_model::{Config, TinyTransformer, ids_tensor, param_to_host};
|
||||
use xtrain_tensor::Device;
|
||||
|
||||
fn fill(n: usize, seed: u64, scale: f32) -> Vec<f32> {
|
||||
let mut state = seed
|
||||
.wrapping_mul(2862933555777941757)
|
||||
.wrapping_add(3037000493);
|
||||
(0..n)
|
||||
.map(|_| {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
(((state >> 33) as f32 / (1u64 << 31) as f32) - 0.5) * 2.0 * scale
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn write_vec(dir: &PathBuf, name: &str, data: &[f32], shape: &[usize]) {
|
||||
let mut f = fs::File::create(dir.join(name)).unwrap();
|
||||
let shape_str: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
|
||||
writeln!(f, "# shape {}", shape_str.join(",")).unwrap();
|
||||
for v in data {
|
||||
writeln!(f, "{v:.8e}").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "fixture generator for PyTorch parity; run with --ignored"]
|
||||
fn dump_for_parity() {
|
||||
assert!(device::device_count().unwrap() > 0, "no CUDA device");
|
||||
device::set_device(0).unwrap();
|
||||
let device = Device::Cuda(0);
|
||||
|
||||
let dir = PathBuf::from(
|
||||
std::env::var("XTRAIN_PARITY_DIR").unwrap_or_else(|_| "/tmp/xtrain_parity".to_string()),
|
||||
);
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
// Fixed config + ids (independent of any text, for reproducibility).
|
||||
let mut cfg = Config::tiny();
|
||||
cfg.vocab = 12;
|
||||
let ids: Vec<i32> = vec![3, 1, 4, 1, 5, 9, 2, 6];
|
||||
let targets: Vec<i32> = vec![1, 4, 1, 5, 9, 2, 6, 0];
|
||||
let seq = ids.len();
|
||||
|
||||
// Same deterministic init as the overfit test.
|
||||
let mut seed = 1u64;
|
||||
let model = TinyTransformer::new(cfg, device, |shape| {
|
||||
seed = seed.wrapping_add(1);
|
||||
let n: usize = shape.iter().product();
|
||||
if shape.len() == 1 {
|
||||
fill(n, seed, 0.02).iter().map(|v| v + 1.0).collect()
|
||||
} else {
|
||||
fill(n, seed, 0.08)
|
||||
}
|
||||
});
|
||||
|
||||
// config + ids
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("config.txt")).unwrap();
|
||||
writeln!(f, "vocab {}", cfg.vocab).unwrap();
|
||||
writeln!(f, "dim {}", cfg.dim).unwrap();
|
||||
writeln!(f, "n_layers {}", cfg.n_layers).unwrap();
|
||||
writeln!(f, "n_heads {}", cfg.n_heads).unwrap();
|
||||
writeln!(f, "head_dim {}", cfg.head_dim).unwrap();
|
||||
writeln!(f, "ffn_hidden {}", cfg.ffn_hidden).unwrap();
|
||||
writeln!(f, "eps {:e}", cfg.eps).unwrap();
|
||||
writeln!(f, "rope_theta {:e}", cfg.rope_theta).unwrap();
|
||||
writeln!(f, "seq {seq}").unwrap();
|
||||
}
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("ids.txt")).unwrap();
|
||||
for v in &ids {
|
||||
writeln!(f, "{v}").unwrap();
|
||||
}
|
||||
let mut f = fs::File::create(dir.join("targets.txt")).unwrap();
|
||||
for v in &targets {
|
||||
writeln!(f, "{v}").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Stable param order, named to match parity.py.
|
||||
let names = param_names(&cfg);
|
||||
let params = model.params();
|
||||
assert_eq!(names.len(), params.len(), "param name/count mismatch");
|
||||
for (name, p) in names.iter().zip(¶ms) {
|
||||
let shape = p.value().shape().to_vec();
|
||||
write_vec(&dir, &format!("w_{name}.txt"), ¶m_to_host(p), &shape);
|
||||
}
|
||||
|
||||
// Forward logits + loss, then backward → per-param grads.
|
||||
let ids_t = ids_tensor(&ids, device);
|
||||
let targets_t = ids_tensor(&targets, device);
|
||||
let logits = model.forward(&ids_t);
|
||||
write_vec(
|
||||
&dir,
|
||||
"logits.txt",
|
||||
¶m_to_host(&logits),
|
||||
logits.value().shape(),
|
||||
);
|
||||
|
||||
let loss = model.loss(&ids_t, &targets_t);
|
||||
let loss_val = param_to_host(&loss)[0];
|
||||
{
|
||||
let mut f = fs::File::create(dir.join("loss.txt")).unwrap();
|
||||
writeln!(f, "{loss_val:.8e}").unwrap();
|
||||
}
|
||||
loss.backward();
|
||||
for (name, p) in names.iter().zip(¶ms) {
|
||||
let g = p.grad().expect("param has no grad");
|
||||
let gh = g.to_device(Device::Cpu);
|
||||
write_vec(
|
||||
&dir,
|
||||
&format!("g_{name}.txt"),
|
||||
gh.as_slice::<f32>(),
|
||||
g.shape(),
|
||||
);
|
||||
}
|
||||
|
||||
println!("parity: dumped to {} (loss={loss_val:.6e})", dir.display());
|
||||
}
|
||||
|
||||
fn param_names(cfg: &Config) -> Vec<String> {
|
||||
let mut names = vec!["embed".to_string()];
|
||||
for l in 0..cfg.n_layers {
|
||||
for p in [
|
||||
"attn_norm",
|
||||
"wq",
|
||||
"wk",
|
||||
"wv",
|
||||
"wo",
|
||||
"ffn_norm",
|
||||
"w_gate",
|
||||
"w_up",
|
||||
"w_down",
|
||||
] {
|
||||
names.push(format!("l{l}_{p}"));
|
||||
}
|
||||
}
|
||||
names.push("final_norm".to_string());
|
||||
names.push("lm_head".to_string());
|
||||
names
|
||||
}
|
||||
Reference in New Issue
Block a user