Files
kernel-lab/kernels/triton/row_softmax.py
2026-04-10 13:15:06 +00:00

46 lines
1.3 KiB
Python

from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def row_softmax_kernel(
x_ptr,
out_ptr,
num_cols,
stride_x_row,
stride_out_row,
block_size: tl.constexpr,
):
row_idx = tl.program_id(axis=0)
col_offsets = tl.arange(0, block_size)
# TODO(student): convert row_idx and col_offsets into pointers for this row.
# TODO(student): load a row with masking.
# TODO(student): subtract the row max for stability.
# TODO(student): exponentiate, sum, and normalize.
# TODO(student): store the normalized row.
pass
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.ndim != 2:
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
if not x.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")