from __future__ import annotations import torch try: import triton import triton.language as tl except ImportError: # pragma: no cover - depends on local environment triton = None tl = None TRITON_AVAILABLE = triton is not None if TRITON_AVAILABLE: @triton.jit def online_softmax_kernel( x_ptr, out_ptr, num_cols, stride_x_row, stride_out_row, block_size: tl.constexpr, ): row_idx = tl.program_id(axis=0) # TODO(student): maintain running max and running sum for this row. # TODO(student): process the row in blocks rather than assuming all columns fit at once. # TODO(student): write the final normalized probabilities. pass def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor: if not TRITON_AVAILABLE: raise RuntimeError("Triton is not installed in this environment.") if x.ndim != 2: raise ValueError(f"expected 2D input, got {tuple(x.shape)}") if not x.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") raise NotImplementedError("TODO(student): implement online softmax in Triton.")