model: silence torch parity warning (read loss before backward)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -134,6 +134,7 @@ logits = h @ lm_head # [seq, vocab]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits, torch.tensor(targets, dtype=torch.long), reduction="mean")
|
||||
loss_val = loss.item()
|
||||
loss.backward()
|
||||
|
||||
# ---- Compare ----
|
||||
@@ -147,8 +148,8 @@ def relerr(a, b):
|
||||
ref_logits = read_vec("logits.txt")
|
||||
ref_loss = read_vec("loss.txt").item()
|
||||
|
||||
print(f"loss: rust={ref_loss:.6e} torch={loss.item():.6e} "
|
||||
f"relerr={abs(loss.item()-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
|
||||
print(f"loss: rust={ref_loss:.6e} torch={loss_val:.6e} "
|
||||
f"relerr={abs(loss_val-ref_loss)/max(abs(ref_loss),1e-6):.2e}")
|
||||
le = relerr(logits.detach(), ref_logits)
|
||||
print(f"logits: max relerr = {le:.2e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user