Compare commits
1 Commits
main
...
triton-lab
| Author | SHA1 | Date | |
|---|---|---|---|
| 165a1b0bd5 |
@@ -46,13 +46,79 @@ if TRITON_AVAILABLE:
|
|||||||
):
|
):
|
||||||
pid_q = tl.program_id(axis=0)
|
pid_q = tl.program_id(axis=0)
|
||||||
pid_bh = tl.program_id(axis=1)
|
pid_bh = tl.program_id(axis=1)
|
||||||
# TODO(student): map pid_q and pid_bh to a batch/head/query tile.
|
num_heads = stride_q_batch // stride_q_head
|
||||||
# TODO(student): load Q, K, and V blocks.
|
batch_idx = pid_bh // num_heads
|
||||||
# TODO(student): compute scores for the current block pair.
|
head_idx = pid_bh % num_heads
|
||||||
# TODO(student): apply optional causal masking.
|
|
||||||
# TODO(student): update online softmax state and accumulate the output block.
|
q_offset = batch_idx * stride_q_batch + head_idx * stride_q_head
|
||||||
# TODO(student): store the final output tile.
|
k_offset = batch_idx * stride_k_batch + head_idx * stride_k_head
|
||||||
pass
|
v_offset = batch_idx * stride_v_batch + head_idx * stride_v_head
|
||||||
|
out_offset = batch_idx * stride_out_batch + head_idx * stride_out_head
|
||||||
|
|
||||||
|
offs_q = pid_q * block_q + tl.arange(0, block_q)
|
||||||
|
offs_d = tl.arange(0, block_d)
|
||||||
|
|
||||||
|
# Load Q block [block_q, block_d]
|
||||||
|
q_ptrs = q_ptr + q_offset + offs_q[:, None] * stride_q_seq + offs_d[None, :] * stride_q_dim
|
||||||
|
q_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
||||||
|
q_block = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
||||||
|
|
||||||
|
scale = 1.0 / tl.sqrt(head_dim.to(tl.float32))
|
||||||
|
|
||||||
|
# Online softmax accumulators
|
||||||
|
m_i = tl.full((block_q,), float('-inf'), dtype=tl.float32)
|
||||||
|
l_i = tl.zeros((block_q,), dtype=tl.float32)
|
||||||
|
acc = tl.zeros((block_q, block_d), dtype=tl.float32)
|
||||||
|
|
||||||
|
# Determine K range
|
||||||
|
if causal:
|
||||||
|
k_end = tl.minimum((pid_q + 1) * block_q, seq_len)
|
||||||
|
else:
|
||||||
|
k_end = seq_len
|
||||||
|
|
||||||
|
for k_start in range(0, k_end, block_k):
|
||||||
|
offs_k = k_start + tl.arange(0, block_k)
|
||||||
|
|
||||||
|
# Load K block [block_k, block_d]
|
||||||
|
k_ptrs = k_ptr + k_offset + offs_k[:, None] * stride_k_seq + offs_d[None, :] * stride_k_dim
|
||||||
|
k_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
||||||
|
k_block = tl.load(k_ptrs, mask=k_mask, other=0.0)
|
||||||
|
|
||||||
|
# Compute scores [block_q, block_k]
|
||||||
|
scores = tl.dot(q_block, tl.trans(k_block), input_precision="ieee") * scale
|
||||||
|
|
||||||
|
# Apply causal mask
|
||||||
|
if causal:
|
||||||
|
causal_mask = offs_q[:, None] >= offs_k[None, :]
|
||||||
|
scores = tl.where(causal_mask, scores, float('-inf'))
|
||||||
|
|
||||||
|
# Mask out-of-bounds keys
|
||||||
|
scores = tl.where(offs_k[None, :] < seq_len, scores, float('-inf'))
|
||||||
|
|
||||||
|
# Online softmax update
|
||||||
|
m_ij = tl.max(scores, axis=1)
|
||||||
|
m_new = tl.maximum(m_i, m_ij)
|
||||||
|
alpha = tl.exp(m_i - m_new)
|
||||||
|
p = tl.exp(scores - m_new[:, None])
|
||||||
|
|
||||||
|
l_i = l_i * alpha + tl.sum(p, axis=1)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
# Load V block [block_k, block_d]
|
||||||
|
v_ptrs = v_ptr + v_offset + offs_k[:, None] * stride_v_seq + offs_d[None, :] * stride_v_dim
|
||||||
|
v_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
||||||
|
v_block = tl.load(v_ptrs, mask=v_mask, other=0.0)
|
||||||
|
|
||||||
|
acc += tl.dot(p.to(v_block.dtype), v_block, input_precision="ieee")
|
||||||
|
m_i = m_new
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
|
||||||
|
# Store output
|
||||||
|
out_ptrs = out_ptr + out_offset + offs_q[:, None] * stride_out_seq + offs_d[None, :] * stride_out_dim
|
||||||
|
out_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim)
|
||||||
|
tl.store(out_ptrs, acc, mask=out_mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_flash_attention_fwd(
|
def triton_flash_attention_fwd(
|
||||||
@@ -71,5 +137,19 @@ def triton_flash_attention_fwd(
|
|||||||
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
raise ValueError("expected [batch, heads, seq, dim] inputs")
|
||||||
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
if not q.is_cuda or not k.is_cuda or not v.is_cuda:
|
||||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||||
raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.")
|
batch, heads, seq_len, head_dim = q.shape
|
||||||
|
block_d = triton.next_power_of_2(head_dim)
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
grid = (triton.cdiv(seq_len, block_q), batch * heads)
|
||||||
|
flash_attention_fwd_kernel[grid](
|
||||||
|
q, k, v, out,
|
||||||
|
seq_len, head_dim,
|
||||||
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||||
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||||
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||||
|
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
|
||||||
|
causal,
|
||||||
|
block_q=block_q, block_k=block_k, block_d=block_d,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -25,10 +25,27 @@ if TRITON_AVAILABLE:
|
|||||||
block_size: tl.constexpr,
|
block_size: tl.constexpr,
|
||||||
):
|
):
|
||||||
row_idx = tl.program_id(axis=0)
|
row_idx = tl.program_id(axis=0)
|
||||||
# TODO(student): maintain running max and running sum for this row.
|
# First pass: compute running max and sum
|
||||||
# TODO(student): process the row in blocks rather than assuming all columns fit at once.
|
running_max = float('-inf')
|
||||||
# TODO(student): write the final normalized probabilities.
|
running_sum = 0.0
|
||||||
pass
|
for block_start in range(0, num_cols, block_size):
|
||||||
|
col_offsets = block_start + tl.arange(0, block_size)
|
||||||
|
mask = col_offsets < num_cols
|
||||||
|
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
|
||||||
|
x_block = tl.load(x_ptrs, mask=mask, other=float('-inf'))
|
||||||
|
block_max = tl.max(x_block, axis=0)
|
||||||
|
new_max = tl.maximum(running_max, block_max)
|
||||||
|
running_sum = running_sum * tl.exp(running_max - new_max) + tl.sum(tl.exp(x_block - new_max), axis=0)
|
||||||
|
running_max = new_max
|
||||||
|
# Second pass: write normalized output
|
||||||
|
for block_start in range(0, num_cols, block_size):
|
||||||
|
col_offsets = block_start + tl.arange(0, block_size)
|
||||||
|
mask = col_offsets < num_cols
|
||||||
|
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
|
||||||
|
out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets
|
||||||
|
x_block = tl.load(x_ptrs, mask=mask, other=float('-inf'))
|
||||||
|
result = tl.exp(x_block - running_max) / running_sum
|
||||||
|
tl.store(out_ptrs, result, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||||
@@ -38,5 +55,9 @@ def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tenso
|
|||||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||||
if not x.is_cuda:
|
if not x.is_cuda:
|
||||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||||
raise NotImplementedError("TODO(student): implement online softmax in Triton.")
|
num_rows, num_cols = x.shape
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
grid = (num_rows,)
|
||||||
|
online_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -26,12 +26,15 @@ if TRITON_AVAILABLE:
|
|||||||
):
|
):
|
||||||
row_idx = tl.program_id(axis=0)
|
row_idx = tl.program_id(axis=0)
|
||||||
col_offsets = tl.arange(0, block_size)
|
col_offsets = tl.arange(0, block_size)
|
||||||
# TODO(student): convert row_idx and col_offsets into pointers for this row.
|
mask = col_offsets < num_cols
|
||||||
# TODO(student): load a row with masking.
|
x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets
|
||||||
# TODO(student): subtract the row max for stability.
|
out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets
|
||||||
# TODO(student): exponentiate, sum, and normalize.
|
row = tl.load(x_ptrs, mask=mask, other=float('-inf'))
|
||||||
# TODO(student): store the normalized row.
|
row_max = tl.max(row, axis=0)
|
||||||
pass
|
numerator = tl.exp(row - row_max)
|
||||||
|
denominator = tl.sum(numerator, axis=0)
|
||||||
|
result = numerator / denominator
|
||||||
|
tl.store(out_ptrs, result, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||||
@@ -41,5 +44,11 @@ def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|||||||
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
raise ValueError(f"expected 2D input, got {tuple(x.shape)}")
|
||||||
if not x.is_cuda:
|
if not x.is_cuda:
|
||||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||||
raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.")
|
num_rows, num_cols = x.shape
|
||||||
|
# block_size must be >= num_cols for this single-pass kernel
|
||||||
|
block_size = max(block_size, triton.next_power_of_2(num_cols))
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
grid = (num_rows,)
|
||||||
|
row_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -35,11 +35,24 @@ if TRITON_AVAILABLE:
|
|||||||
):
|
):
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
pid_n = tl.program_id(axis=1)
|
pid_n = tl.program_id(axis=1)
|
||||||
# TODO(student): compute the tile owned by this program instance.
|
offs_m = pid_m * block_m + tl.arange(0, block_m)
|
||||||
# TODO(student): loop over K tiles and accumulate partial products.
|
offs_n = pid_n * block_n + tl.arange(0, block_n)
|
||||||
# TODO(student): use masking on edge tiles.
|
offs_k = tl.arange(0, block_k)
|
||||||
# TODO(student): store the output tile.
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
pass
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||||
|
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
|
||||||
|
for ki in range(0, tl.cdiv(k, block_k)):
|
||||||
|
k_offset = ki * block_k
|
||||||
|
a_mask = (offs_m[:, None] < m) & ((k_offset + offs_k[None, :]) < k)
|
||||||
|
b_mask = ((k_offset + offs_k[:, None]) < k) & (offs_n[None, :] < n)
|
||||||
|
a_tile = tl.load(a_ptrs, mask=a_mask, other=0.0)
|
||||||
|
b_tile = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||||
|
acc += tl.dot(a_tile, b_tile, input_precision="ieee")
|
||||||
|
a_ptrs += block_k * stride_ak
|
||||||
|
b_ptrs += block_k * stride_bk
|
||||||
|
c_mask = (offs_m[:, None] < m) & (offs_n[None, :] < n)
|
||||||
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||||
|
tl.store(c_ptrs, acc, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_tiled_matmul(
|
def triton_tiled_matmul(
|
||||||
@@ -57,5 +70,17 @@ def triton_tiled_matmul(
|
|||||||
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
|
||||||
if not a.is_cuda or not b.is_cuda:
|
if not a.is_cuda or not b.is_cuda:
|
||||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||||
raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.")
|
m, k = a.shape
|
||||||
|
_, n = b.shape
|
||||||
|
c = torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||||
|
grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n))
|
||||||
|
tiled_matmul_kernel[grid](
|
||||||
|
a, b, c,
|
||||||
|
m, n, k,
|
||||||
|
a.stride(0), a.stride(1),
|
||||||
|
b.stride(0), b.stride(1),
|
||||||
|
c.stride(0), c.stride(1),
|
||||||
|
block_m=block_m, block_n=block_n, block_k=block_k,
|
||||||
|
)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|||||||
@@ -26,10 +26,10 @@ if TRITON_AVAILABLE:
|
|||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
offsets = pid * block_size + tl.arange(0, block_size)
|
offsets = pid * block_size + tl.arange(0, block_size)
|
||||||
mask = offsets < num_elements
|
mask = offsets < num_elements
|
||||||
# TODO(student): load x and y using masked tl.load calls.
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
# TODO(student): add the vectors.
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
# TODO(student): write the result with tl.store.
|
out = x + y
|
||||||
pass
|
tl.store(out_ptr + offsets, out, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
|
def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor:
|
||||||
@@ -40,5 +40,9 @@ def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024)
|
|||||||
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}")
|
||||||
if not x.is_cuda or not y.is_cuda:
|
if not x.is_cuda or not y.is_cuda:
|
||||||
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
raise ValueError("Triton kernels in this lab expect CUDA tensors.")
|
||||||
raise NotImplementedError("TODO(student): launch vector_add_kernel and return the output tensor.")
|
out = torch.empty_like(x)
|
||||||
|
num_elements = x.numel()
|
||||||
|
grid = ((num_elements + block_size - 1) // block_size,)
|
||||||
|
vector_add_kernel[grid](x, y, out, num_elements, block_size=block_size)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user