from __future__ import annotations import pytest import torch from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd from reference.torch_attention import torch_attention def _run_impl_or_skip(fn, *args, **kwargs): try: return fn(*args, **kwargs) except NotImplementedError: pytest.skip("implementation is still TODO") except RuntimeError as exc: pytest.skip(str(exc)) @pytest.mark.reference def test_attention_reference_small_shape(): q = torch.randn(1, 2, 8, 16) k = torch.randn(1, 2, 8, 16) v = torch.randn(1, 2, 8, 16) out = torch_attention(q, k, v, causal=False) expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False) torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5) @pytest.mark.reference def test_attention_reference_causal_small_shape(): q = torch.randn(1, 1, 8, 16) k = torch.randn(1, 1, 8, 16) v = torch.randn(1, 1, 8, 16) out = torch_attention(q, k, v, causal=True) expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5) @pytest.mark.triton_required @pytest.mark.skeleton def test_triton_flash_attention_if_available(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") q = torch.randn(1, 2, 16, 32, device="cuda") k = torch.randn(1, 2, 16, 32, device="cuda") v = torch.randn(1, 2, 16, 32, device="cuda") out = _run_impl_or_skip(triton_flash_attention_fwd, q, k, v, causal=False) expected = torch_attention(q, k, v, causal=False) torch.testing.assert_close(out, expected, atol=2e-3, rtol=2e-3)