Initial project scaffold
This commit is contained in:
39
tests/test_shapes.py
Normal file
39
tests/test_shapes.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from reference.torch_attention import torch_attention
|
||||
from reference.torch_matmul import torch_matmul
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
|
||||
|
||||
def test_vector_add_shape():
|
||||
x = torch.randn(11)
|
||||
y = torch.randn(11)
|
||||
assert torch_vector_add(x, y).shape == x.shape
|
||||
|
||||
|
||||
def test_row_softmax_shape():
|
||||
x = torch.randn(4, 9)
|
||||
assert torch_row_softmax(x).shape == x.shape
|
||||
|
||||
|
||||
def test_matmul_shape():
|
||||
a = torch.randn(5, 7)
|
||||
b = torch.randn(7, 3)
|
||||
assert torch_matmul(a, b).shape == (5, 3)
|
||||
|
||||
|
||||
def test_online_softmax_shape():
|
||||
x = torch.randn(3, 13)
|
||||
assert torch_online_softmax(x).shape == x.shape
|
||||
|
||||
|
||||
def test_attention_shape():
|
||||
q = torch.randn(2, 4, 8, 16)
|
||||
k = torch.randn(2, 4, 8, 16)
|
||||
v = torch.randn(2, 4, 8, 16)
|
||||
assert torch_attention(q, k, v).shape == q.shape
|
||||
|
||||
Reference in New Issue
Block a user