Files
kernel-lab/kernels/triton/row_softmax.py
Gahow Wang 165a1b0bd5 Implement all 5 Triton kernel labs
- vector_add: basic masked load/store with block indexing
- row_softmax: single-pass numerically stable softmax per row
- tiled_matmul: K-dimension tile loop with edge masking (IEEE precision)
- online_softmax: two-pass running max/sum recurrence across blocks
- flash_attention_fwd: blockwise Q/K/V with online softmax, causal support

All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
2026-05-15 20:46:04 +08:00

55 lines
1.7 KiB
Python

from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def row_softmax_kernel(
x_ptr,
out_ptr,
num_cols,
stride_x_row,
stride_out_row,
block_size: tl.constexpr,
):
row_idx = tl.program_id(axis=0)
col_offsets = tl.arange(0, block_size)
mask = col_offsets < num_cols
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets
row = tl.load(x_ptrs, mask=mask, other=float('-inf'))
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
result = numerator / denominator
tl.store(out_ptrs, result, mask=mask)
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.ndim != 2:
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
if not x.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
num_rows, num_cols = x.shape
# block_size must be >= num_cols for this single-pass kernel
block_size = max(block_size, triton.next_power_of_2(num_cols))
out = torch.empty_like(x)
grid = (num_rows,)
row_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size)
return out