Files
kernel-lab/tasks/06_pytorch_custom_op/opcheck_test.py
2026-04-10 13:15:06 +00:00

28 lines
839 B
Python

from __future__ import annotations
import pytest
import torch
from tools.lab_extension import build_extension
@pytest.mark.cuda_required
@pytest.mark.skeleton
def test_vector_add_opcheck_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
ext = build_extension(verbose=False)
if ext is None or not hasattr(torch.ops, "kernel_lab"):
pytest.skip("extension is unavailable")
if not hasattr(torch.library, "opcheck"):
pytest.skip("torch.library.opcheck is unavailable")
x = torch.randn(32, device="cuda")
y = torch.randn(32, device="cuda")
try:
torch.ops.kernel_lab.vector_add(x, y)
except Exception as exc:
pytest.skip(f"operator is not implemented yet: {exc}")
torch.library.opcheck(torch.ops.kernel_lab.vector_add, (x, y))