- vector_add: basic masked load/store with block indexing - row_softmax: single-pass numerically stable softmax per row - tiled_matmul: K-dimension tile loop with edge masking (IEEE precision) - online_softmax: two-pass running max/sum recurrence across blocks - flash_attention_fwd: blockwise Q/K/V with online softmax, causal support All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def vector_add_kernel(
|
|
x_ptr,
|
|
y_ptr,
|
|
out_ptr,
|
|
num_elements,
|
|
block_size: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
offsets = pid * block_size + tl.arange(0, block_size)
|
|
mask = offsets < num_elements
|
|
x = tl.load(x_ptr + offsets, mask=mask)
|
|
y = tl.load(y_ptr + offsets, mask=mask)
|
|
out = x + y
|
|
tl.store(out_ptr + offsets, out, mask=mask)
|
|
|
|
|
|
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
|
|
"""Student entrypoint for the Triton vector add task."""
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if x.shape != y.shape:
|
|
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
|
if not x.is_cuda or not y.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
out = torch.empty_like(x)
|
|
num_elements = x.numel()
|
|
grid = ((num_elements + block_size - 1) // block_size,)
|
|
vector_add_kernel[grid](x, y, out, num_elements, block_size=block_size)
|
|
return out
|
|
|