- vector_add: basic masked load/store with block indexing - row_softmax: single-pass numerically stable softmax per row - tiled_matmul: K-dimension tile loop with edge masking (IEEE precision) - online_softmax: two-pass running max/sum recurrence across blocks - flash_attention_fwd: blockwise Q/K/V with online softmax, causal support All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def tiled_matmul_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
m,
|
|
n,
|
|
k,
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
block_m: tl.constexpr,
|
|
block_n: tl.constexpr,
|
|
block_k: tl.constexpr,
|
|
):
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
offs_m = pid_m * block_m + tl.arange(0, block_m)
|
|
offs_n = pid_n * block_n + tl.arange(0, block_n)
|
|
offs_k = tl.arange(0, block_k)
|
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
|
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
|
|
for ki in range(0, tl.cdiv(k, block_k)):
|
|
k_offset = ki * block_k
|
|
a_mask = (offs_m[:, None] < m) & ((k_offset + offs_k[None, :]) < k)
|
|
b_mask = ((k_offset + offs_k[:, None]) < k) & (offs_n[None, :] < n)
|
|
a_tile = tl.load(a_ptrs, mask=a_mask, other=0.0)
|
|
b_tile = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
|
acc += tl.dot(a_tile, b_tile, input_precision="ieee")
|
|
a_ptrs += block_k * stride_ak
|
|
b_ptrs += block_k * stride_bk
|
|
c_mask = (offs_m[:, None] < m) & (offs_n[None, :] < n)
|
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
|
tl.store(c_ptrs, acc, mask=c_mask)
|
|
|
|
|
|
def triton_tiled_matmul(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
block_m: int = 64,
|
|
block_n: int = 64,
|
|
block_k: int = 32,
|
|
) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if a.ndim != 2 or b.ndim != 2:
|
|
raise ValueError("expected two 2D tensors")
|
|
if a.shape[1] != b.shape[0]:
|
|
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
|
if not a.is_cuda or not b.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
m, k = a.shape
|
|
_, n = b.shape
|
|
c = torch.empty((m, n), device=a.device, dtype=a.dtype)
|
|
grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n))
|
|
tiled_matmul_kernel[grid](
|
|
a, b, c,
|
|
m, n, k,
|
|
a.stride(0), a.stride(1),
|
|
b.stride(0), b.stride(1),
|
|
c.stride(0), c.stride(1),
|
|
block_m=block_m, block_n=block_n, block_k=block_k,
|
|
)
|
|
return c
|
|
|