parity_dump.rs (#[ignore] fixture generator) dumps the model's exact weights, ids, forward logits, loss, and per-param grads after one backward. parity.py rebuilds the IDENTICAL model in PyTorch (same x@W convention, RoPE rotate_half pos=row, RMSNorm, SwiGLU, causal SDPA), runs fwd+bwd, and compares logits + every grad within rtol. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
177 lines
4.9 KiB
Python
177 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
|
|
|
|
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
|
|
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
|
|
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
|
|
and compares the forward logits and every parameter's gradient against the Rust
|
|
values within a relative tolerance.
|
|
|
|
Usage: python3 parity.py /tmp/xtrain_parity
|
|
"""
|
|
import sys
|
|
import os
|
|
import math
|
|
import torch
|
|
|
|
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
|
|
|
|
|
|
def read_vec(name):
|
|
path = os.path.join(DIR, name)
|
|
shape = None
|
|
vals = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line.startswith("# shape"):
|
|
shape = [int(x) for x in line.split()[2].split(",") if x]
|
|
elif line:
|
|
vals.append(float(line))
|
|
t = torch.tensor(vals, dtype=torch.float64)
|
|
if shape:
|
|
t = t.reshape(shape)
|
|
return t
|
|
|
|
|
|
def read_cfg():
|
|
cfg = {}
|
|
with open(os.path.join(DIR, "config.txt")) as f:
|
|
for line in f:
|
|
k, v = line.split()
|
|
cfg[k] = v
|
|
return cfg
|
|
|
|
|
|
def read_ids(name):
|
|
with open(os.path.join(DIR, name)) as f:
|
|
return [int(x) for x in f.read().split()]
|
|
|
|
|
|
cfg = read_cfg()
|
|
DIM = int(cfg["dim"])
|
|
NL = int(cfg["n_layers"])
|
|
NH = int(cfg["n_heads"])
|
|
HD = int(cfg["head_dim"])
|
|
EPS = float(cfg["eps"])
|
|
THETA = float(cfg["rope_theta"])
|
|
|
|
ids = read_ids("ids.txt")
|
|
targets = read_ids("targets.txt")
|
|
SEQ = len(ids)
|
|
|
|
# Load params as leaf tensors requiring grad (float64 for a clean reference).
|
|
P = {}
|
|
|
|
|
|
def load(name):
|
|
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
|
|
P[name] = t
|
|
return t
|
|
|
|
|
|
def rms_norm(x, gamma):
|
|
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
|
|
ms = x.pow(2).mean(dim=-1, keepdim=True)
|
|
return x * torch.rsqrt(ms + EPS) * gamma
|
|
|
|
|
|
def rope(x): # x: [seq, nh, hd], position = token index, matching the kernel
|
|
half = HD // 2
|
|
out = torch.empty_like(x)
|
|
i = torch.arange(half, dtype=torch.float64)
|
|
freq = THETA ** (-(2.0 * i) / HD) # [half]
|
|
pos = torch.arange(SEQ, dtype=torch.float64).reshape(SEQ, 1) # [seq,1]
|
|
ang = pos * freq # [seq, half]
|
|
c = torch.cos(ang).reshape(SEQ, 1, half)
|
|
s = torch.sin(ang).reshape(SEQ, 1, half)
|
|
x0 = x[..., :half]
|
|
x1 = x[..., half:]
|
|
out[..., :half] = x0 * c - x1 * s
|
|
out[..., half:] = x1 * c + x0 * s
|
|
return out
|
|
|
|
|
|
emb = load("embed")
|
|
final_norm = load("final_norm")
|
|
lm_head = load("lm_head")
|
|
layers = []
|
|
for l in range(NL):
|
|
layers.append({p: load(f"l{l}_{p}") for p in
|
|
["attn_norm", "wq", "wk", "wv", "wo",
|
|
"ffn_norm", "w_gate", "w_up", "w_down"]})
|
|
|
|
idx = torch.tensor(ids, dtype=torch.long)
|
|
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
|
|
|
|
h = emb[idx] # [seq, dim]
|
|
for L in layers:
|
|
# Attention
|
|
x = rms_norm(h, L["attn_norm"])
|
|
q = (x @ L["wq"]).reshape(SEQ, NH, HD)
|
|
k = (x @ L["wk"]).reshape(SEQ, NH, HD)
|
|
v = (x @ L["wv"]).reshape(SEQ, NH, HD)
|
|
q = rope(q).transpose(0, 1) # [nh, seq, hd]
|
|
k = rope(k).transpose(0, 1)
|
|
v = v.transpose(0, 1)
|
|
scale = 1.0 / math.sqrt(HD)
|
|
scores = (q @ k.transpose(-1, -2)) * scale + mask # [nh, seq, seq]
|
|
probs = torch.softmax(scores, dim=-1)
|
|
out = probs @ v # [nh, seq, hd]
|
|
out = out.transpose(0, 1).reshape(SEQ, DIM) # [seq, dim]
|
|
attn = out @ L["wo"]
|
|
h = h + attn
|
|
# MLP
|
|
x = rms_norm(h, L["ffn_norm"])
|
|
gate = x @ L["w_gate"]
|
|
up = x @ L["w_up"]
|
|
act = torch.nn.functional.silu(gate) * up
|
|
mlp = act @ L["w_down"]
|
|
h = h + mlp
|
|
|
|
h = rms_norm(h, final_norm)
|
|
logits = h @ lm_head # [seq, vocab]
|
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
|
|
loss.backward()
|
|
|
|
# ---- Compare ----
|
|
def relerr(a, b):
|
|
a = a.double()
|
|
b = b.double()
|
|
denom = b.abs().clamp(min=1e-6)
|
|
return ((a - b).abs() / denom).max().item()
|
|
|
|
|
|
ref_logits = read_vec("logits.txt")
|
|
ref_loss = read_vec("loss.txt").item()
|
|
|
|
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
|
|
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
|
|
le = relerr(logits.detach(), ref_logits)
|
|
print(f"logits: max relerr = {le:.2e}")
|
|
|
|
RTOL = 2e-2
|
|
worst = le
|
|
worst_name = "logits"
|
|
fails = []
|
|
if le > RTOL:
|
|
fails.append(("logits", le))
|
|
|
|
for name, t in P.items():
|
|
ref_g = read_vec(f"g_{name}.txt")
|
|
ge = relerr(t.grad, ref_g)
|
|
if ge > worst:
|
|
worst, worst_name = ge, f"grad[{name}]"
|
|
if ge > RTOL:
|
|
fails.append((f"grad[{name}]", ge))
|
|
|
|
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
|
|
if fails:
|
|
print("FAIL:")
|
|
for n, e in fails:
|
|
print(f" {n}: relerr={e:.3e}")
|
|
sys.exit(1)
|
|
print("PARITY OK: forward logits + all param grads within rtol")
|