model: silence torch parity warning (read loss before backward)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-15 16:09:30 +08:00
parent 3366f30c4d
commit 603c85e1e0

View File

@@ -134,6 +134,7 @@ logits = h @ lm_head # [seq, vocab]
loss = torch.nn.functional.cross_entropy(
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
loss_val = loss.item()
loss.backward()
# ---- Compare ----
@@ -147,8 +148,8 @@ def relerr(a, b):
ref_logits = read_vec("logits.txt")
ref_loss = read_vec("loss.txt").item()
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
print(f"loss: rust={ref_loss:.6e} torch={loss_val:.6e} "
f"relerr={abs(loss_val-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
le = relerr(logits.detach(), ref_logits)
print(f"logits: max relerr = {le:.2e}")