Initial project scaffold
This commit is contained in:
69
kernels/cuda/binding/binding.cpp
Normal file
69
kernels/cuda/binding/binding.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
|
||||
check_cuda_pair(x, y);
|
||||
LAB_CHECK_SAME_SHAPE(x, y);
|
||||
return vector_add_cuda(x, y);
|
||||
}
|
||||
|
||||
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return row_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
|
||||
check_cuda_pair(a, b);
|
||||
return tiled_matmul_cuda(a, b);
|
||||
}
|
||||
|
||||
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return online_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor flash_attention_fwd_dispatch(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal) {
|
||||
LAB_CHECK_CUDA(q);
|
||||
LAB_CHECK_CUDA(k);
|
||||
LAB_CHECK_CUDA(v);
|
||||
return flash_attention_fwd_cuda(q, k, v, causal);
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
|
||||
TORCH_LIBRARY(kernel_lab, m) {
|
||||
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
|
||||
m.def("row_softmax(Tensor x) -> Tensor");
|
||||
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
|
||||
m.def("online_softmax(Tensor x) -> Tensor");
|
||||
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
|
||||
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
|
||||
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
|
||||
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
|
||||
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
|
||||
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
|
||||
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
|
||||
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
|
||||
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
|
||||
m.def(
|
||||
"flash_attention_fwd_dispatch",
|
||||
&kernel_lab::flash_attention_fwd_dispatch,
|
||||
"Flash attention forward dispatch");
|
||||
}
|
||||
Reference in New Issue
Block a user