Initial project scaffold
This commit is contained in:
49
tasks/04_online_softmax/spec.md
Normal file
49
tasks/04_online_softmax/spec.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Task 04: Online Softmax
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- Input: `[num_rows, num_cols]`
|
||||
- Output: `[num_rows, num_cols]`
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- Why is a running max needed instead of only a running sum?
|
||||
- Why does online softmax enable FlashAttention-style blockwise computation?
|
||||
- Which values must persist from one tile to the next?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Read the reference online softmax
|
||||
- Derive the recurrence informally
|
||||
- Implement the Triton blocked recurrence
|
||||
- Implement the CUDA blocked recurrence
|
||||
- Compare against full softmax on small shapes first
|
||||
|
||||
## Informal Recurrence
|
||||
|
||||
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
|
||||
|
||||
- `m_new = max(m_prev, m_tile)`
|
||||
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
|
||||
|
||||
That is the key idea you will reuse in FlashAttention.
|
||||
Reference in New Issue
Block a user