from __future__ import annotations import math import torch from reference.torch_attention import torch_attention from reference.torch_matmul import torch_matmul from reference.torch_online_softmax import torch_online_softmax from reference.torch_row_softmax import torch_row_softmax from reference.torch_vector_add import torch_vector_add def test_vector_add_matches_torch(): x = torch.randn(257) y = torch.randn(257) torch.testing.assert_close(torch_vector_add(x, y), x + y) def test_row_softmax_matches_torch(): x = torch.randn(32, 65) torch.testing.assert_close(torch_row_softmax(x), torch.softmax(x, dim=1)) def test_matmul_matches_torch(): a = torch.randn(16, 24) b = torch.randn(24, 8) torch.testing.assert_close(torch_matmul(a, b), a @ b) def test_online_softmax_matches_torch(): x = torch.randn(12, 33) torch.testing.assert_close( torch_online_softmax(x), torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5 ) def test_attention_matches_manual_formula(): q = torch.randn(1, 2, 8, 16) k = torch.randn(1, 2, 8, 16) v = torch.randn(1, 2, 8, 16) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1]) expected = torch.matmul(torch.softmax(scores, dim=-1), v) torch.testing.assert_close(torch_attention(q, k, v), expected, atol=1e-5, rtol=1e-5)