50 lines
1.7 KiB
Markdown
50 lines
1.7 KiB
Markdown
# Task 04: Online Softmax
|
|
|
|
## 1. Problem Statement
|
|
|
|
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
|
|
|
|
## 2. Expected Input/Output Shapes
|
|
|
|
- Input: `[num_rows, num_cols]`
|
|
- Output: `[num_rows, num_cols]`
|
|
|
|
## 3. Performance Intuition
|
|
|
|
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
|
|
|
|
## 4. Memory Access Discussion
|
|
|
|
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
|
|
|
|
## 5. What Triton Is Abstracting
|
|
|
|
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
|
|
|
|
## 6. What CUDA Makes Explicit
|
|
|
|
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
|
|
|
|
## 7. Reflection Questions
|
|
|
|
- Why is a running max needed instead of only a running sum?
|
|
- Why does online softmax enable FlashAttention-style blockwise computation?
|
|
- Which values must persist from one tile to the next?
|
|
|
|
## 8. Implementation Checklist
|
|
|
|
- Read the reference online softmax
|
|
- Derive the recurrence informally
|
|
- Implement the Triton blocked recurrence
|
|
- Implement the CUDA blocked recurrence
|
|
- Compare against full softmax on small shapes first
|
|
|
|
## Informal Recurrence
|
|
|
|
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
|
|
|
|
- `m_new = max(m_prev, m_tile)`
|
|
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
|
|
|
|
That is the key idea you will reuse in FlashAttention.
|