from __future__ import annotations import pytest import torch from kernels.triton.tiled_matmul import triton_tiled_matmul from reference.torch_matmul import torch_matmul def _run_impl_or_skip(fn, *args): try: return fn(*args) except NotImplementedError: pytest.skip("implementation is still TODO") except RuntimeError as exc: pytest.skip(str(exc)) @pytest.mark.reference def test_tiled_matmul_reference_matches_torch(): a = torch.randn(8, 16) b = torch.randn(16, 12) out = torch_matmul(a, b) torch.testing.assert_close(out, a @ b) @pytest.mark.triton_required @pytest.mark.skeleton def test_triton_tiled_matmul_if_available(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") a = torch.randn(32, 48, device="cuda") b = torch.randn(48, 40, device="cuda") out = _run_impl_or_skip(triton_tiled_matmul, a, b) torch.testing.assert_close(out, a @ b, atol=1e-3, rtol=1e-3)