from __future__ import annotations import torch try: import triton import triton.language as tl except ImportError: # pragma: no cover - depends on local environment triton = None tl = None TRITON_AVAILABLE = triton is not None if TRITON_AVAILABLE: @triton.jit def tiled_matmul_kernel( a_ptr, b_ptr, c_ptr, m, n, k, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) # TODO(student): compute the tile owned by this program instance. # TODO(student): loop over K tiles and accumulate partial products. # TODO(student): use masking on edge tiles. # TODO(student): store the output tile. pass def triton_tiled_matmul( a: torch.Tensor, b: torch.Tensor, block_m: int = 64, block_n: int = 64, block_k: int = 32, ) -> torch.Tensor: if not TRITON_AVAILABLE: raise RuntimeError("Triton is not installed in this environment.") if a.ndim != 2 or b.ndim != 2: raise ValueError("expected two 2D tensors") if a.shape[1] != b.shape[0]: raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}") if not a.is_cuda or not b.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")