Initial project scaffold
This commit is contained in:
28
tests/test_numerics.py
Normal file
28
tests/test_numerics.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
|
||||
|
||||
def test_row_softmax_handles_large_values():
|
||||
x = torch.tensor([[10000.0, 10001.0, 9999.0]], dtype=torch.float32)
|
||||
out = torch_row_softmax(x)
|
||||
torch.testing.assert_close(out.sum(dim=1), torch.ones(1), atol=1e-6, rtol=1e-6)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_online_softmax_handles_large_negative_values():
|
||||
x = torch.tensor([[-10000.0, -9998.0, -9999.0]], dtype=torch.float32)
|
||||
out = torch_online_softmax(x)
|
||||
torch.testing.assert_close(out.sum(dim=1), torch.ones(1), atol=1e-6, rtol=1e-6)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_row_and_online_softmax_agree():
|
||||
x = torch.randn(10, 40) * 8.0
|
||||
torch.testing.assert_close(
|
||||
torch_row_softmax(x), torch_online_softmax(x), atol=1e-5, rtol=1e-5
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user