from __future__ import annotations import pytest import torch from kernels.triton.vector_add import triton_vector_add from reference.torch_vector_add import torch_vector_add def _run_impl_or_skip(fn, *args): try: return fn(*args) except NotImplementedError: pytest.skip("implementation is still TODO") except RuntimeError as exc: pytest.skip(str(exc)) @pytest.mark.reference def test_vector_add_reference_matches_torch(): x = torch.randn(257) y = torch.randn(257) out = torch_vector_add(x, y) torch.testing.assert_close(out, x + y) @pytest.mark.triton_required @pytest.mark.skeleton def test_triton_vector_add_if_available(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") x = torch.randn(513, device="cuda") y = torch.randn(513, device="cuda") out = _run_impl_or_skip(triton_vector_add, x, y) torch.testing.assert_close(out, x + y)