- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each kv head sums its group of query-head grads; no atomics) + Tensor/ops node. - Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim; attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical. - --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export writes real num_key_value_heads (xserv repeat_kv grouping aligned). - Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape); parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
199 lines
6.2 KiB
Python
199 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
|
|
|
|
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
|
|
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
|
|
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
|
|
and compares the forward logits and every parameter's gradient against the Rust
|
|
values within a relative tolerance.
|
|
|
|
Usage: python3 parity.py /tmp/xtrain_parity
|
|
"""
|
|
import sys
|
|
import os
|
|
import math
|
|
import torch
|
|
|
|
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
|
|
|
|
|
|
def read_vec(name):
|
|
path = os.path.join(DIR, name)
|
|
shape = None
|
|
vals = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line.startswith("# shape"):
|
|
shape = [int(x) for x in line.split()[2].split(",") if x]
|
|
elif line:
|
|
vals.append(float(line))
|
|
t = torch.tensor(vals, dtype=torch.float64)
|
|
if shape:
|
|
t = t.reshape(shape)
|
|
return t
|
|
|
|
|
|
def read_cfg():
|
|
cfg = {}
|
|
with open(os.path.join(DIR, "config.txt")) as f:
|
|
for line in f:
|
|
k, v = line.split()
|
|
cfg[k] = v
|
|
return cfg
|
|
|
|
|
|
def read_ids(name):
|
|
with open(os.path.join(DIR, name)) as f:
|
|
return [int(x) for x in f.read().split()]
|
|
|
|
|
|
cfg = read_cfg()
|
|
DIM = int(cfg["dim"])
|
|
NL = int(cfg["n_layers"])
|
|
NH = int(cfg["n_heads"])
|
|
# GQA (T15): num_kv_heads <= n_heads; each kv head shared by group query heads.
|
|
# Default to NH (MHA) for fixtures dumped before the field existed.
|
|
NKV = int(cfg.get("num_kv_heads", str(NH)))
|
|
GROUP = NH // NKV
|
|
HD = int(cfg["head_dim"])
|
|
EPS = float(cfg["eps"])
|
|
THETA = float(cfg["rope_theta"])
|
|
# Batched: B sequences of length SEQ, flattened sequence-major to [B*SEQ] ids.
|
|
B = int(cfg.get("batch", "1"))
|
|
SEQ = int(cfg["seq"])
|
|
|
|
ids = read_ids("ids.txt")
|
|
targets = read_ids("targets.txt")
|
|
assert len(ids) == B * SEQ, f"ids {len(ids)} != B*SEQ {B*SEQ}"
|
|
|
|
# Load params as leaf tensors requiring grad (float64 for a clean reference).
|
|
P = {}
|
|
|
|
|
|
def load(name):
|
|
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
|
|
P[name] = t
|
|
return t
|
|
|
|
|
|
def rms_norm(x, gamma):
|
|
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
|
|
ms = x.pow(2).mean(dim=-1, keepdim=True)
|
|
return x * torch.rsqrt(ms + EPS) * gamma
|
|
|
|
|
|
def rope(x): # x: [B*SEQ, nh, hd], position = (row % SEQ) — resets per sequence
|
|
half = HD // 2
|
|
out = torch.empty_like(x)
|
|
i = torch.arange(half, dtype=torch.float64)
|
|
freq = THETA ** (-(2.0 * i) / HD) # [half]
|
|
# Position within each sequence: rows 0..SEQ for seq 0, 0..SEQ for seq 1, ...
|
|
pos = (torch.arange(B * SEQ, dtype=torch.float64) % SEQ).reshape(B * SEQ, 1)
|
|
ang = pos * freq # [B*SEQ, half]
|
|
c = torch.cos(ang).reshape(B * SEQ, 1, half)
|
|
s = torch.sin(ang).reshape(B * SEQ, 1, half)
|
|
x0 = x[..., :half]
|
|
x1 = x[..., half:]
|
|
out[..., :half] = x0 * c - x1 * s
|
|
out[..., half:] = x1 * c + x0 * s
|
|
return out
|
|
|
|
|
|
emb = load("embed")
|
|
final_norm = load("final_norm")
|
|
lm_head = load("lm_head")
|
|
layers = []
|
|
for l in range(NL):
|
|
layers.append({p: load(f"l{l}_{p}") for p in
|
|
["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
|
|
"ffn_norm", "w_gate", "w_up", "w_down"]})
|
|
|
|
idx = torch.tensor(ids, dtype=torch.long)
|
|
# Per-sequence causal mask (broadcast over the batch); NO cross-sequence attention.
|
|
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
|
|
|
|
h = emb[idx] # [B*SEQ, dim] (everything stays flattened, matching the Rust path)
|
|
for L in layers:
|
|
# Attention
|
|
x = rms_norm(h, L["attn_norm"])
|
|
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
|
|
# GQA: K/V project to NKV heads, then repeat each kv head GROUP times to NH.
|
|
k = (x @ L["wk"]).reshape(B * SEQ, NKV, HD)
|
|
v = (x @ L["wv"]).reshape(B * SEQ, NKV, HD)
|
|
# Per-head QK-norm (Qwen3-style), before RoPE.
|
|
q = rms_norm(q, L["q_norm"])
|
|
k = rms_norm(k, L["k_norm"])
|
|
q = rope(q) # [B*SEQ, nh, hd]
|
|
k = rope(k) # [B*SEQ, nkv, hd]
|
|
# Reshape to [B, *, SEQ, HD]; broadcast kv heads to NH (repeat_interleave along
|
|
# the head axis: kv head kvh → query heads [kvh*GROUP, (kvh+1)*GROUP), matching
|
|
# xtrain repeat_kv + xserv repeat_kv).
|
|
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
|
|
k = k.reshape(B, SEQ, NKV, HD).transpose(1, 2) # [B, nkv, seq, hd]
|
|
v = v.reshape(B, SEQ, NKV, HD).transpose(1, 2)
|
|
if GROUP > 1:
|
|
k = k.repeat_interleave(GROUP, dim=1) # [B, nh, seq, hd]
|
|
v = v.repeat_interleave(GROUP, dim=1)
|
|
scale = 1.0 / math.sqrt(HD)
|
|
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
|
|
probs = torch.softmax(scores, dim=-1)
|
|
out = probs @ v # [B, nh, seq, hd]
|
|
out = out.transpose(1, 2).reshape(B * SEQ, DIM) # [B*SEQ, dim]
|
|
attn = out @ L["wo"]
|
|
h = h + attn
|
|
# MLP
|
|
x = rms_norm(h, L["ffn_norm"])
|
|
gate = x @ L["w_gate"]
|
|
up = x @ L["w_up"]
|
|
act = torch.nn.functional.silu(gate) * up
|
|
mlp = act @ L["w_down"]
|
|
h = h + mlp
|
|
|
|
h = rms_norm(h, final_norm)
|
|
logits = h @ lm_head # [B*SEQ, vocab]
|
|
|
|
loss = torch.nn.functional.cross_entropy(
|
|
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
|
|
loss_val = loss.item()
|
|
loss.backward()
|
|
|
|
# ---- Compare ----
|
|
def relerr(a, b):
|
|
a = a.double()
|
|
b = b.double()
|
|
denom = b.abs().clamp(min=1e-6)
|
|
return ((a - b).abs() / denom).max().item()
|
|
|
|
|
|
ref_logits = read_vec("logits.txt")
|
|
ref_loss = read_vec("loss.txt").item()
|
|
|
|
print(f"loss: rust={ref_loss:.6e} torch={loss_val:.6e} "
|
|
f"relerr={abs(loss_val-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
|
|
le = relerr(logits.detach(), ref_logits)
|
|
print(f"logits: max relerr = {le:.2e}")
|
|
|
|
RTOL = 2e-2
|
|
worst = le
|
|
worst_name = "logits"
|
|
fails = []
|
|
if le > RTOL:
|
|
fails.append(("logits", le))
|
|
|
|
for name, t in P.items():
|
|
ref_g = read_vec(f"g_{name}.txt")
|
|
ge = relerr(t.grad, ref_g)
|
|
if ge > worst:
|
|
worst, worst_name = ge, f"grad[{name}]"
|
|
if ge > RTOL:
|
|
fails.append((f"grad[{name}]", ge))
|
|
|
|
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
|
|
if fails:
|
|
print("FAIL:")
|
|
for n, e in fails:
|
|
print(f" {n}: relerr={e:.3e}")
|
|
sys.exit(1)
|
|
print("PARITY OK: forward logits + all param grads within rtol")
|