Initial project scaffold
This commit is contained in:
17
kernels/cuda/include/common.h
Normal file
17
kernels/cuda/include/common.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
|
||||
torch::Tensor row_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
|
||||
torch::Tensor online_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor flash_attention_fwd_cuda(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal);
|
||||
|
||||
} // namespace kernel_lab
|
||||
15
kernels/cuda/include/cuda_utils.h
Normal file
15
kernels/cuda/include/cuda_utils.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
|
||||
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
|
||||
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
|
||||
|
||||
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CUDA(y);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
LAB_CHECK_CONTIGUOUS(y);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user