18 lines
450 B
C++
18 lines
450 B
C++
#pragma once
|
|
|
|
#include <torch/extension.h>
|
|
|
|
namespace kernel_lab {
|
|
|
|
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
|
|
torch::Tensor row_softmax_cuda(torch::Tensor x);
|
|
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
|
|
torch::Tensor online_softmax_cuda(torch::Tensor x);
|
|
torch::Tensor flash_attention_fwd_cuda(
|
|
torch::Tensor q,
|
|
torch::Tensor k,
|
|
torch::Tensor v,
|
|
bool causal);
|
|
|
|
} // namespace kernel_lab
|