Initial project scaffold
This commit is contained in:
27
tasks/06_pytorch_custom_op/opcheck_test.py
Normal file
27
tasks/06_pytorch_custom_op/opcheck_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tools.lab_extension import build_extension
|
||||
|
||||
|
||||
@pytest.mark.cuda_required
|
||||
@pytest.mark.skeleton
|
||||
def test_vector_add_opcheck_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
ext = build_extension(verbose=False)
|
||||
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
||||
pytest.skip("extension is unavailable")
|
||||
if not hasattr(torch.library, "opcheck"):
|
||||
pytest.skip("torch.library.opcheck is unavailable")
|
||||
|
||||
x = torch.randn(32, device="cuda")
|
||||
y = torch.randn(32, device="cuda")
|
||||
try:
|
||||
torch.ops.kernel_lab.vector_add(x, y)
|
||||
except Exception as exc:
|
||||
pytest.skip(f"operator is not implemented yet: {exc}")
|
||||
|
||||
torch.library.opcheck(torch.ops.kernel_lab.vector_add, (x, y))
|
||||
Reference in New Issue
Block a user