Initial project scaffold
This commit is contained in:
33
tasks/04_online_softmax/test_task.py
Normal file
33
tasks/04_online_softmax/test_task.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.online_softmax import triton_online_softmax
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_online_softmax_reference_matches_torch():
|
||||
x = torch.randn(6, 19)
|
||||
out = torch_online_softmax(x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_online_softmax_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
x = torch.randn(8, 97, device="cuda")
|
||||
out = _run_impl_or_skip(triton_online_softmax, x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)
|
||||
Reference in New Issue
Block a user