36 lines
976 B
Python
36 lines
976 B
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
|
from reference.torch_matmul import torch_matmul
|
|
|
|
|
|
def _run_impl_or_skip(fn, *args):
|
|
try:
|
|
return fn(*args)
|
|
except NotImplementedError:
|
|
pytest.skip("implementation is still TODO")
|
|
except RuntimeError as exc:
|
|
pytest.skip(str(exc))
|
|
|
|
|
|
@pytest.mark.reference
|
|
def test_tiled_matmul_reference_matches_torch():
|
|
a = torch.randn(8, 16)
|
|
b = torch.randn(16, 12)
|
|
out = torch_matmul(a, b)
|
|
torch.testing.assert_close(out, a @ b)
|
|
|
|
|
|
@pytest.mark.triton_required
|
|
@pytest.mark.skeleton
|
|
def test_triton_tiled_matmul_if_available():
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("CUDA is not available")
|
|
a = torch.randn(32, 48, device="cuda")
|
|
b = torch.randn(48, 40, device="cuda")
|
|
out = _run_impl_or_skip(triton_tiled_matmul, a, b)
|
|
torch.testing.assert_close(out, a @ b, atol=1e-3, rtol=1e-3)
|