# Task 04: Online Softmax ## 1. Problem Statement Implement the running max / running sum formulation of softmax and connect it to blockwise attention. ## 2. Expected Input/Output Shapes - Input: `[num_rows, num_cols]` - Output: `[num_rows, num_cols]` ## 3. Performance Intuition The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once. ## 4. Memory Access Discussion Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block. ## 5. What Triton Is Abstracting Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state. ## 6. What CUDA Makes Explicit CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles. ## 7. Reflection Questions - Why is a running max needed instead of only a running sum? - Why does online softmax enable FlashAttention-style blockwise computation? - Which values must persist from one tile to the next? ## 8. Implementation Checklist - Read the reference online softmax - Derive the recurrence informally - Implement the Triton blocked recurrence - Implement the CUDA blocked recurrence - Compare against full softmax on small shapes first ## Informal Recurrence Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define: - `m_new = max(m_prev, m_tile)` - `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)` That is the key idea you will reuse in FlashAttention.