Files
xtrain/crates/xtrain-model/tests/parity.py
Gahow Wang 830d06ad01 gqa: real grouped-query attention (repeat_kv op + both SDPA paths + wiring + tests)
- repeat_kv CUDA kernel: fwd head-block gather, bwd DETERMINISTIC group-sum (each
  kv head sums its group of query-head grads; no atomics) + Tensor/ops node.
- Config gains num_kv_heads (default = n_heads → MHA); wk/wv project to kv_dim;
  attention() repeat_kv-broadcasts K/V to nh heads before the UNCHANGED composed
  & flash SDPA → GQA on both paths. group=1 is identity → MHA bit-identical.
- --kv-heads flag on train/train_ddp/export_safetensors/greedy_sample; export
  writes real num_key_value_heads (xserv repeat_kv grouping aligned).
- Tests: repeat_kv grad-check (group>1 grad-sum + group=1 identity); model gqa.rs
  (GQA flash==composed fp32/bf16, group=1 bit-identical to MHA, kv-proj shape);
  parity_dump+parity.py GQA path (repeat_interleave) via XTRAIN_PARITY_KV_HEADS.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 01:37:37 +08:00

199 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""PyTorch parity check for the xtrain tiny transformer (Phase T5).
Loads the weights/ids dumped by tests/parity_dump.rs, rebuilds the IDENTICAL
model in PyTorch (same x@W convention, same RoPE rotate_half + position=row,
same RMSNorm, SwiGLU, causal mask, per-head SDPA), runs forward + one backward,
and compares the forward logits and every parameter's gradient against the Rust
values within a relative tolerance.
Usage: python3 parity.py /tmp/xtrain_parity
"""
import sys
import os
import math
import torch
DIR = sys.argv[1] if len(sys.argv) > 1 else "/tmp/xtrain_parity"
def read_vec(name):
path = os.path.join(DIR, name)
shape = None
vals = []
with open(path) as f:
for line in f:
line = line.strip()
if line.startswith("# shape"):
shape = [int(x) for x in line.split()[2].split(",") if x]
elif line:
vals.append(float(line))
t = torch.tensor(vals, dtype=torch.float64)
if shape:
t = t.reshape(shape)
return t
def read_cfg():
cfg = {}
with open(os.path.join(DIR, "config.txt")) as f:
for line in f:
k, v = line.split()
cfg[k] = v
return cfg
def read_ids(name):
with open(os.path.join(DIR, name)) as f:
return [int(x) for x in f.read().split()]
cfg = read_cfg()
DIM = int(cfg["dim"])
NL = int(cfg["n_layers"])
NH = int(cfg["n_heads"])
# GQA (T15): num_kv_heads <= n_heads; each kv head shared by group query heads.
# Default to NH (MHA) for fixtures dumped before the field existed.
NKV = int(cfg.get("num_kv_heads", str(NH)))
GROUP = NH // NKV
HD = int(cfg["head_dim"])
EPS = float(cfg["eps"])
THETA = float(cfg["rope_theta"])
# Batched: B sequences of length SEQ, flattened sequence-major to [B*SEQ] ids.
B = int(cfg.get("batch", "1"))
SEQ = int(cfg["seq"])
ids = read_ids("ids.txt")
targets = read_ids("targets.txt")
assert len(ids) == B * SEQ, f"ids {len(ids)} != B*SEQ {B*SEQ}"
# Load params as leaf tensors requiring grad (float64 for a clean reference).
P = {}
def load(name):
t = read_vec(f"w_{name}.txt").clone().requires_grad_(True)
P[name] = t
return t
def rms_norm(x, gamma):
# y = x / sqrt(mean(x^2)+eps) * gamma (no mean subtraction)
ms = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(ms + EPS) * gamma
def rope(x): # x: [B*SEQ, nh, hd], position = (row % SEQ) — resets per sequence
half = HD // 2
out = torch.empty_like(x)
i = torch.arange(half, dtype=torch.float64)
freq = THETA ** (-(2.0 * i) / HD) # [half]
# Position within each sequence: rows 0..SEQ for seq 0, 0..SEQ for seq 1, ...
pos = (torch.arange(B * SEQ, dtype=torch.float64) % SEQ).reshape(B * SEQ, 1)
ang = pos * freq # [B*SEQ, half]
c = torch.cos(ang).reshape(B * SEQ, 1, half)
s = torch.sin(ang).reshape(B * SEQ, 1, half)
x0 = x[..., :half]
x1 = x[..., half:]
out[..., :half] = x0 * c - x1 * s
out[..., half:] = x1 * c + x0 * s
return out
emb = load("embed")
final_norm = load("final_norm")
lm_head = load("lm_head")
layers = []
for l in range(NL):
layers.append({p: load(f"l{l}_{p}") for p in
["attn_norm", "wq", "wk", "wv", "q_norm", "k_norm", "wo",
"ffn_norm", "w_gate", "w_up", "w_down"]})
idx = torch.tensor(ids, dtype=torch.long)
# Per-sequence causal mask (broadcast over the batch); NO cross-sequence attention.
mask = torch.triu(torch.full((SEQ, SEQ), -1.0e9, dtype=torch.float64), diagonal=1)
h = emb[idx] # [B*SEQ, dim] (everything stays flattened, matching the Rust path)
for L in layers:
# Attention
x = rms_norm(h, L["attn_norm"])
q = (x @ L["wq"]).reshape(B * SEQ, NH, HD)
# GQA: K/V project to NKV heads, then repeat each kv head GROUP times to NH.
k = (x @ L["wk"]).reshape(B * SEQ, NKV, HD)
v = (x @ L["wv"]).reshape(B * SEQ, NKV, HD)
# Per-head QK-norm (Qwen3-style), before RoPE.
q = rms_norm(q, L["q_norm"])
k = rms_norm(k, L["k_norm"])
q = rope(q) # [B*SEQ, nh, hd]
k = rope(k) # [B*SEQ, nkv, hd]
# Reshape to [B, *, SEQ, HD]; broadcast kv heads to NH (repeat_interleave along
# the head axis: kv head kvh → query heads [kvh*GROUP, (kvh+1)*GROUP), matching
# xtrain repeat_kv + xserv repeat_kv).
q = q.reshape(B, SEQ, NH, HD).transpose(1, 2) # [B, nh, seq, hd]
k = k.reshape(B, SEQ, NKV, HD).transpose(1, 2) # [B, nkv, seq, hd]
v = v.reshape(B, SEQ, NKV, HD).transpose(1, 2)
if GROUP > 1:
k = k.repeat_interleave(GROUP, dim=1) # [B, nh, seq, hd]
v = v.repeat_interleave(GROUP, dim=1)
scale = 1.0 / math.sqrt(HD)
scores = (q @ k.transpose(-1, -2)) * scale + mask # [B, nh, seq, seq]
probs = torch.softmax(scores, dim=-1)
out = probs @ v # [B, nh, seq, hd]
out = out.transpose(1, 2).reshape(B * SEQ, DIM) # [B*SEQ, dim]
attn = out @ L["wo"]
h = h + attn
# MLP
x = rms_norm(h, L["ffn_norm"])
gate = x @ L["w_gate"]
up = x @ L["w_up"]
act = torch.nn.functional.silu(gate) * up
mlp = act @ L["w_down"]
h = h + mlp
h = rms_norm(h, final_norm)
logits = h @ lm_head # [B*SEQ, vocab]
loss = torch.nn.functional.cross_entropy(
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
loss_val = loss.item()
loss.backward()
# ---- Compare ----
def relerr(a, b):
a = a.double()
b = b.double()
denom = b.abs().clamp(min=1e-6)
return ((a - b).abs() / denom).max().item()
ref_logits = read_vec("logits.txt")
ref_loss = read_vec("loss.txt").item()
print(f"loss: rust={ref_loss:.6e} torch={loss_val:.6e} "
f"relerr={abs(loss_val-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
le = relerr(logits.detach(), ref_logits)
print(f"logits: max relerr = {le:.2e}")
RTOL = 2e-2
worst = le
worst_name = "logits"
fails = []
if le > RTOL:
fails.append(("logits", le))
for name, t in P.items():
ref_g = read_vec(f"g_{name}.txt")
ge = relerr(t.grad, ref_g)
if ge > worst:
worst, worst_name = ge, f"grad[{name}]"
if ge > RTOL:
fails.append((f"grad[{name}]", ge))
print(f"params checked: {len(P)} worst = {worst_name} @ {worst:.2e} (rtol={RTOL})")
if fails:
print("FAIL:")
for n, e in fails:
print(f" {n}: relerr={e:.3e}")
sys.exit(1)
print("PARITY OK: forward logits + all param grads within rtol")