62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def tiled_matmul_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
m,
|
|
n,
|
|
k,
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
block_m: tl.constexpr,
|
|
block_n: tl.constexpr,
|
|
block_k: tl.constexpr,
|
|
):
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
# TODO(student): compute the tile owned by this program instance.
|
|
# TODO(student): loop over K tiles and accumulate partial products.
|
|
# TODO(student): use masking on edge tiles.
|
|
# TODO(student): store the output tile.
|
|
pass
|
|
|
|
|
|
def triton_tiled_matmul(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
block_m: int = 64,
|
|
block_n: int = 64,
|
|
block_k: int = 32,
|
|
) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if a.ndim != 2 or b.ndim != 2:
|
|
raise ValueError("expected two 2D tensors")
|
|
if a.shape[1] != b.shape[0]:
|
|
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
|
if not a.is_cuda or not b.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")
|
|
|