1.7 KiB
Task 04: Online Softmax
1. Problem Statement
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
2. Expected Input/Output Shapes
- Input:
[num_rows, num_cols] - Output:
[num_rows, num_cols]
3. Performance Intuition
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
4. Memory Access Discussion
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
5. What Triton Is Abstracting
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
6. What CUDA Makes Explicit
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
7. Reflection Questions
- Why is a running max needed instead of only a running sum?
- Why does online softmax enable FlashAttention-style blockwise computation?
- Which values must persist from one tile to the next?
8. Implementation Checklist
- Read the reference online softmax
- Derive the recurrence informally
- Implement the Triton blocked recurrence
- Implement the CUDA blocked recurrence
- Compare against full softmax on small shapes first
Informal Recurrence
Given a previous state (m_prev, l_prev) and a new tile with max m_tile and denominator contribution l_tile, define:
m_new = max(m_prev, m_tile)l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)
That is the key idea you will reuse in FlashAttention.