43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def online_softmax_kernel(
|
|
x_ptr,
|
|
out_ptr,
|
|
num_cols,
|
|
stride_x_row,
|
|
stride_out_row,
|
|
block_size: tl.constexpr,
|
|
):
|
|
row_idx = tl.program_id(axis=0)
|
|
# TODO(student): maintain running max and running sum for this row.
|
|
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
|
|
# TODO(student): write the final normalized probabilities.
|
|
pass
|
|
|
|
|
|
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if x.ndim != 2:
|
|
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
|
if not x.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
raise NotImplementedError("TODO(student): implement online softmax in Triton.")
|
|
|