Initial project scaffold
This commit is contained in:
2
tasks/02_row_softmax/__init__.py
Normal file
2
tasks/02_row_softmax/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Row softmax task."""
|
||||
|
||||
49
tasks/02_row_softmax/bench.py
Normal file
49
tasks/02_row_softmax/bench.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.row_softmax import triton_row_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
x = torch.randn(4096, 1024, device=device)
|
||||
ref_ms = benchmark(torch_row_softmax, x)
|
||||
print(f"torch_row_softmax: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_row_softmax, x)
|
||||
print(f"triton_row_softmax: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_row_softmax: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
tasks/02_row_softmax/cuda_skeleton.cu
Normal file
11
tasks/02_row_softmax/cuda_skeleton.cu
Normal file
@@ -0,0 +1,11 @@
|
||||
// Workbook-local CUDA sketch for row softmax.
|
||||
//
|
||||
// Reflection prompt:
|
||||
// Softmax is usually bandwidth-bound because the math is cheap but the rows are read and written a lot.
|
||||
// Keep track of how many global-memory passes your implementation needs.
|
||||
|
||||
// TODO(student):
|
||||
// 1. Assign one block or block tile to a row.
|
||||
// 2. Compute the row max.
|
||||
// 3. Compute the sum of exp(x - row_max).
|
||||
// 4. Normalize the row.
|
||||
40
tasks/02_row_softmax/spec.md
Normal file
40
tasks/02_row_softmax/spec.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Task 02: Row Softmax
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement a row-wise softmax with numerical stability and compare naive and fused viewpoints.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- Input: a 2D tensor `[num_rows, num_cols]`
|
||||
- Output: a 2D tensor with the same shape
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
Softmax is often bandwidth-bound because each element is read several times unless you fuse work carefully. The arithmetic is cheap relative to the data movement.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
A naive implementation may read rows multiple times: once for the max, once for the sum of exponentials, and once for normalization. Think about which intermediate values can stay on chip.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton makes it easy to load a row block, apply masked operations, and reduce across the block with tensor-style code.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA forces you to decide where the row reduction lives: one block per row, multiple warps per row, or a tiled strategy. Shared-memory use and synchronization become explicit design choices.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- Why is max subtraction required for stable softmax?
|
||||
- Why is softmax often bandwidth-bound rather than compute-bound?
|
||||
- Which intermediate quantities would you prefer not to write back to global memory?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Validate the reference row softmax
|
||||
- Fill in Triton row loading, max reduction, sum reduction, and normalization
|
||||
- Fill in the CUDA reduction structure
|
||||
- Test large positive and negative values
|
||||
- Compare against `torch.softmax`
|
||||
40
tasks/02_row_softmax/test_task.py
Normal file
40
tasks/02_row_softmax/test_task.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.row_softmax import triton_row_softmax
|
||||
from reference.torch_row_softmax import torch_row_softmax
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_row_softmax_reference_matches_torch():
|
||||
x = torch.randn(8, 17)
|
||||
out = torch_row_softmax(x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_row_softmax_reference_is_numerically_stable():
|
||||
x = torch.tensor([[1000.0, 1001.0, 1002.0], [-1000.0, -999.0, -998.0]])
|
||||
out = torch_row_softmax(x)
|
||||
torch.testing.assert_close(out.sum(dim=1), torch.ones(2), atol=1e-6, rtol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_row_softmax_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
x = torch.randn(16, 63, device="cuda")
|
||||
out = _run_impl_or_skip(triton_row_softmax, x)
|
||||
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)
|
||||
20
tasks/02_row_softmax/triton_skeleton.py
Normal file
20
tasks/02_row_softmax/triton_skeleton.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Workbook-local Triton notes for row softmax."""
|
||||
|
||||
|
||||
def notes() -> str:
|
||||
return """
|
||||
TODO(student):
|
||||
1. Decide what one program instance owns: a whole row or a row tile.
|
||||
2. Load a row with masking.
|
||||
3. Compute row_max = max(x).
|
||||
4. Compute exp(x - row_max), then the row sum.
|
||||
5. Normalize and store.
|
||||
|
||||
Reflection:
|
||||
- Why does numerical stability matter here more than in vector add?
|
||||
- Where does extra memory traffic appear in a naive multi-kernel approach?
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(notes())
|
||||
Reference in New Issue
Block a user