Initial project scaffold

This commit is contained in:
2026-04-10 13:22:19 +00:00
commit 7fa69b1354
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
"""Online softmax task."""

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.online_softmax import triton_online_softmax
from reference.torch_online_softmax import torch_online_softmax
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(2048, 2048, device=device)
ref_ms = benchmark(torch_online_softmax, x)
print(f"torch_online_softmax: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_online_softmax, x)
print(f"triton_online_softmax: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_online_softmax: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,7 @@
// Workbook-local CUDA sketch for online softmax.
//
// TODO(student):
// 1. Choose how one block owns one row or row tile.
// 2. Keep running_max and running_sum across column tiles.
// 3. Update the recurrence carefully for numerical stability.
// 4. Normalize the final row.

View File

@@ -0,0 +1,49 @@
# Task 04: Online Softmax
## 1. Problem Statement
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
## 2. Expected Input/Output Shapes
- Input: `[num_rows, num_cols]`
- Output: `[num_rows, num_cols]`
## 3. Performance Intuition
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
## 4. Memory Access Discussion
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
## 5. What Triton Is Abstracting
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
## 6. What CUDA Makes Explicit
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
## 7. Reflection Questions
- Why is a running max needed instead of only a running sum?
- Why does online softmax enable FlashAttention-style blockwise computation?
- Which values must persist from one tile to the next?
## 8. Implementation Checklist
- Read the reference online softmax
- Derive the recurrence informally
- Implement the Triton blocked recurrence
- Implement the CUDA blocked recurrence
- Compare against full softmax on small shapes first
## Informal Recurrence
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
- `m_new = max(m_prev, m_tile)`
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
That is the key idea you will reuse in FlashAttention.

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.online_softmax import triton_online_softmax
from reference.torch_online_softmax import torch_online_softmax
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_online_softmax_reference_matches_torch():
x = torch.randn(6, 19)
out = torch_online_softmax(x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_online_softmax_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
x = torch.randn(8, 97, device="cuda")
out = _run_impl_or_skip(triton_online_softmax, x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)

View File

@@ -0,0 +1,15 @@
"""Workbook-local Triton notes for online softmax."""
def notes() -> str:
return """
TODO(student):
1. Keep running_max and running_sum for one row.
2. Process the row in blocks.
3. Update the recurrence after each block.
4. Normalize once the full row has been seen.
"""
if __name__ == "__main__":
print(notes())