16 lines
513 B
C
16 lines
513 B
C
#pragma once
|
|
|
|
#include <torch/extension.h>
|
|
|
|
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
|
|
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
|
|
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
|
|
|
|
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
|
|
LAB_CHECK_CUDA(x);
|
|
LAB_CHECK_CUDA(y);
|
|
LAB_CHECK_CONTIGUOUS(x);
|
|
LAB_CHECK_CONTIGUOUS(y);
|
|
}
|
|
|