- vector_add: basic masked load/store with block indexing - row_softmax: single-pass numerically stable softmax per row - tiled_matmul: K-dimension tile loop with edge masking (IEEE precision) - online_softmax: two-pass running max/sum recurrence across blocks - flash_attention_fwd: blockwise Q/K/V with online softmax, causal support All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def online_softmax_kernel(
|
|
x_ptr,
|
|
out_ptr,
|
|
num_cols,
|
|
stride_x_row,
|
|
stride_out_row,
|
|
block_size: tl.constexpr,
|
|
):
|
|
row_idx = tl.program_id(axis=0)
|
|
# First pass: compute running max and sum
|
|
running_max = float('-inf')
|
|
running_sum = 0.0
|
|
for block_start in range(0, num_cols, block_size):
|
|
col_offsets = block_start + tl.arange(0, block_size)
|
|
mask = col_offsets < num_cols
|
|
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
|
|
x_block = tl.load(x_ptrs, mask=mask, other=float('-inf'))
|
|
block_max = tl.max(x_block, axis=0)
|
|
new_max = tl.maximum(running_max, block_max)
|
|
running_sum = running_sum * tl.exp(running_max - new_max) + tl.sum(tl.exp(x_block - new_max), axis=0)
|
|
running_max = new_max
|
|
# Second pass: write normalized output
|
|
for block_start in range(0, num_cols, block_size):
|
|
col_offsets = block_start + tl.arange(0, block_size)
|
|
mask = col_offsets < num_cols
|
|
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
|
|
out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets
|
|
x_block = tl.load(x_ptrs, mask=mask, other=float('-inf'))
|
|
result = tl.exp(x_block - running_max) / running_sum
|
|
tl.store(out_ptrs, result, mask=mask)
|
|
|
|
|
|
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if x.ndim != 2:
|
|
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
|
if not x.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
num_rows, num_cols = x.shape
|
|
out = torch.empty_like(x)
|
|
grid = (num_rows,)
|
|
online_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size)
|
|
return out
|
|
|