Initial project scaffold
This commit is contained in:
2
kernels/__init__.py
Normal file
2
kernels/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Kernel modules for Triton and CUDA learning tasks."""
|
||||
|
||||
29
kernels/cuda/CMakeLists.txt
Normal file
29
kernels/cuda/CMakeLists.txt
Normal file
@@ -0,0 +1,29 @@
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
project(kernel_lab LANGUAGES CXX CUDA)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
||||
set(CMAKE_CUDA_ARCHITECTURES 120 CACHE STRING "Target CUDA architectures")
|
||||
endif()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_library(kernel_lab_extension SHARED
|
||||
binding/binding.cpp
|
||||
src/vector_add.cu
|
||||
src/row_softmax.cu
|
||||
src/tiled_matmul.cu
|
||||
src/online_softmax.cu
|
||||
src/flash_attention_fwd.cu
|
||||
)
|
||||
|
||||
target_include_directories(kernel_lab_extension PRIVATE include)
|
||||
target_link_libraries(kernel_lab_extension PRIVATE "${TORCH_LIBRARIES}")
|
||||
target_compile_features(kernel_lab_extension PRIVATE cxx_std_17)
|
||||
set_target_properties(kernel_lab_extension PROPERTIES
|
||||
PREFIX ""
|
||||
CUDA_SEPARABLE_COMPILATION ON
|
||||
)
|
||||
69
kernels/cuda/binding/binding.cpp
Normal file
69
kernels/cuda/binding/binding.cpp
Normal file
@@ -0,0 +1,69 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_dispatch(torch::Tensor x, torch::Tensor y) {
|
||||
check_cuda_pair(x, y);
|
||||
LAB_CHECK_SAME_SHAPE(x, y);
|
||||
return vector_add_cuda(x, y);
|
||||
}
|
||||
|
||||
torch::Tensor row_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return row_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor tiled_matmul_dispatch(torch::Tensor a, torch::Tensor b) {
|
||||
check_cuda_pair(a, b);
|
||||
return tiled_matmul_cuda(a, b);
|
||||
}
|
||||
|
||||
torch::Tensor online_softmax_dispatch(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
return online_softmax_cuda(x);
|
||||
}
|
||||
|
||||
torch::Tensor flash_attention_fwd_dispatch(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal) {
|
||||
LAB_CHECK_CUDA(q);
|
||||
LAB_CHECK_CUDA(k);
|
||||
LAB_CHECK_CUDA(v);
|
||||
return flash_attention_fwd_cuda(q, k, v, causal);
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
|
||||
TORCH_LIBRARY(kernel_lab, m) {
|
||||
m.def("vector_add(Tensor x, Tensor y) -> Tensor");
|
||||
m.def("row_softmax(Tensor x) -> Tensor");
|
||||
m.def("tiled_matmul(Tensor a, Tensor b) -> Tensor");
|
||||
m.def("online_softmax(Tensor x) -> Tensor");
|
||||
m.def("flash_attention_fwd(Tensor q, Tensor k, Tensor v, bool causal=False) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(kernel_lab, CUDA, m) {
|
||||
m.impl("vector_add", &kernel_lab::vector_add_dispatch);
|
||||
m.impl("row_softmax", &kernel_lab::row_softmax_dispatch);
|
||||
m.impl("tiled_matmul", &kernel_lab::tiled_matmul_dispatch);
|
||||
m.impl("online_softmax", &kernel_lab::online_softmax_dispatch);
|
||||
m.impl("flash_attention_fwd", &kernel_lab::flash_attention_fwd_dispatch);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("vector_add_dispatch", &kernel_lab::vector_add_dispatch, "Vector add dispatch");
|
||||
m.def("row_softmax_dispatch", &kernel_lab::row_softmax_dispatch, "Row softmax dispatch");
|
||||
m.def("tiled_matmul_dispatch", &kernel_lab::tiled_matmul_dispatch, "Tiled matmul dispatch");
|
||||
m.def("online_softmax_dispatch", &kernel_lab::online_softmax_dispatch, "Online softmax dispatch");
|
||||
m.def(
|
||||
"flash_attention_fwd_dispatch",
|
||||
&kernel_lab::flash_attention_fwd_dispatch,
|
||||
"Flash attention forward dispatch");
|
||||
}
|
||||
17
kernels/cuda/include/common.h
Normal file
17
kernels/cuda/include/common.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y);
|
||||
torch::Tensor row_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b);
|
||||
torch::Tensor online_softmax_cuda(torch::Tensor x);
|
||||
torch::Tensor flash_attention_fwd_cuda(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal);
|
||||
|
||||
} // namespace kernel_lab
|
||||
15
kernels/cuda/include/cuda_utils.h
Normal file
15
kernels/cuda/include/cuda_utils.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define LAB_CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be a CUDA tensor")
|
||||
#define LAB_CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous")
|
||||
#define LAB_CHECK_SAME_SHAPE(x, y) TORCH_CHECK((x).sizes() == (y).sizes(), #x " and " #y " must have the same shape")
|
||||
|
||||
inline void check_cuda_pair(const torch::Tensor& x, const torch::Tensor& y) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CUDA(y);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
LAB_CHECK_CONTIGUOUS(y);
|
||||
}
|
||||
|
||||
54
kernels/cuda/src/flash_attention_fwd.cu
Normal file
54
kernels/cuda/src/flash_attention_fwd.cu
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void flash_attention_fwd_kernel(
|
||||
const float* q,
|
||||
const float* k,
|
||||
const float* v,
|
||||
float* out,
|
||||
int64_t batch,
|
||||
int64_t heads,
|
||||
int64_t seq_len,
|
||||
int64_t head_dim,
|
||||
bool causal) {
|
||||
(void)q;
|
||||
(void)k;
|
||||
(void)v;
|
||||
(void)out;
|
||||
(void)batch;
|
||||
(void)heads;
|
||||
(void)seq_len;
|
||||
(void)head_dim;
|
||||
(void)causal;
|
||||
|
||||
// TODO(student): assign each block to a batch/head/query tile.
|
||||
// TODO(student): cooperatively load K/V tiles.
|
||||
// TODO(student): compute score blocks and apply causal masking when requested.
|
||||
// TODO(student): maintain online softmax state and accumulate the output tile.
|
||||
}
|
||||
|
||||
torch::Tensor flash_attention_fwd_cuda(
|
||||
torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
bool causal) {
|
||||
LAB_CHECK_CUDA(q);
|
||||
LAB_CHECK_CUDA(k);
|
||||
LAB_CHECK_CUDA(v);
|
||||
LAB_CHECK_CONTIGUOUS(q);
|
||||
LAB_CHECK_CONTIGUOUS(k);
|
||||
LAB_CHECK_CONTIGUOUS(v);
|
||||
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must match");
|
||||
TORCH_CHECK(q.sizes() == v.sizes(), "q and v must match");
|
||||
TORCH_CHECK(q.dim() == 4, "flash_attention_fwd_cuda expects [batch, heads, seq, dim]");
|
||||
TORCH_CHECK(q.scalar_type() == torch::kFloat32, "flash_attention_fwd_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement flash_attention_fwd_cuda in kernels/cuda/src/flash_attention_fwd.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
36
kernels/cuda/src/online_softmax.cu
Normal file
36
kernels/cuda/src/online_softmax.cu
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void online_softmax_kernel(
|
||||
const float* x,
|
||||
float* out,
|
||||
int64_t num_rows,
|
||||
int64_t num_cols) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(student): maintain running max and running sum across column tiles.
|
||||
// TODO(student): write the normalized row after finishing the recurrence.
|
||||
(void)x;
|
||||
(void)out;
|
||||
(void)num_rows;
|
||||
(void)num_cols;
|
||||
}
|
||||
|
||||
torch::Tensor online_softmax_cuda(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
TORCH_CHECK(x.dim() == 2, "online_softmax_cuda expects a 2D tensor");
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "online_softmax_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement online_softmax_cuda in kernels/cuda/src/online_softmax.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
37
kernels/cuda/src/row_softmax.cu
Normal file
37
kernels/cuda/src/row_softmax.cu
Normal file
@@ -0,0 +1,37 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void row_softmax_kernel(
|
||||
const float* x,
|
||||
float* out,
|
||||
int64_t num_rows,
|
||||
int64_t num_cols) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(student): decide whether one block owns one row or one row tile.
|
||||
// TODO(student): compute the row max for numerical stability.
|
||||
// TODO(student): compute exp(x - max), reduce the sum, and normalize.
|
||||
(void)x;
|
||||
(void)out;
|
||||
(void)num_rows;
|
||||
(void)num_cols;
|
||||
}
|
||||
|
||||
torch::Tensor row_softmax_cuda(torch::Tensor x) {
|
||||
LAB_CHECK_CUDA(x);
|
||||
LAB_CHECK_CONTIGUOUS(x);
|
||||
TORCH_CHECK(x.dim() == 2, "row_softmax_cuda expects a 2D tensor");
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "row_softmax_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement row_softmax_cuda in kernels/cuda/src/row_softmax.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
40
kernels/cuda/src/tiled_matmul.cu
Normal file
40
kernels/cuda/src/tiled_matmul.cu
Normal file
@@ -0,0 +1,40 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void tiled_matmul_kernel(
|
||||
const float* a,
|
||||
const float* b,
|
||||
float* c,
|
||||
int64_t m,
|
||||
int64_t n,
|
||||
int64_t k) {
|
||||
// TODO(student): map blockIdx/threadIdx to a C tile.
|
||||
// TODO(student): cooperatively load A and B tiles into shared memory.
|
||||
// TODO(student): accumulate partial products across the K dimension.
|
||||
(void)a;
|
||||
(void)b;
|
||||
(void)c;
|
||||
(void)m;
|
||||
(void)n;
|
||||
(void)k;
|
||||
}
|
||||
|
||||
torch::Tensor tiled_matmul_cuda(torch::Tensor a, torch::Tensor b) {
|
||||
LAB_CHECK_CUDA(a);
|
||||
LAB_CHECK_CUDA(b);
|
||||
LAB_CHECK_CONTIGUOUS(a);
|
||||
LAB_CHECK_CONTIGUOUS(b);
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "tiled_matmul_cuda expects 2D tensors");
|
||||
TORCH_CHECK(a.size(1) == b.size(0), "inner dimensions must match");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat32, "tiled_matmul_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement tiled_matmul_cuda in kernels/cuda/src/tiled_matmul.cu.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
35
kernels/cuda/src/vector_add.cu
Normal file
35
kernels/cuda/src/vector_add.cu
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "../include/common.h"
|
||||
#include "../include/cuda_utils.h"
|
||||
|
||||
namespace kernel_lab {
|
||||
|
||||
__global__ void vector_add_kernel(
|
||||
const float* x,
|
||||
const float* y,
|
||||
float* out,
|
||||
int64_t numel) {
|
||||
int64_t global_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (global_idx >= numel) {
|
||||
return;
|
||||
}
|
||||
|
||||
(void)x;
|
||||
(void)y;
|
||||
(void)out;
|
||||
(void)numel;
|
||||
// TODO(student): replace this placeholder with the real vector-add math.
|
||||
// Hint: one thread should own one element for the first implementation.
|
||||
}
|
||||
|
||||
torch::Tensor vector_add_cuda(torch::Tensor x, torch::Tensor y) {
|
||||
check_cuda_pair(x, y);
|
||||
LAB_CHECK_SAME_SHAPE(x, y);
|
||||
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "vector_add_cuda currently assumes float32");
|
||||
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"TODO(student): implement vector_add_cuda in kernels/cuda/src/vector_add.cu and then launch the kernel.");
|
||||
return torch::Tensor();
|
||||
}
|
||||
|
||||
} // namespace kernel_lab
|
||||
2
kernels/triton/__init__.py
Normal file
2
kernels/triton/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Triton learner skeletons."""
|
||||
|
||||
75
kernels/triton/flash_attention_fwd.py
Normal file
75
kernels/triton/flash_attention_fwd.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def flash_attention_fwd_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
out_ptr,
|
||||
seq_len,
|
||||
head_dim,
|
||||
stride_q_batch,
|
||||
stride_q_head,
|
||||
stride_q_seq,
|
||||
stride_q_dim,
|
||||
stride_k_batch,
|
||||
stride_k_head,
|
||||
stride_k_seq,
|
||||
stride_k_dim,
|
||||
stride_v_batch,
|
||||
stride_v_head,
|
||||
stride_v_seq,
|
||||
stride_v_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_seq,
|
||||
stride_out_dim,
|
||||
causal,
|
||||
block_q: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
block_d: tl.constexpr,
|
||||
):
|
||||
pid_q = tl.program_id(axis=0)
|
||||
pid_bh = tl.program_id(axis=1)
|
||||
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
|
||||
# TODO(student): load Q, K, and V blocks.
|
||||
# TODO(student): compute scores for the current block pair.
|
||||
# TODO(student): apply optional causal masking.
|
||||
# TODO(student): update online softmax state and accumulate the output block.
|
||||
# TODO(student): store the final output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_flash_attention_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool = False,
|
||||
block_q: int = 64,
|
||||
block_k: int = 64,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if q.shape != k.shape or q.shape != v.shape:
|
||||
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
|
||||
if q.ndim != 4:
|
||||
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
||||
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")
|
||||
|
||||
42
kernels/triton/online_softmax.py
Normal file
42
kernels/triton/online_softmax.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def online_softmax_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
num_cols,
|
||||
stride_x_row,
|
||||
stride_out_row,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(axis=0)
|
||||
# TODO(student): maintain running max and running sum for this row.
|
||||
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
|
||||
# TODO(student): write the final normalized probabilities.
|
||||
pass
|
||||
|
||||
|
||||
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||
if not x.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement online softmax in Triton.")
|
||||
|
||||
45
kernels/triton/row_softmax.py
Normal file
45
kernels/triton/row_softmax.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def row_softmax_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
num_cols,
|
||||
stride_x_row,
|
||||
stride_out_row,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(axis=0)
|
||||
col_offsets = tl.arange(0, block_size)
|
||||
# TODO(student): convert row_idx and col_offsets into pointers for this row.
|
||||
# TODO(student): load a row with masking.
|
||||
# TODO(student): subtract the row max for stability.
|
||||
# TODO(student): exponentiate, sum, and normalize.
|
||||
# TODO(student): store the normalized row.
|
||||
pass
|
||||
|
||||
|
||||
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||
if not x.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")
|
||||
|
||||
61
kernels/triton/tiled_matmul.py
Normal file
61
kernels/triton/tiled_matmul.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def tiled_matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
# TODO(student): compute the tile owned by this program instance.
|
||||
# TODO(student): loop over K tiles and accumulate partial products.
|
||||
# TODO(student): use masking on edge tiles.
|
||||
# TODO(student): store the output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_tiled_matmul(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_m: int = 64,
|
||||
block_n: int = 64,
|
||||
block_k: int = 32,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError("expected two 2D tensors")
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
||||
if not a.is_cuda or not b.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")
|
||||
|
||||
44
kernels/triton/vector_add.py
Normal file
44
kernels/triton/vector_add.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def vector_add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
out_ptr,
|
||||
num_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < num_elements
|
||||
# TODO(student): load x and y using masked tl.load calls.
|
||||
# TODO(student): add the vectors.
|
||||
# TODO(student): write the result with tl.store.
|
||||
pass
|
||||
|
||||
|
||||
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
|
||||
"""Student entrypoint for the Triton vector add task."""
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.shape != y.shape:
|
||||
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
||||
if not x.is_cuda or not y.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): launch vector_add_kernel and return the output tensor.")
|
||||
|
||||
Reference in New Issue
Block a user