46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
from reference.torch_attention import torch_attention
|
|
from reference.torch_matmul import torch_matmul
|
|
from reference.torch_online_softmax import torch_online_softmax
|
|
from reference.torch_row_softmax import torch_row_softmax
|
|
from reference.torch_vector_add import torch_vector_add
|
|
|
|
|
|
def test_vector_add_matches_torch():
|
|
x = torch.randn(257)
|
|
y = torch.randn(257)
|
|
torch.testing.assert_close(torch_vector_add(x, y), x + y)
|
|
|
|
|
|
def test_row_softmax_matches_torch():
|
|
x = torch.randn(32, 65)
|
|
torch.testing.assert_close(torch_row_softmax(x), torch.softmax(x, dim=1))
|
|
|
|
|
|
def test_matmul_matches_torch():
|
|
a = torch.randn(16, 24)
|
|
b = torch.randn(24, 8)
|
|
torch.testing.assert_close(torch_matmul(a, b), a @ b)
|
|
|
|
|
|
def test_online_softmax_matches_torch():
|
|
x = torch.randn(12, 33)
|
|
torch.testing.assert_close(
|
|
torch_online_softmax(x), torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5
|
|
)
|
|
|
|
|
|
def test_attention_matches_manual_formula():
|
|
q = torch.randn(1, 2, 8, 16)
|
|
k = torch.randn(1, 2, 8, 16)
|
|
v = torch.randn(1, 2, 8, 16)
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
|
|
expected = torch.matmul(torch.softmax(scores, dim=-1), v)
|
|
torch.testing.assert_close(torch_attention(q, k, v), expected, atol=1e-5, rtol=1e-5)
|
|
|