Initial project scaffold
This commit is contained in:
49
tasks/05_flash_attention_fwd/test_task.py
Normal file
49
tasks/05_flash_attention_fwd/test_task.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
||||
from reference.torch_attention import torch_attention
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_attention_reference_small_shape():
|
||||
q = torch.randn(1, 2, 8, 16)
|
||||
k = torch.randn(1, 2, 8, 16)
|
||||
v = torch.randn(1, 2, 8, 16)
|
||||
out = torch_attention(q, k, v, causal=False)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_attention_reference_causal_small_shape():
|
||||
q = torch.randn(1, 1, 8, 16)
|
||||
k = torch.randn(1, 1, 8, 16)
|
||||
v = torch.randn(1, 1, 8, 16)
|
||||
out = torch_attention(q, k, v, causal=True)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_flash_attention_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
q = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
k = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
v = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
out = _run_impl_or_skip(triton_flash_attention_fwd, q, k, v, causal=False)
|
||||
expected = torch_attention(q, k, v, causal=False)
|
||||
torch.testing.assert_close(out, expected, atol=2e-3, rtol=2e-3)
|
||||
Reference in New Issue
Block a user