70 lines
2.2 KiB
C++
70 lines
2.2 KiB
C++
#include "../include/common.h"
|
|
#include "../include/cuda_utils.h"
|
|
|
|
#include <torch/extension.h>
|
|
|
|
namespace kernel_lab {
|
|
|
|
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
|
|
check_cuda_pair(x, y);
|
|
LAB_CHECK_SAME_SHAPE(x, y);
|
|
return vector_add_cuda(x, y);
|
|
}
|
|
|
|
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
|
|
LAB_CHECK_CUDA(x);
|
|
LAB_CHECK_CONTIGUOUS(x);
|
|
return row_softmax_cuda(x);
|
|
}
|
|
|
|
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
|
|
check_cuda_pair(a, b);
|
|
return tiled_matmul_cuda(a, b);
|
|
}
|
|
|
|
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
|
|
LAB_CHECK_CUDA(x);
|
|
LAB_CHECK_CONTIGUOUS(x);
|
|
return online_softmax_cuda(x);
|
|
}
|
|
|
|
torch::Tensor flash_attention_fwd_dispatch(
|
|
torch::Tensor q,
|
|
torch::Tensor k,
|
|
torch::Tensor v,
|
|
bool causal) {
|
|
LAB_CHECK_CUDA(q);
|
|
LAB_CHECK_CUDA(k);
|
|
LAB_CHECK_CUDA(v);
|
|
return flash_attention_fwd_cuda(q, k, v, causal);
|
|
}
|
|
|
|
} // namespace kernel_lab
|
|
|
|
TORCH_LIBRARY(kernel_lab, m) {
|
|
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
|
|
m.def("row_softmax(Tensor x) -> Tensor");
|
|
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
|
|
m.def("online_softmax(Tensor x) -> Tensor");
|
|
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
|
|
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
|
|
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
|
|
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
|
|
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
|
|
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
|
|
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
|
|
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
|
|
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
|
|
m.def(
|
|
"flash_attention_fwd_dispatch",
|
|
&kernel_lab::flash_attention_fwd_dispatch,
|
|
"Flash attention forward dispatch");
|
|
}
|