Initial project scaffold
This commit is contained in:
29
kernels/cuda/CMakeLists.txt
Normal file
29
kernels/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,29 @@
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
project(kernel_lab LANGUAGES CXX CUDA)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
set(CMAKE_CUDA_ARCHITECTURES 120 CACHE STRING "Target CUDA architectures")
|
||||
endif()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_library(kernel_lab_extension SHARED
|
||||
binding/binding.cpp
|
||||
src/vector_add.cu
|
||||
src/row_softmax.cu
|
||||
src/tiled_matmul.cu
|
||||
src/online_softmax.cu
|
||||
src/flash_attention_fwd.cu
|
||||
)
|
||||
|
||||
target_include_directories(kernel_lab_extension PRIVATE include)
|
||||
target_link_libraries(kernel_lab_extension PRIVATE "${TORCH_LIBRARIES}")
|
||||
target_compile_features(kernel_lab_extension PRIVATE cxx_std_17)
|
||||
set_target_properties(kernel_lab_extension PROPERTIES
|
||||
PREFIX ""
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
69
kernels/cuda/binding/binding.cpp
Normal file
69
kernels/cuda/binding/binding.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
|
||||
check_cuda_pair(x, y);
|
||||
LAB_CHECK_SAME_SHAPE(x, y);
|
||||
return vector_add_cuda(x, y);
|
||||
}
|
||||
|
||||
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return row_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
|
||||
check_cuda_pair(a, b);
|
||||
return tiled_matmul_cuda(a, b);
|
||||
}
|
||||
|
||||
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return online_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor flash_attention_fwd_dispatch(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal) {
|
||||
LAB_CHECK_CUDA(q);
|
||||
LAB_CHECK_CUDA(k);
|
||||
LAB_CHECK_CUDA(v);
|
||||
return flash_attention_fwd_cuda(q, k, v, causal);
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
|
||||
TORCH_LIBRARY(kernel_lab, m) {
|
||||
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
|
||||
m.def("row_softmax(Tensor x) -> Tensor");
|
||||
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
|
||||
m.def("online_softmax(Tensor x) -> Tensor");
|
||||
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
|
||||
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
|
||||
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
|
||||
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
|
||||
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
|
||||
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
|
||||
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
|
||||
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
|
||||
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
|
||||
m.def(
|
||||
"flash_attention_fwd_dispatch",
|
||||
&kernel_lab::flash_attention_fwd_dispatch,
|
||||
"Flash attention forward dispatch");
|
||||
}
|
||||
17
kernels/cuda/include/common.h
Normal file
17
kernels/cuda/include/common.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
|
||||
torch::Tensor row_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
|
||||
torch::Tensor online_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor flash_attention_fwd_cuda(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal);
|
||||
|
||||
} // namespace kernel_lab
|
||||
15
kernels/cuda/include/cuda_utils.h
Normal file
15
kernels/cuda/include/cuda_utils.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
|
||||
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
|
||||
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
|
||||
|
||||
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CUDA(y);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
LAB_CHECK_CONTIGUOUS(y);
|
||||
}
|
||||
|
||||
54
kernels/cuda/src/flash_attention_fwd.cu
Normal file
54
kernels/cuda/src/flash_attention_fwd.cu
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void flash_attention_fwd_kernel(
|
||||
const float* q,
|
||||
const float* k,
|
||||
const float* v,
|
||||
float* out,
|
||||
int64_t batch,
|
||||
int64_t heads,
|
||||
int64_t seq_len,
|
||||
int64_t head_dim,
|
||||
bool causal) {
|
||||
(void)q;
|
||||
(void)k;
|
||||
(void)v;
|
||||
(void)out;
|
||||
(void)batch;
|
||||
(void)heads;
|
||||
(void)seq_len;
|
||||
(void)head_dim;
|
||||
(void)causal;
|
||||
|
||||
// TODO(student): assign each block to a batch/head/query tile.
|
||||
// TODO(student): cooperatively load K/V tiles.
|
||||
// TODO(student): compute score blocks and apply causal masking when requested.
|
||||
// TODO(student): maintain online softmax state and accumulate the output tile.
|
||||
}
|
||||
|
||||
torch::Tensor flash_attention_fwd_cuda(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal) {
|
||||
LAB_CHECK_CUDA(q);
|
||||
LAB_CHECK_CUDA(k);
|
||||
LAB_CHECK_CUDA(v);
|
||||
LAB_CHECK_CONTIGUOUS(q);
|
||||
LAB_CHECK_CONTIGUOUS(k);
|
||||
LAB_CHECK_CONTIGUOUS(v);
|
||||
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must match");
|
||||
TORCH_CHECK(q.sizes() == v.sizes(), "q and v must match");
|
||||
TORCH_CHECK(q.dim() == 4, "flash_attention_fwd_cuda expects [batch, heads, seq, dim]");
|
||||
TORCH_CHECK(q.scalar_type() == torch::kFloat32, "flash_attention_fwd_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement flash_attention_fwd_cuda in kernels/cuda/src/flash_attention_fwd.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
36
kernels/cuda/src/online_softmax.cu
Normal file
36
kernels/cuda/src/online_softmax.cu
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void online_softmax_kernel(
|
||||
const float* x,
|
||||
float* out,
|
||||
int64_t num_rows,
|
||||
int64_t num_cols) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(student): maintain running max and running sum across column tiles.
|
||||
// TODO(student): write the normalized row after finishing the recurrence.
|
||||
(void)x;
|
||||
(void)out;
|
||||
(void)num_rows;
|
||||
(void)num_cols;
|
||||
}
|
||||
|
||||
torch::Tensor online_softmax_cuda(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
TORCH_CHECK(x.dim() == 2, "online_softmax_cuda expects a 2D tensor");
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "online_softmax_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement online_softmax_cuda in kernels/cuda/src/online_softmax.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
37
kernels/cuda/src/row_softmax.cu
Normal file
37
kernels/cuda/src/row_softmax.cu
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void row_softmax_kernel(
|
||||
const float* x,
|
||||
float* out,
|
||||
int64_t num_rows,
|
||||
int64_t num_cols) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(student): decide whether one block owns one row or one row tile.
|
||||
// TODO(student): compute the row max for numerical stability.
|
||||
// TODO(student): compute exp(x - max), reduce the sum, and normalize.
|
||||
(void)x;
|
||||
(void)out;
|
||||
(void)num_rows;
|
||||
(void)num_cols;
|
||||
}
|
||||
|
||||
torch::Tensor row_softmax_cuda(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
TORCH_CHECK(x.dim() == 2, "row_softmax_cuda expects a 2D tensor");
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "row_softmax_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement row_softmax_cuda in kernels/cuda/src/row_softmax.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
40
kernels/cuda/src/tiled_matmul.cu
Normal file
40
kernels/cuda/src/tiled_matmul.cu
Normal file
@@ -0,0 +1,40 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void tiled_matmul_kernel(
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* c,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k) {
|
||||
// TODO(student): map blockIdx/threadIdx to a C tile.
|
||||
// TODO(student): cooperatively load A and B tiles into shared memory.
|
||||
// TODO(student): accumulate partial products across the K dimension.
|
||||
(void)a;
|
||||
(void)b;
|
||||
(void)c;
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
}
|
||||
|
||||
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b) {
|
||||
LAB_CHECK_CUDA(a);
|
||||
LAB_CHECK_CUDA(b);
|
||||
LAB_CHECK_CONTIGUOUS(a);
|
||||
LAB_CHECK_CONTIGUOUS(b);
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "tiled_matmul_cuda expects 2D tensors");
|
||||
TORCH_CHECK(a.size(1) == b.size(0), "inner dimensions must match");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement tiled_matmul_cuda in kernels/cuda/src/tiled_matmul.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
35
kernels/cuda/src/vector_add.cu
Normal file
35
kernels/cuda/src/vector_add.cu
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void vector_add_kernel(
|
||||
const float* x,
|
||||
const float* y,
|
||||
float* out,
|
||||
int64_t numel) {
|
||||
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (global_idx >= numel) {
|
||||
return;
|
||||
}
|
||||
|
||||
(void)x;
|
||||
(void)y;
|
||||
(void)out;
|
||||
(void)numel;
|
||||
// TODO(student): replace this placeholder with the real vector-add math.
|
||||
// Hint: one thread should own one element for the first implementation.
|
||||
}
|
||||
|
||||
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y) {
|
||||
check_cuda_pair(x, y);
|
||||
LAB_CHECK_SAME_SHAPE(x, y);
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "vector_add_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement vector_add_cuda in kernels/cuda/src/vector_add.cu and then launch the kernel.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
Reference in New Issue
Block a user