from __future__ import annotations import torch from reference.torch_attention import torch_attention from reference.torch_matmul import torch_matmul from reference.torch_online_softmax import torch_online_softmax from reference.torch_row_softmax import torch_row_softmax from reference.torch_vector_add import torch_vector_add def test_vector_add_shape(): x = torch.randn(11) y = torch.randn(11) assert torch_vector_add(x, y).shape == x.shape def test_row_softmax_shape(): x = torch.randn(4, 9) assert torch_row_softmax(x).shape == x.shape def test_matmul_shape(): a = torch.randn(5, 7) b = torch.randn(7, 3) assert torch_matmul(a, b).shape == (5, 3) def test_online_softmax_shape(): x = torch.randn(3, 13) assert torch_online_softmax(x).shape == x.shape def test_attention_shape(): q = torch.randn(2, 4, 8, 16) k = torch.randn(2, 4, 8, 16) v = torch.randn(2, 4, 8, 16) assert torch_attention(q, k, v).shape == q.shape