Files
2026-04-10 13:22:19 +00:00

50 lines
1.7 KiB
Markdown

# Task 04: Online Softmax
## 1. Problem Statement
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
## 2. Expected Input/Output Shapes
- Input: `[num_rows, num_cols]`
- Output: `[num_rows, num_cols]`
## 3. Performance Intuition
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
## 4. Memory Access Discussion
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
## 5. What Triton Is Abstracting
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
## 6. What CUDA Makes Explicit
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
## 7. Reflection Questions
- Why is a running max needed instead of only a running sum?
- Why does online softmax enable FlashAttention-style blockwise computation?
- Which values must persist from one tile to the next?
## 8. Implementation Checklist
- Read the reference online softmax
- Derive the recurrence informally
- Implement the Triton blocked recurrence
- Implement the CUDA blocked recurrence
- Compare against full softmax on small shapes first
## Informal Recurrence
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
- `m_new = max(m_prev, m_tile)`
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
That is the key idea you will reuse in FlashAttention.