Files
kernel-lab/tasks/02_row_softmax/test_task.py
2026-04-10 13:22:19 +00:00

41 lines
1.2 KiB
Python

from __future__ import annotations
import pytest
import torch
from kernels.triton.row_softmax import triton_row_softmax
from reference.torch_row_softmax import torch_row_softmax
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_row_softmax_reference_matches_torch():
x = torch.randn(8, 17)
out = torch_row_softmax(x)
torch.testing.assert_close(out, torch.softmax(x, dim=1))
@pytest.mark.reference
def test_row_softmax_reference_is_numerically_stable():
x = torch.tensor([[1000.0, 1001.0, 1002.0], [-1000.0, -999.0, -998.0]])
out = torch_row_softmax(x)
torch.testing.assert_close(out.sum(dim=1), torch.ones(2), atol=1e-6, rtol=1e-6)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_row_softmax_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
x = torch.randn(16, 63, device="cuda")
out = _run_impl_or_skip(triton_row_softmax, x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)