From 165a1b0bd51f9fa9852ee952b7acf3968c677348 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 15 May 2026 20:46:04 +0800 Subject: [PATCH] Implement all 5 Triton kernel labs - vector_add: basic masked load/store with block indexing - row_softmax: single-pass numerically stable softmax per row - tiled_matmul: K-dimension tile loop with edge masking (IEEE precision) - online_softmax: two-pass running max/sum recurrence across blocks - flash_attention_fwd: blockwise Q/K/V with online softmax, causal support All 26 tests pass on RTX 5090 (CUDA 12.8, Triton 3.6). --- kernels/triton/flash_attention_fwd.py | 96 ++++++++++++++++++++++++--- kernels/triton/online_softmax.py | 31 +++++++-- kernels/triton/row_softmax.py | 23 +++++-- kernels/triton/tiled_matmul.py | 37 +++++++++-- kernels/triton/vector_add.py | 14 ++-- 5 files changed, 170 insertions(+), 31 deletions(-) diff --git a/kernels/triton/flash_attention_fwd.py b/kernels/triton/flash_attention_fwd.py index 30df4a4..4cca2cd 100644 --- a/kernels/triton/flash_attention_fwd.py +++ b/kernels/triton/flash_attention_fwd.py @@ -46,13 +46,79 @@ if TRITON_AVAILABLE: ): pid_q = tl.program_id(axis=0) pid_bh = tl.program_id(axis=1) - # TODO(student): map pid_q and pid_bh to a batch/head/query tile. - # TODO(student): load Q, K, and V blocks. - # TODO(student): compute scores for the current block pair. - # TODO(student): apply optional causal masking. - # TODO(student): update online softmax state and accumulate the output block. - # TODO(student): store the final output tile. - pass + num_heads = stride_q_batch // stride_q_head + batch_idx = pid_bh // num_heads + head_idx = pid_bh % num_heads + + q_offset = batch_idx * stride_q_batch + head_idx * stride_q_head + k_offset = batch_idx * stride_k_batch + head_idx * stride_k_head + v_offset = batch_idx * stride_v_batch + head_idx * stride_v_head + out_offset = batch_idx * stride_out_batch + head_idx * stride_out_head + + offs_q = pid_q * block_q + tl.arange(0, block_q) + offs_d = tl.arange(0, block_d) + + # Load Q block [block_q, block_d] + q_ptrs = q_ptr + q_offset + offs_q[:, None] * stride_q_seq + offs_d[None, :] * stride_q_dim + q_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim) + q_block = tl.load(q_ptrs, mask=q_mask, other=0.0) + + scale = 1.0 / tl.sqrt(head_dim.to(tl.float32)) + + # Online softmax accumulators + m_i = tl.full((block_q,), float('-inf'), dtype=tl.float32) + l_i = tl.zeros((block_q,), dtype=tl.float32) + acc = tl.zeros((block_q, block_d), dtype=tl.float32) + + # Determine K range + if causal: + k_end = tl.minimum((pid_q + 1) * block_q, seq_len) + else: + k_end = seq_len + + for k_start in range(0, k_end, block_k): + offs_k = k_start + tl.arange(0, block_k) + + # Load K block [block_k, block_d] + k_ptrs = k_ptr + k_offset + offs_k[:, None] * stride_k_seq + offs_d[None, :] * stride_k_dim + k_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim) + k_block = tl.load(k_ptrs, mask=k_mask, other=0.0) + + # Compute scores [block_q, block_k] + scores = tl.dot(q_block, tl.trans(k_block), input_precision="ieee") * scale + + # Apply causal mask + if causal: + causal_mask = offs_q[:, None] >= offs_k[None, :] + scores = tl.where(causal_mask, scores, float('-inf')) + + # Mask out-of-bounds keys + scores = tl.where(offs_k[None, :] < seq_len, scores, float('-inf')) + + # Online softmax update + m_ij = tl.max(scores, axis=1) + m_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_new) + p = tl.exp(scores - m_new[:, None]) + + l_i = l_i * alpha + tl.sum(p, axis=1) + acc = acc * alpha[:, None] + + # Load V block [block_k, block_d] + v_ptrs = v_ptr + v_offset + offs_k[:, None] * stride_v_seq + offs_d[None, :] * stride_v_dim + v_mask = (offs_k[:, None] < seq_len) & (offs_d[None, :] < head_dim) + v_block = tl.load(v_ptrs, mask=v_mask, other=0.0) + + acc += tl.dot(p.to(v_block.dtype), v_block, input_precision="ieee") + m_i = m_new + + # Normalize + acc = acc / l_i[:, None] + + # Store output + out_ptrs = out_ptr + out_offset + offs_q[:, None] * stride_out_seq + offs_d[None, :] * stride_out_dim + out_mask = (offs_q[:, None] < seq_len) & (offs_d[None, :] < head_dim) + tl.store(out_ptrs, acc, mask=out_mask) def triton_flash_attention_fwd( @@ -71,5 +137,19 @@ def triton_flash_attention_fwd( raise ValueError("expected [batch, heads, seq, dim] inputs") if not q.is_cuda or not k.is_cuda or not v.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") - raise NotImplementedError("TODO(student): implement the FlashAttention forward launch.") + batch, heads, seq_len, head_dim = q.shape + block_d = triton.next_power_of_2(head_dim) + out = torch.empty_like(q) + grid = (triton.cdiv(seq_len, block_q), batch * heads) + flash_attention_fwd_kernel[grid]( + q, k, v, out, + seq_len, head_dim, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + causal, + block_q=block_q, block_k=block_k, block_d=block_d, + ) + return out diff --git a/kernels/triton/online_softmax.py b/kernels/triton/online_softmax.py index 328c818..14bb8de 100644 --- a/kernels/triton/online_softmax.py +++ b/kernels/triton/online_softmax.py @@ -25,10 +25,27 @@ if TRITON_AVAILABLE: block_size: tl.constexpr, ): row_idx = tl.program_id(axis=0) - # TODO(student): maintain running max and running sum for this row. - # TODO(student): process the row in blocks rather than assuming all columns fit at once. - # TODO(student): write the final normalized probabilities. - pass + # First pass: compute running max and sum + running_max = float('-inf') + running_sum = 0.0 + for block_start in range(0, num_cols, block_size): + col_offsets = block_start + tl.arange(0, block_size) + mask = col_offsets < num_cols + x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets + x_block = tl.load(x_ptrs, mask=mask, other=float('-inf')) + block_max = tl.max(x_block, axis=0) + new_max = tl.maximum(running_max, block_max) + running_sum = running_sum * tl.exp(running_max - new_max) + tl.sum(tl.exp(x_block - new_max), axis=0) + running_max = new_max + # Second pass: write normalized output + for block_start in range(0, num_cols, block_size): + col_offsets = block_start + tl.arange(0, block_size) + mask = col_offsets < num_cols + x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets + out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets + x_block = tl.load(x_ptrs, mask=mask, other=float('-inf')) + result = tl.exp(x_block - running_max) / running_sum + tl.store(out_ptrs, result, mask=mask) def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor: @@ -38,5 +55,9 @@ def triton_online_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tenso raise ValueError(f"expected 2D input, got {tuple(x.shape)}") if not x.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") - raise NotImplementedError("TODO(student): implement online softmax in Triton.") + num_rows, num_cols = x.shape + out = torch.empty_like(x) + grid = (num_rows,) + online_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size) + return out diff --git a/kernels/triton/row_softmax.py b/kernels/triton/row_softmax.py index db8dd27..db35884 100644 --- a/kernels/triton/row_softmax.py +++ b/kernels/triton/row_softmax.py @@ -26,12 +26,15 @@ if TRITON_AVAILABLE: ): row_idx = tl.program_id(axis=0) col_offsets = tl.arange(0, block_size) - # TODO(student): convert row_idx and col_offsets into pointers for this row. - # TODO(student): load a row with masking. - # TODO(student): subtract the row max for stability. - # TODO(student): exponentiate, sum, and normalize. - # TODO(student): store the normalized row. - pass + mask = col_offsets < num_cols + x_ptrs = x_ptr + row_idx * stride_x_row + col_offsets + out_ptrs = out_ptr + row_idx * stride_out_row + col_offsets + row = tl.load(x_ptrs, mask=mask, other=float('-inf')) + row_max = tl.max(row, axis=0) + numerator = tl.exp(row - row_max) + denominator = tl.sum(numerator, axis=0) + result = numerator / denominator + tl.store(out_ptrs, result, mask=mask) def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor: @@ -41,5 +44,11 @@ def triton_row_softmax(x: torch.Tensor, block_size: int = 128) -> torch.Tensor: raise ValueError(f"expected 2D input, got {tuple(x.shape)}") if not x.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") - raise NotImplementedError("TODO(student): implement row-wise softmax launch logic.") + num_rows, num_cols = x.shape + # block_size must be >= num_cols for this single-pass kernel + block_size = max(block_size, triton.next_power_of_2(num_cols)) + out = torch.empty_like(x) + grid = (num_rows,) + row_softmax_kernel[grid](x, out, num_cols, x.stride(0), out.stride(0), block_size=block_size) + return out diff --git a/kernels/triton/tiled_matmul.py b/kernels/triton/tiled_matmul.py index 9059458..28a9ca1 100644 --- a/kernels/triton/tiled_matmul.py +++ b/kernels/triton/tiled_matmul.py @@ -35,11 +35,24 @@ if TRITON_AVAILABLE: ): pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) - # TODO(student): compute the tile owned by this program instance. - # TODO(student): loop over K tiles and accumulate partial products. - # TODO(student): use masking on edge tiles. - # TODO(student): store the output tile. - pass + offs_m = pid_m * block_m + tl.arange(0, block_m) + offs_n = pid_n * block_n + tl.arange(0, block_n) + offs_k = tl.arange(0, block_k) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + acc = tl.zeros((block_m, block_n), dtype=tl.float32) + for ki in range(0, tl.cdiv(k, block_k)): + k_offset = ki * block_k + a_mask = (offs_m[:, None] < m) & ((k_offset + offs_k[None, :]) < k) + b_mask = ((k_offset + offs_k[:, None]) < k) & (offs_n[None, :] < n) + a_tile = tl.load(a_ptrs, mask=a_mask, other=0.0) + b_tile = tl.load(b_ptrs, mask=b_mask, other=0.0) + acc += tl.dot(a_tile, b_tile, input_precision="ieee") + a_ptrs += block_k * stride_ak + b_ptrs += block_k * stride_bk + c_mask = (offs_m[:, None] < m) & (offs_n[None, :] < n) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_ptrs, acc, mask=c_mask) def triton_tiled_matmul( @@ -57,5 +70,17 @@ def triton_tiled_matmul( raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}") if not a.is_cuda or not b.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") - raise NotImplementedError("TODO(student): implement the tiled Triton matmul path.") + m, k = a.shape + _, n = b.shape + c = torch.empty((m, n), device=a.device, dtype=a.dtype) + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n)) + tiled_matmul_kernel[grid]( + a, b, c, + m, n, k, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + block_m=block_m, block_n=block_n, block_k=block_k, + ) + return c diff --git a/kernels/triton/vector_add.py b/kernels/triton/vector_add.py index d5a2bfc..0efbcd0 100644 --- a/kernels/triton/vector_add.py +++ b/kernels/triton/vector_add.py @@ -26,10 +26,10 @@ if TRITON_AVAILABLE: pid = tl.program_id(axis=0) offsets = pid * block_size + tl.arange(0, block_size) mask = offsets < num_elements - # TODO(student): load x and y using masked tl.load calls. - # TODO(student): add the vectors. - # TODO(student): write the result with tl.store. - pass + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + out = x + y + tl.store(out_ptr + offsets, out, mask=mask) def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) -> torch.Tensor: @@ -40,5 +40,9 @@ def triton_vector_add(x: torch.Tensor, y: torch.Tensor, block_size: int = 1024) raise ValueError(f"shape mismatch: {x.shape} vs {y.shape}") if not x.is_cuda or not y.is_cuda: raise ValueError("Triton kernels in this lab expect CUDA tensors.") - raise NotImplementedError("TODO(student): launch vector_add_kernel and return the output tensor.") + out = torch.empty_like(x) + num_elements = x.numel() + grid = ((num_elements + block_size - 1) // block_size,) + vector_add_kernel[grid](x, y, out, num_elements, block_size=block_size) + return out