Initial project scaffold
This commit is contained in:
2
tasks/04_online_softmax/__init__.py
Normal file
2
tasks/04_online_softmax/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Online softmax task."""
|
||||
|
||||
49
tasks/04_online_softmax/bench.py
Normal file
49
tasks/04_online_softmax/bench.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.online_softmax import triton_online_softmax
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
x = torch.randn(2048, 2048, device=device)
|
||||
ref_ms = benchmark(torch_online_softmax, x)
|
||||
print(f"torch_online_softmax: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_online_softmax, x)
|
||||
print(f"triton_online_softmax: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_online_softmax: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
tasks/04_online_softmax/cuda_skeleton.cu
Normal file
7
tasks/04_online_softmax/cuda_skeleton.cu
Normal file
@@ -0,0 +1,7 @@
|
||||
// Workbook-local CUDA sketch for online softmax.
|
||||
//
|
||||
// TODO(student):
|
||||
// 1. Choose how one block owns one row or row tile.
|
||||
// 2. Keep running_max and running_sum across column tiles.
|
||||
// 3. Update the recurrence carefully for numerical stability.
|
||||
// 4. Normalize the final row.
|
||||
49
tasks/04_online_softmax/spec.md
Normal file
49
tasks/04_online_softmax/spec.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Task 04: Online Softmax
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- Input: `[num_rows, num_cols]`
|
||||
- Output: `[num_rows, num_cols]`
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- Why is a running max needed instead of only a running sum?
|
||||
- Why does online softmax enable FlashAttention-style blockwise computation?
|
||||
- Which values must persist from one tile to the next?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Read the reference online softmax
|
||||
- Derive the recurrence informally
|
||||
- Implement the Triton blocked recurrence
|
||||
- Implement the CUDA blocked recurrence
|
||||
- Compare against full softmax on small shapes first
|
||||
|
||||
## Informal Recurrence
|
||||
|
||||
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
|
||||
|
||||
- `m_new = max(m_prev, m_tile)`
|
||||
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
|
||||
|
||||
That is the key idea you will reuse in FlashAttention.
|
||||
33
tasks/04_online_softmax/test_task.py
Normal file
33
tasks/04_online_softmax/test_task.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.online_softmax import triton_online_softmax
|
||||
from reference.torch_online_softmax import torch_online_softmax
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_online_softmax_reference_matches_torch():
|
||||
x = torch.randn(6, 19)
|
||||
out = torch_online_softmax(x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_online_softmax_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
x = torch.randn(8, 97, device="cuda")
|
||||
out = _run_impl_or_skip(triton_online_softmax, x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)
|
||||
15
tasks/04_online_softmax/triton_skeleton.py
Normal file
15
tasks/04_online_softmax/triton_skeleton.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Workbook-local Triton notes for online softmax."""
|
||||
|
||||
|
||||
def notes() -> str:
|
||||
return """
|
||||
TODO(student):
|
||||
1. Keep running_max and running_sum for one row.
|
||||
2. Process the row in blocks.
|
||||
3. Update the recurrence after each block.
|
||||
4. Normalize once the full row has been seen.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(notes())
|
||||
Reference in New Issue
Block a user