#include "../include/common.h" #include "../include/cuda_utils.h" #include namespace kernel_lab { torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) { check_cuda_pair(x, y); LAB_CHECK_SAME_SHAPE(x, y); return vector_add_cuda(x, y); } torch::Tensor row_softmax_dispatch(torch::Tensor x) { LAB_CHECK_CUDA(x); LAB_CHECK_CONTIGUOUS(x); return row_softmax_cuda(x); } torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) { check_cuda_pair(a, b); return tiled_matmul_cuda(a, b); } torch::Tensor online_softmax_dispatch(torch::Tensor x) { LAB_CHECK_CUDA(x); LAB_CHECK_CONTIGUOUS(x); return online_softmax_cuda(x); } torch::Tensor flash_attention_fwd_dispatch( torch::Tensor q, torch::Tensor k, torch::Tensor v, bool causal) { LAB_CHECK_CUDA(q); LAB_CHECK_CUDA(k); LAB_CHECK_CUDA(v); return flash_attention_fwd_cuda(q, k, v, causal); } } // namespace kernel_lab TORCH_LIBRARY(kernel_lab, m) { m.def("vector_add(Tensor x, Tensor y) -> Tensor"); m.def("row_softmax(Tensor x) -> Tensor"); m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor"); m.def("online_softmax(Tensor x) -> Tensor"); m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor"); } TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) { m.impl("vector_add", &kernel_lab::vector_add_dispatch); m.impl("row_softmax", &kernel_lab::row_softmax_dispatch); m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch); m.impl("online_softmax", &kernel_lab::online_softmax_dispatch); m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch"); m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch"); m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch"); m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch"); m.def( "flash_attention_fwd_dispatch", &kernel_lab::flash_attention_fwd_dispatch, "Flash attention forward dispatch"); }