Files
kernel-lab/kernels/triton/vector_add.py
Gahow Wang 165a1b0bd5 Implement all 5 Triton kernel labs
- vector_add: basic masked load/store with block indexing
- row_softmax: single-pass numerically stable softmax per row
- tiled_matmul: K-dimension tile loop with edge masking (IEEE precision)
- online_softmax: two-pass running max/sum recurrence across blocks
- flash_attention_fwd: blockwise Q/K/V with online softmax, causal support

All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
2026-05-15 20:46:04 +08:00

49 lines
1.4 KiB
Python

from __future__ import annotations
import torch
try:
import triton
import triton.language as tl
except ImportError: # pragma: no cover - depends on local environment
triton = None
tl = None
TRITON_AVAILABLE = triton is not None
if TRITON_AVAILABLE:
@triton.jit
def vector_add_kernel(
x_ptr,
y_ptr,
out_ptr,
num_elements,
block_size: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = pid * block_size + tl.arange(0, block_size)
mask = offsets < num_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
out = x + y
tl.store(out_ptr + offsets, out, mask=mask)
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
"""Student entrypoint for the Triton vector add task."""
if not TRITON_AVAILABLE:
raise RuntimeError("Triton is not installed in this environment.")
if x.shape != y.shape:
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
if not x.is_cuda or not y.is_cuda:
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
out = torch.empty_like(x)
num_elements = x.numel()
grid = ((num_elements + block_size - 1) // block_size,)
vector_add_kernel[grid](x, y, out, num_elements, block_size=block_size)
return out