28 lines
839 B
Python
28 lines
839 B
Python
from __future__ import annotations
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tools.lab_extension import build_extension
|
|
|
|
|
|
@pytest.mark.cuda_required
|
|
@pytest.mark.skeleton
|
|
def test_vector_add_opcheck_if_available():
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("CUDA is not available")
|
|
ext = build_extension(verbose=False)
|
|
if ext is None or not hasattr(torch.ops, "kernel_lab"):
|
|
pytest.skip("extension is unavailable")
|
|
if not hasattr(torch.library, "opcheck"):
|
|
pytest.skip("torch.library.opcheck is unavailable")
|
|
|
|
x = torch.randn(32, device="cuda")
|
|
y = torch.randn(32, device="cuda")
|
|
try:
|
|
torch.ops.kernel_lab.vector_add(x, y)
|
|
except Exception as exc:
|
|
pytest.skip(f"operator is not implemented yet: {exc}")
|
|
|
|
torch.library.opcheck(torch.ops.kernel_lab.vector_add, (x, y))
|