Initial project scaffold
This commit is contained in:
2
kernels/triton/__init__.py
Normal file
2
kernels/triton/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Triton learner skeletons."""
|
||||
|
||||
75
kernels/triton/flash_attention_fwd.py
Normal file
75
kernels/triton/flash_attention_fwd.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def flash_attention_fwd_kernel(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
out_ptr,
|
||||
seq_len,
|
||||
head_dim,
|
||||
stride_q_batch,
|
||||
stride_q_head,
|
||||
stride_q_seq,
|
||||
stride_q_dim,
|
||||
stride_k_batch,
|
||||
stride_k_head,
|
||||
stride_k_seq,
|
||||
stride_k_dim,
|
||||
stride_v_batch,
|
||||
stride_v_head,
|
||||
stride_v_seq,
|
||||
stride_v_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_seq,
|
||||
stride_out_dim,
|
||||
causal,
|
||||
block_q: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
block_d: tl.constexpr,
|
||||
):
|
||||
pid_q = tl.program_id(axis=0)
|
||||
pid_bh = tl.program_id(axis=1)
|
||||
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
|
||||
# TODO(student): load Q, K, and V blocks.
|
||||
# TODO(student): compute scores for the current block pair.
|
||||
# TODO(student): apply optional causal masking.
|
||||
# TODO(student): update online softmax state and accumulate the output block.
|
||||
# TODO(student): store the final output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_flash_attention_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
causal: bool = False,
|
||||
block_q: int = 64,
|
||||
block_k: int = 64,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if q.shape != k.shape or q.shape != v.shape:
|
||||
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
|
||||
if q.ndim != 4:
|
||||
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
||||
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")
|
||||
|
||||
42
kernels/triton/online_softmax.py
Normal file
42
kernels/triton/online_softmax.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def online_softmax_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
num_cols,
|
||||
stride_x_row,
|
||||
stride_out_row,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(axis=0)
|
||||
# TODO(student): maintain running max and running sum for this row.
|
||||
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
|
||||
# TODO(student): write the final normalized probabilities.
|
||||
pass
|
||||
|
||||
|
||||
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||
if not x.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement online softmax in Triton.")
|
||||
|
||||
45
kernels/triton/row_softmax.py
Normal file
45
kernels/triton/row_softmax.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def row_softmax_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
num_cols,
|
||||
stride_x_row,
|
||||
stride_out_row,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(axis=0)
|
||||
col_offsets = tl.arange(0, block_size)
|
||||
# TODO(student): convert row_idx and col_offsets into pointers for this row.
|
||||
# TODO(student): load a row with masking.
|
||||
# TODO(student): subtract the row max for stability.
|
||||
# TODO(student): exponentiate, sum, and normalize.
|
||||
# TODO(student): store the normalized row.
|
||||
pass
|
||||
|
||||
|
||||
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||
if not x.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")
|
||||
|
||||
61
kernels/triton/tiled_matmul.py
Normal file
61
kernels/triton/tiled_matmul.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def tiled_matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
# TODO(student): compute the tile owned by this program instance.
|
||||
# TODO(student): loop over K tiles and accumulate partial products.
|
||||
# TODO(student): use masking on edge tiles.
|
||||
# TODO(student): store the output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_tiled_matmul(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_m: int = 64,
|
||||
block_n: int = 64,
|
||||
block_k: int = 32,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError("expected two 2D tensors")
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
||||
if not a.is_cuda or not b.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")
|
||||
|
||||
44
kernels/triton/vector_add.py
Normal file
44
kernels/triton/vector_add.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def vector_add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
out_ptr,
|
||||
num_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < num_elements
|
||||
# TODO(student): load x and y using masked tl.load calls.
|
||||
# TODO(student): add the vectors.
|
||||
# TODO(student): write the result with tl.store.
|
||||
pass
|
||||
|
||||
|
||||
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
|
||||
"""Student entrypoint for the Triton vector add task."""
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.shape != y.shape:
|
||||
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
||||
if not x.is_cuda or not y.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): launch vector_add_kernel and return the output tensor.")
|
||||
|
||||
Reference in New Issue
Block a user