- vector_add: basic masked load/store with block indexing - row_softmax: single-pass numerically stable softmax per row - tiled_matmul: K-dimension tile loop with edge masking (IEEE precision) - online_softmax: two-pass running max/sum recurrence across blocks - flash_attention_fwd: blockwise Q/K/V with online softmax, causal support All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6).
156 lines
5.2 KiB
Python
156 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
try:
|
|
import triton
|
|
import triton.language as tl
|
|
except ImportError: # pragma: no cover - depends on local environment
|
|
triton = None
|
|
tl = None
|
|
|
|
|
|
TRITON_AVAILABLE = triton is not None
|
|
|
|
|
|
if TRITON_AVAILABLE:
|
|
|
|
@triton.jit
|
|
def flash_attention_fwd_kernel(
|
|
q_ptr,
|
|
k_ptr,
|
|
v_ptr,
|
|
out_ptr,
|
|
seq_len,
|
|
head_dim,
|
|
stride_q_batch,
|
|
stride_q_head,
|
|
stride_q_seq,
|
|
stride_q_dim,
|
|
stride_k_batch,
|
|
stride_k_head,
|
|
stride_k_seq,
|
|
stride_k_dim,
|
|
stride_v_batch,
|
|
stride_v_head,
|
|
stride_v_seq,
|
|
stride_v_dim,
|
|
stride_out_batch,
|
|
stride_out_head,
|
|
stride_out_seq,
|
|
stride_out_dim,
|
|
causal,
|
|
block_q: tl.constexpr,
|
|
block_k: tl.constexpr,
|
|
block_d: tl.constexpr,
|
|
):
|
|
pid_q = tl.program_id(axis=0)
|
|
pid_bh = tl.program_id(axis=1)
|
|
num_heads = stride_q_batch // stride_q_head
|
|
batch_idx = pid_bh // num_heads
|
|
head_idx = pid_bh % num_heads
|
|
|
|
q_offset = batch_idx * stride_q_batch + head_idx * stride_q_head
|
|
k_offset = batch_idx * stride_k_batch + head_idx * stride_k_head
|
|
v_offset = batch_idx * stride_v_batch + head_idx * stride_v_head
|
|
out_offset = batch_idx * stride_out_batch + head_idx * stride_out_head
|
|
|
|
offs_q = pid_q * block_q + tl.arange(0, block_q)
|
|
offs_d = tl.arange(0, block_d)
|
|
|
|
# Load Q block [block_q, block_d]
|
|
q_ptrs = q_ptr + q_offset + offs_q[:, None] * stride_q_seq + offs_d[None, :] * stride_q_dim
|
|
q_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
|
q_block = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
|
|
|
scale = 1.0 / tl.sqrt(head_dim.to(tl.float32))
|
|
|
|
# Online softmax accumulators
|
|
m_i = tl.full((block_q,), float('-inf'), dtype=tl.float32)
|
|
l_i = tl.zeros((block_q,), dtype=tl.float32)
|
|
acc = tl.zeros((block_q, block_d), dtype=tl.float32)
|
|
|
|
# Determine K range
|
|
if causal:
|
|
k_end = tl.minimum((pid_q + 1) * block_q, seq_len)
|
|
else:
|
|
k_end = seq_len
|
|
|
|
for k_start in range(0, k_end, block_k):
|
|
offs_k = k_start + tl.arange(0, block_k)
|
|
|
|
# Load K block [block_k, block_d]
|
|
k_ptrs = k_ptr + k_offset + offs_k[:, None] * stride_k_seq + offs_d[None, :] * stride_k_dim
|
|
k_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
|
k_block = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
|
|
|
# Compute scores [block_q, block_k]
|
|
scores = tl.dot(q_block, tl.trans(k_block), input_precision="ieee") * scale
|
|
|
|
# Apply causal mask
|
|
if causal:
|
|
causal_mask = offs_q[:, None] >= offs_k[None, :]
|
|
scores = tl.where(causal_mask, scores, float('-inf'))
|
|
|
|
# Mask out-of-bounds keys
|
|
scores = tl.where(offs_k[None, :] < seq_len, scores, float('-inf'))
|
|
|
|
# Online softmax update
|
|
m_ij = tl.max(scores, axis=1)
|
|
m_new = tl.maximum(m_i, m_ij)
|
|
alpha = tl.exp(m_i - m_new)
|
|
p = tl.exp(scores - m_new[:, None])
|
|
|
|
l_i = l_i * alpha + tl.sum(p, axis=1)
|
|
acc = acc * alpha[:, None]
|
|
|
|
# Load V block [block_k, block_d]
|
|
v_ptrs = v_ptr + v_offset + offs_k[:, None] * stride_v_seq + offs_d[None, :] * stride_v_dim
|
|
v_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
|
v_block = tl.load(v_ptrs, mask=v_mask, other=0.0)
|
|
|
|
acc += tl.dot(p.to(v_block.dtype), v_block, input_precision="ieee")
|
|
m_i = m_new
|
|
|
|
# Normalize
|
|
acc = acc / l_i[:, None]
|
|
|
|
# Store output
|
|
out_ptrs = out_ptr + out_offset + offs_q[:, None] * stride_out_seq + offs_d[None, :] * stride_out_dim
|
|
out_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
|
tl.store(out_ptrs, acc, mask=out_mask)
|
|
|
|
|
|
def triton_flash_attention_fwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
causal: bool = False,
|
|
block_q: int = 64,
|
|
block_k: int = 64,
|
|
) -> torch.Tensor:
|
|
if not TRITON_AVAILABLE:
|
|
raise RuntimeError("Triton is not installed in this environment.")
|
|
if q.shape != k.shape or q.shape != v.shape:
|
|
raise ValueError(f"q, k, v must match; got {q.shape}, {k.shape}, {v.shape}")
|
|
if q.ndim != 4:
|
|
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
|
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
|
batch, heads, seq_len, head_dim = q.shape
|
|
block_d = triton.next_power_of_2(head_dim)
|
|
out = torch.empty_like(q)
|
|
grid = (triton.cdiv(seq_len, block_q), batch * heads)
|
|
flash_attention_fwd_kernel[grid](
|
|
q, k, v, out,
|
|
seq_len, head_dim,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
|
|
causal,
|
|
block_q=block_q, block_k=block_k, block_d=block_d,
|
|
)
|
|
return out
|
|
|