from __future__ import annotations import torch try: import triton import triton.language as tl except ImportError: # pragma: no cover - depends on local environment triton = None tl = None TRITON_AVAILABLE = triton is not None if TRITON_AVAILABLE: @triton.jit def row_softmax_kernel( x_ptr, out_ptr, num_cols, stride_x_row, stride_out_row, block_size: tl.constexpr, ): row_idx = tl.program_id(axis=0) col_offsets = tl.arange(0, block_size) # TODO(student): convert row_idx and col_offsets into pointers for this row. # TODO(student): load a row with masking. # TODO(student): subtract the row max for stability. # TODO(student): exponentiate, sum, and normalize. # TODO(student): store the normalized row. pass def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor: if not TRITON_AVAILABLE: raise RuntimeError("Triton is not installed in this environment.") if x.ndim != 2: raise ValueError(f"expected 2D input, got {tuple(x.shape)}") if not x.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")