Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,29 @@
cmake_minimum_required(VERSION 3.25)
project(kernel_lab LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 120 CACHE STRING "Target CUDA architectures")
endif()
find_package(Torch REQUIRED)
add_library(kernel_lab_extension SHARED
binding/binding.cpp
src/vector_add.cu
src/row_softmax.cu
src/tiled_matmul.cu
src/online_softmax.cu
src/flash_attention_fwd.cu
)
target_include_directories(kernel_lab_extension PRIVATE include)
target_link_libraries(kernel_lab_extension PRIVATE "${TORCH_LIBRARIES}")
target_compile_features(kernel_lab_extension PRIVATE cxx_std_17)
set_target_properties(kernel_lab_extension PROPERTIES
PREFIX ""
CUDA_SEPARABLE_COMPILATION ON
)

View File

@@ -0,0 +1,69 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
#include <torch/extension.h>
namespace kernel_lab {
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
check_cuda_pair(x, y);
LAB_CHECK_SAME_SHAPE(x, y);
return vector_add_cuda(x, y);
}
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return row_softmax_cuda(x);
}
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
check_cuda_pair(a, b);
return tiled_matmul_cuda(a, b);
}
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return online_softmax_cuda(x);
}
torch::Tensor flash_attention_fwd_dispatch(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal) {
LAB_CHECK_CUDA(q);
LAB_CHECK_CUDA(k);
LAB_CHECK_CUDA(v);
return flash_attention_fwd_cuda(q, k, v, causal);
}
} // namespace kernel_lab
TORCH_LIBRARY(kernel_lab, m) {
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
m.def("row_softmax(Tensor x) -> Tensor");
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
m.def("online_softmax(Tensor x) -> Tensor");
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
}
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
m.def(
"flash_attention_fwd_dispatch",
&kernel_lab::flash_attention_fwd_dispatch,
"Flash attention forward dispatch");
}

View File

@@ -0,0 +1,17 @@
#pragma once
#include <torch/extension.h>
namespace kernel_lab {
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
torch::Tensor row_softmax_cuda(torch::Tensor x);
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
torch::Tensor online_softmax_cuda(torch::Tensor x);
torch::Tensor flash_attention_fwd_cuda(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal);
} // namespace kernel_lab

View File

@@ -0,0 +1,15 @@
#pragma once
#include <torch/extension.h>
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CUDA(y);
LAB_CHECK_CONTIGUOUS(x);
LAB_CHECK_CONTIGUOUS(y);
}

View File

@@ -0,0 +1,54 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void flash_attention_fwd_kernel(
const float* q,
const float* k,
const float* v,
float* out,
int64_t batch,
int64_t heads,
int64_t seq_len,
int64_t head_dim,
bool causal) {
(void)q;
(void)k;
(void)v;
(void)out;
(void)batch;
(void)heads;
(void)seq_len;
(void)head_dim;
(void)causal;
// TODO(student): assign each block to a batch/head/query tile.
// TODO(student): cooperatively load K/V tiles.
// TODO(student): compute score blocks and apply causal masking when requested.
// TODO(student): maintain online softmax state and accumulate the output tile.
}
torch::Tensor flash_attention_fwd_cuda(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal) {
LAB_CHECK_CUDA(q);
LAB_CHECK_CUDA(k);
LAB_CHECK_CUDA(v);
LAB_CHECK_CONTIGUOUS(q);
LAB_CHECK_CONTIGUOUS(k);
LAB_CHECK_CONTIGUOUS(v);
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must match");
TORCH_CHECK(q.sizes() == v.sizes(), "q and v must match");
TORCH_CHECK(q.dim() == 4, "flash_attention_fwd_cuda expects [batch, heads, seq, dim]");
TORCH_CHECK(q.scalar_type() == torch::kFloat32, "flash_attention_fwd_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement flash_attention_fwd_cuda in kernels/cuda/src/flash_attention_fwd.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,36 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void online_softmax_kernel(
const float* x,
float* out,
int64_t num_rows,
int64_t num_cols) {
int row = blockIdx.x;
if (row >= num_rows) {
return;
}
// TODO(student): maintain running max and running sum across column tiles.
// TODO(student): write the normalized row after finishing the recurrence.
(void)x;
(void)out;
(void)num_rows;
(void)num_cols;
}
torch::Tensor online_softmax_cuda(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
TORCH_CHECK(x.dim() == 2, "online_softmax_cuda expects a 2D tensor");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "online_softmax_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement online_softmax_cuda in kernels/cuda/src/online_softmax.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,37 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void row_softmax_kernel(
const float* x,
float* out,
int64_t num_rows,
int64_t num_cols) {
int row = blockIdx.x;
if (row >= num_rows) {
return;
}
// TODO(student): decide whether one block owns one row or one row tile.
// TODO(student): compute the row max for numerical stability.
// TODO(student): compute exp(x - max), reduce the sum, and normalize.
(void)x;
(void)out;
(void)num_rows;
(void)num_cols;
}
torch::Tensor row_softmax_cuda(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
TORCH_CHECK(x.dim() == 2, "row_softmax_cuda expects a 2D tensor");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "row_softmax_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement row_softmax_cuda in kernels/cuda/src/row_softmax.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,40 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void tiled_matmul_kernel(
const float* a,
const float* b,
float* c,
int64_t m,
int64_t n,
int64_t k) {
// TODO(student): map blockIdx/threadIdx to a C tile.
// TODO(student): cooperatively load A and B tiles into shared memory.
// TODO(student): accumulate partial products across the K dimension.
(void)a;
(void)b;
(void)c;
(void)m;
(void)n;
(void)k;
}
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b) {
LAB_CHECK_CUDA(a);
LAB_CHECK_CUDA(b);
LAB_CHECK_CONTIGUOUS(a);
LAB_CHECK_CONTIGUOUS(b);
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "tiled_matmul_cuda expects 2D tensors");
TORCH_CHECK(a.size(1) == b.size(0), "inner dimensions must match");
TORCH_CHECK(a.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
TORCH_CHECK(b.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement tiled_matmul_cuda in kernels/cuda/src/tiled_matmul.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,35 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void vector_add_kernel(
const float* x,
const float* y,
float* out,
int64_t numel) {
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (global_idx >= numel) {
return;
}
(void)x;
(void)y;
(void)out;
(void)numel;
// TODO(student): replace this placeholder with the real vector-add math.
// Hint: one thread should own one element for the first implementation.
}
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y) {
check_cuda_pair(x, y);
LAB_CHECK_SAME_SHAPE(x, y);
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "vector_add_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement vector_add_cuda in kernels/cuda/src/vector_add.cu and then launch the kernel.");
return torch::Tensor();
}
} // namespace kernel_lab