Files
kernel-lab/tasks/04_online_softmax/spec.md
2026-04-10 13:15:06 +00:00

1.7 KiB

Task 04: Online Softmax

1. Problem Statement

Implement the running max / running sum formulation of softmax and connect it to blockwise attention.

2. Expected Input/Output Shapes

  • Input: [num_rows, num_cols]
  • Output: [num_rows, num_cols]

3. Performance Intuition

The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.

4. Memory Access Discussion

Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.

5. What Triton Is Abstracting

Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.

6. What CUDA Makes Explicit

CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.

7. Reflection Questions

  • Why is a running max needed instead of only a running sum?
  • Why does online softmax enable FlashAttention-style blockwise computation?
  • Which values must persist from one tile to the next?

8. Implementation Checklist

  • Read the reference online softmax
  • Derive the recurrence informally
  • Implement the Triton blocked recurrence
  • Implement the CUDA blocked recurrence
  • Compare against full softmax on small shapes first

Informal Recurrence

Given a previous state (m_prev, l_prev) and a new tile with max m_tile and denominator contribution l_tile, define:

  • m_new = max(m_prev, m_tile)
  • l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)

That is the key idea you will reuse in FlashAttention.