Files
kernel-lab/kernels/triton/tiled_matmul.py
2026-04-10 13:22:19 +00:00

62 lines
1.6 KiB
Python

from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def tiled_matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
m,
n,
k,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
# TODO(student): compute the tile owned by this program instance.
# TODO(student): loop over K tiles and accumulate partial products.
# TODO(student): use masking on edge tiles.
# TODO(student): store the output tile.
pass
def triton_tiled_matmul(
a: torch.Tensor,
b: torch.Tensor,
block_m: int = 64,
block_n: int = 64,
block_k: int = 32,
) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError("expected two 2D tensors")
if a.shape[1] != b.shape[0]:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if not a.is_cuda or not b.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")