Initial project scaffold
This commit is contained in:
45
tests/test_correctness.py
Normal file
45
tests/test_correctness.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from reference.torch_attention import torch_attention
|
||||
from reference.torch_matmul import torch_matmul
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
|
||||
|
||||
def test_vector_add_matches_torch():
|
||||
x = torch.randn(257)
|
||||
y = torch.randn(257)
|
||||
torch.testing.assert_close(torch_vector_add(x, y), x + y)
|
||||
|
||||
|
||||
def test_row_softmax_matches_torch():
|
||||
x = torch.randn(32, 65)
|
||||
torch.testing.assert_close(torch_row_softmax(x), torch.softmax(x, dim=1))
|
||||
|
||||
|
||||
def test_matmul_matches_torch():
|
||||
a = torch.randn(16, 24)
|
||||
b = torch.randn(24, 8)
|
||||
torch.testing.assert_close(torch_matmul(a, b), a @ b)
|
||||
|
||||
|
||||
def test_online_softmax_matches_torch():
|
||||
x = torch.randn(12, 33)
|
||||
torch.testing.assert_close(
|
||||
torch_online_softmax(x), torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5
|
||||
)
|
||||
|
||||
|
||||
def test_attention_matches_manual_formula():
|
||||
q = torch.randn(1, 2, 8, 16)
|
||||
k = torch.randn(1, 2, 8, 16)
|
||||
v = torch.randn(1, 2, 8, 16)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
|
||||
expected = torch.matmul(torch.softmax(scores, dim=-1), v)
|
||||
torch.testing.assert_close(torch_attention(q, k, v), expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
18
tests/test_extension_import.py
Normal file
18
tests/test_extension_import.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
@pytest.mark.cuda_required
|
||||
@pytest.mark.skeleton
|
||||
def test_extension_can_build_or_skip():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None:
|
||||
pytest.skip("extension build/load is unavailable in this environment")
|
||||
assert hasattr(torch.ops, "kernel_lab")
|
||||
assert hasattr(torch.ops.kernel_lab, "vector_add")
|
||||
28
tests/test_numerics.py
Normal file
28
tests/test_numerics.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
|
||||
|
||||
def test_row_softmax_handles_large_values():
|
||||
x = torch.tensor([[10000.0, 10001.0, 9999.0]], dtype=torch.float32)
|
||||
out = torch_row_softmax(x)
|
||||
torch.testing.assert_close(out.sum(dim=1), torch.ones(1), atol=1e-6, rtol=1e-6)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_online_softmax_handles_large_negative_values():
|
||||
x = torch.tensor([[-10000.0, -9998.0, -9999.0]], dtype=torch.float32)
|
||||
out = torch_online_softmax(x)
|
||||
torch.testing.assert_close(out.sum(dim=1), torch.ones(1), atol=1e-6, rtol=1e-6)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_row_and_online_softmax_agree():
|
||||
x = torch.randn(10, 40) * 8.0
|
||||
torch.testing.assert_close(
|
||||
torch_row_softmax(x), torch_online_softmax(x), atol=1e-5, rtol=1e-5
|
||||
)
|
||||
|
||||
39
tests/test_shapes.py
Normal file
39
tests/test_shapes.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from reference.torch_attention import torch_attention
|
||||
from reference.torch_matmul import torch_matmul
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
|
||||
|
||||
def test_vector_add_shape():
|
||||
x = torch.randn(11)
|
||||
y = torch.randn(11)
|
||||
assert torch_vector_add(x, y).shape == x.shape
|
||||
|
||||
|
||||
def test_row_softmax_shape():
|
||||
x = torch.randn(4, 9)
|
||||
assert torch_row_softmax(x).shape == x.shape
|
||||
|
||||
|
||||
def test_matmul_shape():
|
||||
a = torch.randn(5, 7)
|
||||
b = torch.randn(7, 3)
|
||||
assert torch_matmul(a, b).shape == (5, 3)
|
||||
|
||||
|
||||
def test_online_softmax_shape():
|
||||
x = torch.randn(3, 13)
|
||||
assert torch_online_softmax(x).shape == x.shape
|
||||
|
||||
|
||||
def test_attention_shape():
|
||||
q = torch.randn(2, 4, 8, 16)
|
||||
k = torch.randn(2, 4, 8, 16)
|
||||
v = torch.randn(2, 4, 8, 16)
|
||||
assert torch_attention(q, k, v).shape == q.shape
|
||||
|
||||
Reference in New Issue
Block a user