Initial project scaffold
This commit is contained in:
61
kernels/triton/tiled_matmul.py
Normal file
61
kernels/triton/tiled_matmul.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def tiled_matmul_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
# TODO(student): compute the tile owned by this program instance.
|
||||
# TODO(student): loop over K tiles and accumulate partial products.
|
||||
# TODO(student): use masking on edge tiles.
|
||||
# TODO(student): store the output tile.
|
||||
pass
|
||||
|
||||
|
||||
def triton_tiled_matmul(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_m: int = 64,
|
||||
block_n: int = 64,
|
||||
block_k: int = 32,
|
||||
) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError("expected two 2D tensors")
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
||||
if not a.is_cuda or not b.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")
|
||||
|
||||
Reference in New Issue
Block a user