Initial project scaffold

This commit is contained in:
wjh
2026-04-10 13:15:06 +00:00
commit a4a6b1f1c8
94 changed files with 3964 additions and 0 deletions

2
kernels/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Kernel modules for Triton and CUDA learning tasks."""

View File

@@ -0,0 +1,29 @@
cmake_minimum_required(VERSION 3.25)
project(kernel_lab LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 120 CACHE STRING "Target CUDA architectures")
endif()
find_package(Torch REQUIRED)
add_library(kernel_lab_extension SHARED
binding/binding.cpp
src/vector_add.cu
src/row_softmax.cu
src/tiled_matmul.cu
src/online_softmax.cu
src/flash_attention_fwd.cu
)
target_include_directories(kernel_lab_extension PRIVATE include)
target_link_libraries(kernel_lab_extension PRIVATE "${TORCH_LIBRARIES}")
target_compile_features(kernel_lab_extension PRIVATE cxx_std_17)
set_target_properties(kernel_lab_extension PROPERTIES
PREFIX ""
CUDA_SEPARABLE_COMPILATION ON
)

View File

@@ -0,0 +1,69 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
#include <torch/extension.h>
namespace kernel_lab {
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
check_cuda_pair(x, y);
LAB_CHECK_SAME_SHAPE(x, y);
return vector_add_cuda(x, y);
}
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return row_softmax_cuda(x);
}
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
check_cuda_pair(a, b);
return tiled_matmul_cuda(a, b);
}
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
return online_softmax_cuda(x);
}
torch::Tensor flash_attention_fwd_dispatch(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal) {
LAB_CHECK_CUDA(q);
LAB_CHECK_CUDA(k);
LAB_CHECK_CUDA(v);
return flash_attention_fwd_cuda(q, k, v, causal);
}
} // namespace kernel_lab
TORCH_LIBRARY(kernel_lab, m) {
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
m.def("row_softmax(Tensor x) -> Tensor");
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
m.def("online_softmax(Tensor x) -> Tensor");
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
}
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
m.def(
"flash_attention_fwd_dispatch",
&kernel_lab::flash_attention_fwd_dispatch,
"Flash attention forward dispatch");
}

View File

@@ -0,0 +1,17 @@
#pragma once
#include <torch/extension.h>
namespace kernel_lab {
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
torch::Tensor row_softmax_cuda(torch::Tensor x);
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
torch::Tensor online_softmax_cuda(torch::Tensor x);
torch::Tensor flash_attention_fwd_cuda(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal);
} // namespace kernel_lab

View File

@@ -0,0 +1,15 @@
#pragma once
#include <torch/extension.h>
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CUDA(y);
LAB_CHECK_CONTIGUOUS(x);
LAB_CHECK_CONTIGUOUS(y);
}

View File

@@ -0,0 +1,54 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void flash_attention_fwd_kernel(
const float* q,
const float* k,
const float* v,
float* out,
int64_t batch,
int64_t heads,
int64_t seq_len,
int64_t head_dim,
bool causal) {
(void)q;
(void)k;
(void)v;
(void)out;
(void)batch;
(void)heads;
(void)seq_len;
(void)head_dim;
(void)causal;
// TODO(student): assign each block to a batch/head/query tile.
// TODO(student): cooperatively load K/V tiles.
// TODO(student): compute score blocks and apply causal masking when requested.
// TODO(student): maintain online softmax state and accumulate the output tile.
}
torch::Tensor flash_attention_fwd_cuda(
torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
bool causal) {
LAB_CHECK_CUDA(q);
LAB_CHECK_CUDA(k);
LAB_CHECK_CUDA(v);
LAB_CHECK_CONTIGUOUS(q);
LAB_CHECK_CONTIGUOUS(k);
LAB_CHECK_CONTIGUOUS(v);
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must match");
TORCH_CHECK(q.sizes() == v.sizes(), "q and v must match");
TORCH_CHECK(q.dim() == 4, "flash_attention_fwd_cuda expects [batch, heads, seq, dim]");
TORCH_CHECK(q.scalar_type() == torch::kFloat32, "flash_attention_fwd_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement flash_attention_fwd_cuda in kernels/cuda/src/flash_attention_fwd.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,36 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void online_softmax_kernel(
const float* x,
float* out,
int64_t num_rows,
int64_t num_cols) {
int row = blockIdx.x;
if (row >= num_rows) {
return;
}
// TODO(student): maintain running max and running sum across column tiles.
// TODO(student): write the normalized row after finishing the recurrence.
(void)x;
(void)out;
(void)num_rows;
(void)num_cols;
}
torch::Tensor online_softmax_cuda(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
TORCH_CHECK(x.dim() == 2, "online_softmax_cuda expects a 2D tensor");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "online_softmax_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement online_softmax_cuda in kernels/cuda/src/online_softmax.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,37 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void row_softmax_kernel(
const float* x,
float* out,
int64_t num_rows,
int64_t num_cols) {
int row = blockIdx.x;
if (row >= num_rows) {
return;
}
// TODO(student): decide whether one block owns one row or one row tile.
// TODO(student): compute the row max for numerical stability.
// TODO(student): compute exp(x - max), reduce the sum, and normalize.
(void)x;
(void)out;
(void)num_rows;
(void)num_cols;
}
torch::Tensor row_softmax_cuda(torch::Tensor x) {
LAB_CHECK_CUDA(x);
LAB_CHECK_CONTIGUOUS(x);
TORCH_CHECK(x.dim() == 2, "row_softmax_cuda expects a 2D tensor");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "row_softmax_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement row_softmax_cuda in kernels/cuda/src/row_softmax.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,40 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void tiled_matmul_kernel(
const float* a,
const float* b,
float* c,
int64_t m,
int64_t n,
int64_t k) {
// TODO(student): map blockIdx/threadIdx to a C tile.
// TODO(student): cooperatively load A and B tiles into shared memory.
// TODO(student): accumulate partial products across the K dimension.
(void)a;
(void)b;
(void)c;
(void)m;
(void)n;
(void)k;
}
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b) {
LAB_CHECK_CUDA(a);
LAB_CHECK_CUDA(b);
LAB_CHECK_CONTIGUOUS(a);
LAB_CHECK_CONTIGUOUS(b);
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "tiled_matmul_cuda expects 2D tensors");
TORCH_CHECK(a.size(1) == b.size(0), "inner dimensions must match");
TORCH_CHECK(a.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
TORCH_CHECK(b.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement tiled_matmul_cuda in kernels/cuda/src/tiled_matmul.cu.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,35 @@
#include "../include/common.h"
#include "../include/cuda_utils.h"
namespace kernel_lab {
__global__ void vector_add_kernel(
const float* x,
const float* y,
float* out,
int64_t numel) {
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (global_idx >= numel) {
return;
}
(void)x;
(void)y;
(void)out;
(void)numel;
// TODO(student): replace this placeholder with the real vector-add math.
// Hint: one thread should own one element for the first implementation.
}
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y) {
check_cuda_pair(x, y);
LAB_CHECK_SAME_SHAPE(x, y);
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "vector_add_cuda currently assumes float32");
TORCH_CHECK(
false,
"TODO(student): implement vector_add_cuda in kernels/cuda/src/vector_add.cu and then launch the kernel.");
return torch::Tensor();
}
} // namespace kernel_lab

View File

@@ -0,0 +1,2 @@
"""Triton learner skeletons."""

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def flash_attention_fwd_kernel(
q_ptr,
k_ptr,
v_ptr,
out_ptr,
seq_len,
head_dim,
stride_q_batch,
stride_q_head,
stride_q_seq,
stride_q_dim,
stride_k_batch,
stride_k_head,
stride_k_seq,
stride_k_dim,
stride_v_batch,
stride_v_head,
stride_v_seq,
stride_v_dim,
stride_out_batch,
stride_out_head,
stride_out_seq,
stride_out_dim,
causal,
block_q: tl.constexpr,
block_k: tl.constexpr,
block_d: tl.constexpr,
):
pid_q = tl.program_id(axis=0)
pid_bh = tl.program_id(axis=1)
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
# TODO(student): load Q, K, and V blocks.
# TODO(student): compute scores for the current block pair.
# TODO(student): apply optional causal masking.
# TODO(student): update online softmax state and accumulate the output block.
# TODO(student): store the final output tile.
pass
def triton_flash_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
block_q: int = 64,
block_k: int = 64,
) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if q.shape != k.shape or q.shape != v.shape:
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
if q.ndim != 4:
raise ValueError("expected [batch, heads, seq, dim] inputs")
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def online_softmax_kernel(
x_ptr,
out_ptr,
num_cols,
stride_x_row,
stride_out_row,
block_size: tl.constexpr,
):
row_idx = tl.program_id(axis=0)
# TODO(student): maintain running max and running sum for this row.
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
# TODO(student): write the final normalized probabilities.
pass
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.ndim != 2:
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
if not x.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement online softmax in Triton.")

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def row_softmax_kernel(
x_ptr,
out_ptr,
num_cols,
stride_x_row,
stride_out_row,
block_size: tl.constexpr,
):
row_idx = tl.program_id(axis=0)
col_offsets = tl.arange(0, block_size)
# TODO(student): convert row_idx and col_offsets into pointers for this row.
# TODO(student): load a row with masking.
# TODO(student): subtract the row max for stability.
# TODO(student): exponentiate, sum, and normalize.
# TODO(student): store the normalized row.
pass
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.ndim != 2:
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
if not x.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def tiled_matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
m,
n,
k,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
# TODO(student): compute the tile owned by this program instance.
# TODO(student): loop over K tiles and accumulate partial products.
# TODO(student): use masking on edge tiles.
# TODO(student): store the output tile.
pass
def triton_tiled_matmul(
a: torch.Tensor,
b: torch.Tensor,
block_m: int = 64,
block_n: int = 64,
block_k: int = 32,
) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError("expected two 2D tensors")
if a.shape[1] != b.shape[0]:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if not a.is_cuda or not b.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def vector_add_kernel(
x_ptr,
y_ptr,
out_ptr,
num_elements,
block_size: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = pid * block_size + tl.arange(0, block_size)
mask = offsets < num_elements
# TODO(student): load x and y using masked tl.load calls.
# TODO(student): add the vectors.
# TODO(student): write the result with tl.store.
pass
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
"""Student entrypoint for the Triton vector add task."""
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.shape != y.shape:
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
if not x.is_cuda or not y.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): launch vector_add_kernel and return the output tensor.")