46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def row_softmax_kernel(
|
|
x_ptr,
|
|
out_ptr,
|
|
num_cols,
|
|
stride_x_row,
|
|
stride_out_row,
|
|
block_size: tl.constexpr,
|
|
):
|
|
row_idx = tl.program_id(axis=0)
|
|
col_offsets = tl.arange(0, block_size)
|
|
# TODO(student): convert row_idx and col_offsets into pointers for this row.
|
|
# TODO(student): load a row with masking.
|
|
# TODO(student): subtract the row max for stability.
|
|
# TODO(student): exponentiate, sum, and normalize.
|
|
# TODO(student): store the normalized row.
|
|
pass
|
|
|
|
|
|
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if x.ndim != 2:
|
|
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
|
if not x.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")
|
|
|