Files
kernel-lab/kernels/cuda/binding/binding.cpp
2026-04-10 13:22:19 +00:00

70 lines
2.2 KiB
C++

#include "../include/common.h"
#include "../include/cuda_utils.h"
#include <torch/extension.h>
namespace kernel_lab {
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
check_cuda_pair(x, y);
LAB_CHECK_SAME_SHAPE(x, y);
return vector_add_cuda(x, y);
}
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return row_softmax_cuda(x);
}
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
check_cuda_pair(a, b);
return tiled_matmul_cuda(a, b);
}
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return online_softmax_cuda(x);
}
torch::Tensor flash_attention_fwd_dispatch(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal) {
LAB_CHECK_CUDA(q);
LAB_CHECK_CUDA(k);
LAB_CHECK_CUDA(v);
return flash_attention_fwd_cuda(q, k, v, causal);
}
} // namespace kernel_lab
TORCH_LIBRARY(kernel_lab, m) {
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
m.def("row_softmax(Tensor x) -> Tensor");
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
m.def("online_softmax(Tensor x) -> Tensor");
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
}
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
m.def(
"flash_attention_fwd_dispatch",
&kernel_lab::flash_attention_fwd_dispatch,
"Flash attention forward dispatch");
}