Files
xtrain/crates/xtrain-model/tests/parity.py
Gahow Wang 3366f30c4d model: PyTorch parity harness (weight dump + equivalent torch model)
parity_dump.rs (#[ignore] fixture generator) dumps the model's exact
weights, ids, forward logits, loss, and per-param grads after one
backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W
convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA),
runs fwd+bwd, and compares logits + every grad within rtol.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-15 16:07:30 +08:00

177 lines
4.9 KiB
Python

#!/usr/bin/env python3
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
and compares the forward logits and every parameter's gradient against the Rust
values within a relative tolerance.
Usage: python3 parity.py /tmp/xtrain_parity
"""
import sys
import os
import math
import torch
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
def read_vec(name):
path = os.path.join(DIR, name)
shape = None
vals = []
with open(path) as f:
for line in f:
line = line.strip()
if line.startswith("# shape"):
shape = [int(x) for x in line.split()[2].split(",") if x]
elif line:
vals.append(float(line))
t = torch.tensor(vals, dtype=torch.float64)
if shape:
t = t.reshape(shape)
return t
def read_cfg():
cfg = {}
with open(os.path.join(DIR, "config.txt")) as f:
for line in f:
k, v = line.split()
cfg[k] = v
return cfg
def read_ids(name):
with open(os.path.join(DIR, name)) as f:
return [int(x) for x in f.read().split()]
cfg = read_cfg()
DIM = int(cfg["dim"])
NL = int(cfg["n_layers"])
NH = int(cfg["n_heads"])
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
ids = read_ids("ids.txt")
targets = read_ids("targets.txt")
SEQ = len(ids)
# Load params as leaf tensors requiring grad (float64 for a clean reference).
P = {}
def load(name):
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
P[name] = t
return t
def rms_norm(x, gamma):
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + EPS) * gamma
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
half = HD // 2
out = torch.empty_like(x)
i = torch.arange(half, dtype=torch.float64)
freq = THETA ** (-(2.0 * i) / HD) # [half]
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
ang = pos * freq # [seq, half]
c = torch.cos(ang).reshape(SEQ, 1, half)
s = torch.sin(ang).reshape(SEQ, 1, half)
x0 = x[..., :half]
x1 = x[..., half:]
out[..., :half] = x0 * c - x1 * s
out[..., half:] = x1 * c + x0 * s
return out
emb = load("embed")
final_norm = load("final_norm")
lm_head = load("lm_head")
layers = []
for l in range(NL):
layers.append({p: load(f"l{l}_{p}") for p in
["attn_norm", "wq", "wk", "wv", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]})
idx = torch.tensor(ids, dtype=torch.long)
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
h = emb[idx] # [seq, dim]
for L in layers:
# Attention
x = rms_norm(h, L["attn_norm"])
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
q = rope(q).transpose(0, 1) # [nh, seq, hd]
k = rope(k).transpose(0, 1)
v = v.transpose(0, 1)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
probs = torch.softmax(scores, dim=-1)
out = probs @ v # [nh, seq, hd]
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
attn = out @ L["wo"]
h = h + attn
# MLP
x = rms_norm(h, L["ffn_norm"])
gate = x @ L["w_gate"]
up = x @ L["w_up"]
act = torch.nn.functional.silu(gate) * up
mlp = act @ L["w_down"]
h = h + mlp
h = rms_norm(h, final_norm)
logits = h @ lm_head # [seq, vocab]
loss = torch.nn.functional.cross_entropy(
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
loss.backward()
# ---- Compare ----
def relerr(a, b):
a = a.double()
b = b.double()
denom = b.abs().clamp(min=1e-6)
return ((a - b).abs() / denom).max().item()
ref_logits = read_vec("logits.txt")
ref_loss = read_vec("loss.txt").item()
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
le = relerr(logits.detach(), ref_logits)
print(f"logits: max relerr = {le:.2e}")
RTOL = 2e-2
worst = le
worst_name = "logits"
fails = []
if le > RTOL:
fails.append(("logits", le))
for name, t in P.items():
ref_g = read_vec(f"g_{name}.txt")
ge = relerr(t.grad, ref_g)
if ge > worst:
worst, worst_name = ge, f"grad[{name}]"
if ge > RTOL:
fails.append((f"grad[{name}]", ge))
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
if fails:
print("FAIL:")
for n, e in fails:
print(f" {n}: relerr={e:.3e}")
sys.exit(1)
print("PARITY OK: forward logits + all param grads within rtol")