from __future__ import annotations import pytest import torch from kernels.triton.row_softmax import triton_row_softmax from reference.torch_row_softmax import torch_row_softmax def _run_impl_or_skip(fn, *args): try: return fn(*args) except NotImplementedError: pytest.skip("implementation is still TODO") except RuntimeError as exc: pytest.skip(str(exc)) @pytest.mark.reference def test_row_softmax_reference_matches_torch(): x = torch.randn(8, 17) out = torch_row_softmax(x) torch.testing.assert_close(out, torch.softmax(x, dim=1)) @pytest.mark.reference def test_row_softmax_reference_is_numerically_stable(): x = torch.tensor([[1000.0, 1001.0, 1002.0], [-1000.0, -999.0, -998.0]]) out = torch_row_softmax(x) torch.testing.assert_close(out.sum(dim=1), torch.ones(2), atol=1e-6, rtol=1e-6) @pytest.mark.triton_required @pytest.mark.skeleton def test_triton_row_softmax_if_available(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") x = torch.randn(16, 63, device="cuda") out = _run_impl_or_skip(triton_row_softmax, x) torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)