Initial project scaffold
This commit is contained in:
42
kernels/triton/online_softmax.py
Normal file
42
kernels/triton/online_softmax.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError: # pragma: no cover - depends on local environment
|
||||
triton = None
|
||||
tl = None
|
||||
|
||||
|
||||
TRITON_AVAILABLE = triton is not None
|
||||
|
||||
|
||||
if TRITON_AVAILABLE:
|
||||
|
||||
@triton.jit
|
||||
def online_softmax_kernel(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
num_cols,
|
||||
stride_x_row,
|
||||
stride_out_row,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
row_idx = tl.program_id(axis=0)
|
||||
# TODO(student): maintain running max and running sum for this row.
|
||||
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
|
||||
# TODO(student): write the final normalized probabilities.
|
||||
pass
|
||||
|
||||
|
||||
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
if not TRITON_AVAILABLE:
|
||||
raise RuntimeError("Triton is not installed in this environment.")
|
||||
if x.ndim != 2:
|
||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||
if not x.is_cuda:
|
||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||
raise NotImplementedError("TODO(student): implement online softmax in Triton.")
|
||||
|
||||
Reference in New Issue
Block a user