Initial project scaffold

This commit is contained in:
2026-04-10 13:22:19 +00:00
commit 7fa69b1354
94 changed files with 3964 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
"""Environment sanity task."""

View File

@@ -0,0 +1,13 @@
# Environment Checklist
- PyTorch imports successfully
- `torch.cuda.is_available()` is `True`
- At least one CUDA device is visible
- The GPU name matches the machine you expect to be using
- Device capability is printed and recorded
- Triton imports successfully, or you know why it does not
- `torch.version.cuda` is visible when using CUDA-enabled PyTorch
- `nvcc --version` works if you plan to build the CUDA extension
- `nvidia-smi` works if the driver stack is installed
If any line above fails, fix that before working on later tasks.

View File

@@ -0,0 +1,46 @@
# Task 00: Environment Sanity
## 1. Problem Statement
Confirm that your machine can see the GPU software stack needed for the rest of the lab.
## 2. Expected Input/Output Shapes
This task is informational rather than tensor-shaped. The outputs are environment facts:
- PyTorch version
- CUDA availability
- Triton import status
- GPU name
- device capability
- toolkit and driver hints when available
## 3. Performance Intuition
Do not benchmark anything yet. First confirm that the environment is what you think it is.
## 4. Memory Access Discussion
Not applicable yet. The point is to avoid debugging kernels when the real problem is a mismatched driver or toolkit.
## 5. What Triton Is Abstracting
Even importing Triton depends on a compatible Python, PyTorch, driver, and GPU stack.
## 6. What CUDA Makes Explicit
CUDA makes the toolkit and architecture targeting explicit. Keep that explicit throughout this repo.
## 7. Reflection Questions
- What exact GPU name does the system report?
- What device capability does PyTorch report?
- Does Triton import cleanly?
- Which part of the stack would you inspect first if CUDA is unavailable?
## 8. Implementation Checklist
- Run `python tools/check_env.py`
- Run `python tools/print_device_info.py`
- Write down the reported capability
- Set `KERNEL_LAB_CUDA_ARCH` explicitly if you need to change architecture targeting

View File

@@ -0,0 +1,2 @@
"""Vector add task."""

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.vector_add import triton_vector_add
from reference.torch_vector_add import torch_vector_add
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(1 << 20, device=device)
y = torch.randn(1 << 20, device=device)
ref_ms = benchmark(torch_vector_add, x, y)
print(f"torch_vector_add: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_vector_add, x, y)
print(f"triton_vector_add: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_vector_add: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,10 @@
// Workbook-local CUDA sketch for vector add.
//
// The repository-level implementation lives in kernels/cuda/src/vector_add.cu.
// Read this side by side with the Triton version.
// TODO(student):
// 1. Compute global_idx from blockIdx.x, blockDim.x, and threadIdx.x.
// 2. Guard the tail with if (global_idx < numel).
// 3. Load x[global_idx] and y[global_idx].
// 4. Store the sum.

View File

@@ -0,0 +1,40 @@
# Task 01: Vector Add
## 1. Problem Statement
Implement `out[i] = x[i] + y[i]` in both Triton and CUDA, then compare both against the PyTorch reference.
## 2. Expected Input/Output Shapes
- Input: two tensors with identical 1D or flattened shapes
- Output: one tensor with the same shape
## 3. Performance Intuition
Vector add is simple enough that launch overhead and memory bandwidth dominate quickly. It is a good place to learn indexing before the math becomes interesting.
## 4. Memory Access Discussion
This kernel should read `x[i]` and `y[i]` once and write `out[i]` once. The main thing to inspect is whether neighboring threads or lanes access neighboring elements.
## 5. What Triton Is Abstracting
Triton lets you express one block of contiguous offsets with `program_id` and `tl.arange`, then apply a mask on the tail.
## 6. What CUDA Makes Explicit
CUDA makes you compute `global_idx` from block and thread indices yourself and write the boundary check explicitly.
## 7. Reflection Questions
- What is the exact correspondence between `program_id` and `blockIdx.x` here?
- Why is a mask or bounds check required on the final block?
- How would the ownership change if one thread handled multiple elements?
## 8. Implementation Checklist
- Confirm the reference implementation
- Fill in the Triton masked loads, add, and store
- Fill in the CUDA thread ownership and store
- Test small and non-multiple-of-block-size shapes
- Benchmark bandwidth on larger vectors

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.vector_add import triton_vector_add
from reference.torch_vector_add import torch_vector_add
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_vector_add_reference_matches_torch():
x = torch.randn(257)
y = torch.randn(257)
out = torch_vector_add(x, y)
torch.testing.assert_close(out, x + y)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_vector_add_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
x = torch.randn(513, device="cuda")
y = torch.randn(513, device="cuda")
out = _run_impl_or_skip(triton_vector_add, x, y)
torch.testing.assert_close(out, x + y)

View File

@@ -0,0 +1,19 @@
"""Workbook-local Triton sketch for vector add.
The repository-level implementation lives in kernels/triton/vector_add.py.
Use this file as a short-form scratchpad before editing the real kernel.
"""
def notes() -> str:
return """
TODO(student):
1. Map one Triton program instance to one contiguous block of elements.
2. Compute offsets with pid * BLOCK_SIZE + arange.
3. Mask the tail.
4. Load x and y, add them, store the result.
"""
if __name__ == "__main__":
print(notes())

View File

@@ -0,0 +1,2 @@
"""Row softmax task."""

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.row_softmax import triton_row_softmax
from reference.torch_row_softmax import torch_row_softmax
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(4096, 1024, device=device)
ref_ms = benchmark(torch_row_softmax, x)
print(f"torch_row_softmax: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_row_softmax, x)
print(f"triton_row_softmax: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_row_softmax: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,11 @@
// Workbook-local CUDA sketch for row softmax.
//
// Reflection prompt:
// Softmax is usually bandwidth-bound because the math is cheap but the rows are read and written a lot.
// Keep track of how many global-memory passes your implementation needs.
// TODO(student):
// 1. Assign one block or block tile to a row.
// 2. Compute the row max.
// 3. Compute the sum of exp(x - row_max).
// 4. Normalize the row.

View File

@@ -0,0 +1,40 @@
# Task 02: Row Softmax
## 1. Problem Statement
Implement a row-wise softmax with numerical stability and compare naive and fused viewpoints.
## 2. Expected Input/Output Shapes
- Input: a 2D tensor `[num_rows, num_cols]`
- Output: a 2D tensor with the same shape
## 3. Performance Intuition
Softmax is often bandwidth-bound because each element is read several times unless you fuse work carefully. The arithmetic is cheap relative to the data movement.
## 4. Memory Access Discussion
A naive implementation may read rows multiple times: once for the max, once for the sum of exponentials, and once for normalization. Think about which intermediate values can stay on chip.
## 5. What Triton Is Abstracting
Triton makes it easy to load a row block, apply masked operations, and reduce across the block with tensor-style code.
## 6. What CUDA Makes Explicit
CUDA forces you to decide where the row reduction lives: one block per row, multiple warps per row, or a tiled strategy. Shared-memory use and synchronization become explicit design choices.
## 7. Reflection Questions
- Why is max subtraction required for stable softmax?
- Why is softmax often bandwidth-bound rather than compute-bound?
- Which intermediate quantities would you prefer not to write back to global memory?
## 8. Implementation Checklist
- Validate the reference row softmax
- Fill in Triton row loading, max reduction, sum reduction, and normalization
- Fill in the CUDA reduction structure
- Test large positive and negative values
- Compare against `torch.softmax`

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.row_softmax import triton_row_softmax
from reference.torch_row_softmax import torch_row_softmax
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_row_softmax_reference_matches_torch():
x = torch.randn(8, 17)
out = torch_row_softmax(x)
torch.testing.assert_close(out, torch.softmax(x, dim=1))
@pytest.mark.reference
def test_row_softmax_reference_is_numerically_stable():
x = torch.tensor([[1000.0, 1001.0, 1002.0], [-1000.0, -999.0, -998.0]])
out = torch_row_softmax(x)
torch.testing.assert_close(out.sum(dim=1), torch.ones(2), atol=1e-6, rtol=1e-6)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_row_softmax_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
x = torch.randn(16, 63, device="cuda")
out = _run_impl_or_skip(triton_row_softmax, x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)

View File

@@ -0,0 +1,20 @@
"""Workbook-local Triton notes for row softmax."""
def notes() -> str:
return """
TODO(student):
1. Decide what one program instance owns: a whole row or a row tile.
2. Load a row with masking.
3. Compute row_max = max(x).
4. Compute exp(x - row_max), then the row sum.
5. Normalize and store.
Reflection:
- Why does numerical stability matter here more than in vector add?
- Where does extra memory traffic appear in a naive multi-kernel approach?
"""
if __name__ == "__main__":
print(notes())

View File

@@ -0,0 +1,2 @@
"""Tiled matmul task."""

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.tiled_matmul import triton_tiled_matmul
from reference.torch_matmul import torch_matmul
def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
for m, k, n in [(128, 128, 128), (512, 512, 512)]:
a = torch.randn(m, k, device=device)
b = torch.randn(k, n, device=device)
ref_ms = benchmark(torch_matmul, a, b)
print(f"torch_matmul {m}x{k}x{n}: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_tiled_matmul, a, b)
print(f"triton_tiled_matmul {m}x{k}x{n}: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_tiled_matmul {m}x{k}x{n}: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,9 @@
// Workbook-local CUDA sketch for tiled matmul.
//
// TODO(student):
// 1. Choose a block tile size, for example 16x16 or 32x32.
// 2. Load one A tile and one B tile into shared memory.
// 3. Synchronize.
// 4. Accumulate partial products.
// 5. Synchronize before loading the next tile.
// 6. Store the final C element or tile.

View File

@@ -0,0 +1,51 @@
# Task 03: Tiled Matmul
## 1. Problem Statement
Implement a tiled matrix multiplication and compare the tile abstraction in Triton with the explicit shared-memory strategy in CUDA.
## 2. Expected Input/Output Shapes
- Input `A`: `[M, K]`
- Input `B`: `[K, N]`
- Output `C`: `[M, N]`
## 3. Performance Intuition
Matmul becomes interesting once data reuse matters. Re-reading the same `A` and `B` values from global memory is expensive; tiling exists to reuse those values across many multiply-accumulate operations.
## 4. Memory Access Discussion
Think about which `A` tile and `B` tile each work unit needs. The performance win comes from moving those tiles into on-chip storage and reusing them before fetching the next tile.
## 5. What Triton Is Abstracting
Triton lets you think in output tiles and blocked pointer arithmetic. The tile loads and accumulations read like tensor operations.
## 6. What CUDA Makes Explicit
CUDA makes you choose block dimensions, allocate shared memory, manage cooperative loads, and synchronize between load and compute phases.
## 7. Reflection Questions
- Which values in `A` and `B` are reused across multiple output elements?
- Why does tiling reduce global-memory traffic?
- How does a Triton tile map to CUDA shared-memory tiles and threads?
## 8. Implementation Checklist
- Confirm the reference matmul
- Draw a block/tile diagram before coding
- Implement the Triton tile loop over `K`
- Implement the CUDA shared-memory tile loop
- Benchmark against `torch.matmul` on small and medium sizes
## Tile Diagram Prompt
Sketch:
- one output tile `C[m0:m1, n0:n1]`
- the matching `A[m0:m1, k0:k1]`
- the matching `B[k0:k1, n0:n1]`
That sketch should tell you what belongs in shared memory.

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.tiled_matmul import triton_tiled_matmul
from reference.torch_matmul import torch_matmul
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_tiled_matmul_reference_matches_torch():
a = torch.randn(8, 16)
b = torch.randn(16, 12)
out = torch_matmul(a, b)
torch.testing.assert_close(out, a @ b)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_tiled_matmul_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
a = torch.randn(32, 48, device="cuda")
b = torch.randn(48, 40, device="cuda")
out = _run_impl_or_skip(triton_tiled_matmul, a, b)
torch.testing.assert_close(out, a @ b, atol=1e-3, rtol=1e-3)

View File

@@ -0,0 +1,16 @@
"""Workbook-local Triton notes for tiled matmul."""
def notes() -> str:
return """
TODO(student):
1. Map one program instance to one output tile.
2. Build row/col offsets for the tile.
3. Loop over K in block_k chunks.
4. Load A and B tiles, accumulate partial products.
5. Store the output tile with masking on edges.
"""
if __name__ == "__main__":
print(notes())

View File

@@ -0,0 +1,2 @@
"""Online softmax task."""

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.online_softmax import triton_online_softmax
from reference.torch_online_softmax import torch_online_softmax
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
for _ in range(warmup):
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(2048, 2048, device=device)
ref_ms = benchmark(torch_online_softmax, x)
print(f"torch_online_softmax: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_online_softmax, x)
print(f"triton_online_softmax: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_online_softmax: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,7 @@
// Workbook-local CUDA sketch for online softmax.
//
// TODO(student):
// 1. Choose how one block owns one row or row tile.
// 2. Keep running_max and running_sum across column tiles.
// 3. Update the recurrence carefully for numerical stability.
// 4. Normalize the final row.

View File

@@ -0,0 +1,49 @@
# Task 04: Online Softmax
## 1. Problem Statement
Implement the running max / running sum formulation of softmax and connect it to blockwise attention.
## 2. Expected Input/Output Shapes
- Input: `[num_rows, num_cols]`
- Output: `[num_rows, num_cols]`
## 3. Performance Intuition
The main goal is algorithmic structure rather than raw speed. Online softmax becomes powerful because it lets you process a row incrementally without materializing the full reduction context at once.
## 4. Memory Access Discussion
Think in column tiles. Each tile updates the running normalization state. This matters later when attention scores are processed block by block.
## 5. What Triton Is Abstracting
Triton can express the blocked recurrence with vectorized loads and tensor math while still letting you reason about per-row state.
## 6. What CUDA Makes Explicit
CUDA forces you to decide where the running max and running sum live and how threads cooperate to update them across tiles.
## 7. Reflection Questions
- Why is a running max needed instead of only a running sum?
- Why does online softmax enable FlashAttention-style blockwise computation?
- Which values must persist from one tile to the next?
## 8. Implementation Checklist
- Read the reference online softmax
- Derive the recurrence informally
- Implement the Triton blocked recurrence
- Implement the CUDA blocked recurrence
- Compare against full softmax on small shapes first
## Informal Recurrence
Given a previous state `(m_prev, l_prev)` and a new tile with max `m_tile` and denominator contribution `l_tile`, define:
- `m_new = max(m_prev, m_tile)`
- `l_new = l_prev * exp(m_prev - m_new) + l_tile * exp(m_tile - m_new)`
That is the key idea you will reuse in FlashAttention.

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.online_softmax import triton_online_softmax
from reference.torch_online_softmax import torch_online_softmax
def _run_impl_or_skip(fn, *args):
try:
return fn(*args)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_online_softmax_reference_matches_torch():
x = torch.randn(6, 19)
out = torch_online_softmax(x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-5, rtol=1e-5)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_online_softmax_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
x = torch.randn(8, 97, device="cuda")
out = _run_impl_or_skip(triton_online_softmax, x)
torch.testing.assert_close(out, torch.softmax(x, dim=1), atol=1e-4, rtol=1e-4)

View File

@@ -0,0 +1,15 @@
"""Workbook-local Triton notes for online softmax."""
def notes() -> str:
return """
TODO(student):
1. Keep running_max and running_sum for one row.
2. Process the row in blocks.
3. Update the recurrence after each block.
4. Normalize once the full row has been seen.
"""
if __name__ == "__main__":
print(notes())

View File

@@ -0,0 +1,2 @@
"""Flash attention forward task."""

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import statistics
import sys
import time
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
from reference.torch_attention import torch_attention
def benchmark(fn, *args, warmup: int = 5, reps: int = 20, **kwargs) -> float:
for _ in range(warmup):
fn(*args, **kwargs)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms = []
for _ in range(reps):
if args[0].is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
fn(*args, **kwargs)
if args[0].is_cuda:
torch.cuda.synchronize()
times_ms.append((time.perf_counter() - start) * 1e3)
return statistics.median(times_ms)
def main() -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
q = torch.randn(2, 8, 128, 64, device=device)
k = torch.randn(2, 8, 128, 64, device=device)
v = torch.randn(2, 8, 128, 64, device=device)
ref_ms = benchmark(torch_attention, q, k, v, causal=False)
print(f"torch_attention: {ref_ms:.3f} ms")
if device == "cuda":
try:
triton_ms = benchmark(triton_flash_attention_fwd, q, k, v, causal=False)
print(f"triton_flash_attention_fwd: {triton_ms:.3f} ms")
except (NotImplementedError, RuntimeError) as exc:
print(f"triton_flash_attention_fwd: skipped ({exc})")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,14 @@
// Workbook-local CUDA sketch for FlashAttention forward.
//
// Map this against the Triton sketch:
// - Triton program_id for query tile -> CUDA block ownership
// - Triton block pointer loads -> CUDA cooperative global-to-shared loads
// - Triton masks -> explicit edge and causal checks
// - Triton implicit block math -> thread/block index arithmetic
// TODO(student):
// 1. Assign a block to one batch/head/query tile.
// 2. Load a Q tile and loop over K/V tiles.
// 3. Compute score tiles and causal masking.
// 4. Update online softmax state.
// 5. Accumulate the output tile.

View File

@@ -0,0 +1,59 @@
# Task 05: Flash Attention Forward
## 1. Problem Statement
Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA.
## 2. Expected Input/Output Shapes
- `Q`: `[batch, heads, seq_len, head_dim]`
- `K`: `[batch, heads, seq_len, head_dim]`
- `V`: `[batch, heads, seq_len, head_dim]`
- `Output`: `[batch, heads, seq_len, head_dim]`
## 3. Performance Intuition
The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct.
## 4. Memory Access Discussion
This task is about staged movement:
- load a `Q` block
- iterate over `K` and `V` blocks
- compute score blocks
- update online normalization
- accumulate the output block
Track where each quantity lives: global memory, registers, or shared memory.
## 5. What Triton Is Abstracting
Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions.
## 6. What CUDA Makes Explicit
CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level.
## 7. Reflection Questions
- How does online softmax avoid writing out the full score matrix?
- Which loop corresponds to iterating over key/value blocks?
- Where do causal masking and normalization interact?
- How does a Triton block pointer map to a CUDA shared-memory load phase?
## 8. Implementation Checklist
- Confirm the PyTorch reference on tiny shapes
- Trace the online softmax state update
- Implement one Triton blockwise forward path
- Implement one CUDA blockwise forward path
- Test non-causal first, then causal
- Benchmark only after small-shape correctness passes
## Explicit Triton To CUDA Mapping
- Triton `program_id(axis=0)` for query tiles maps to CUDA query-tile block ownership
- Triton `program_id(axis=1)` for batch/head maps to a flattened batch-head block index
- Triton block pointer math maps to shared-memory staging and pointer arithmetic
- Triton masked edge handling maps to explicit tail checks and mask branches

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
import pytest
import torch
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
from reference.torch_attention import torch_attention
def _run_impl_or_skip(fn, *args, **kwargs):
try:
return fn(*args, **kwargs)
except NotImplementedError:
pytest.skip("implementation is still TODO")
except RuntimeError as exc:
pytest.skip(str(exc))
@pytest.mark.reference
def test_attention_reference_small_shape():
q = torch.randn(1, 2, 8, 16)
k = torch.randn(1, 2, 8, 16)
v = torch.randn(1, 2, 8, 16)
out = torch_attention(q, k, v, causal=False)
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
@pytest.mark.reference
def test_attention_reference_causal_small_shape():
q = torch.randn(1, 1, 8, 16)
k = torch.randn(1, 1, 8, 16)
v = torch.randn(1, 1, 8, 16)
out = torch_attention(q, k, v, causal=True)
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
@pytest.mark.triton_required
@pytest.mark.skeleton
def test_triton_flash_attention_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
q = torch.randn(1, 2, 16, 32, device="cuda")
k = torch.randn(1, 2, 16, 32, device="cuda")
v = torch.randn(1, 2, 16, 32, device="cuda")
out = _run_impl_or_skip(triton_flash_attention_fwd, q, k, v, causal=False)
expected = torch_attention(q, k, v, causal=False)
torch.testing.assert_close(out, expected, atol=2e-3, rtol=2e-3)

View File

@@ -0,0 +1,19 @@
"""Workbook-local Triton notes for FlashAttention forward."""
def notes() -> str:
return """
TODO(student):
1. Assign one program instance to one query block for one batch/head.
2. Load a Q block.
3. Iterate over K/V blocks.
4. Compute score blocks.
5. Apply optional causal masking.
6. Update running max and running sum.
7. Accumulate the output block.
8. Store the final output.
"""
if __name__ == "__main__":
print(notes())

View File

@@ -0,0 +1,2 @@
"""PyTorch custom op task."""

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import torch
from tools.lab_extension import build_extension
def main() -> None:
ext = build_extension(verbose=True)
if ext is None:
return
print("Extension loaded.")
print("Available torch.ops namespace:", hasattr(torch.ops, "kernel_lab"))
if hasattr(torch.ops, "kernel_lab"):
print("Registered ops:", dir(torch.ops.kernel_lab))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
import pytest
import torch
from tools.lab_extension import build_extension
@pytest.mark.cuda_required
@pytest.mark.skeleton
def test_vector_add_opcheck_if_available():
if not torch.cuda.is_available():
pytest.skip("CUDA is not available")
ext = build_extension(verbose=False)
if ext is None or not hasattr(torch.ops, "kernel_lab"):
pytest.skip("extension is unavailable")
if not hasattr(torch.library, "opcheck"):
pytest.skip("torch.library.opcheck is unavailable")
x = torch.randn(32, device="cuda")
y = torch.randn(32, device="cuda")
try:
torch.ops.kernel_lab.vector_add(x, y)
except Exception as exc:
pytest.skip(f"operator is not implemented yet: {exc}")
torch.library.opcheck(torch.ops.kernel_lab.vector_add, (x, y))

View File

@@ -0,0 +1,45 @@
# Task 06: PyTorch Custom Op
## 1. Problem Statement
Expose a CUDA kernel as a PyTorch operator so Python code can call it and test it like any other operator.
## 2. Expected Input/Output Shapes
For the starter binding, use vector add:
- `x`: `[N]`
- `y`: `[N]`
- output: `[N]`
The same pattern can later be extended to the other operators.
## 3. Performance Intuition
The binding layer is not usually where the kernel time goes, but it determines whether you can test, benchmark, and profile the CUDA implementation from Python.
## 4. Memory Access Discussion
The binding itself does not optimize memory traffic; it passes tensors and dispatches the kernel. Still, the binding must preserve shape, dtype, device, and contiguity assumptions.
## 5. What Triton Is Abstracting
Triton often avoids a separate C++ binding layer because Python can launch the JIT kernel directly.
## 6. What CUDA Makes Explicit
CUDA plus PyTorch binding requires you to define function signatures, operator registration, and build integration explicitly.
## 7. Reflection Questions
- What assumptions should the binding validate before calling a CUDA kernel?
- Why is operator registration useful for testing and benchmarking?
- What changes once you want autograd support?
## 8. Implementation Checklist
- Read `kernels/cuda/binding/binding.cpp`
- Build or load the extension from Python
- Call the operator from `torch.ops.kernel_lab`
- Add correctness checks once the CUDA kernel is implemented
- Try `torch.library.opcheck` if your PyTorch build provides it

View File

@@ -0,0 +1 @@
"""Profiling task."""

View File

@@ -0,0 +1,23 @@
# Profiling Examples
## Nsight Compute
```bash
./tools/profile_ncu.sh python bench/bench_vector_add.py --device cuda --mode triton
./tools/profile_ncu.sh python bench/bench_softmax.py --device cuda --mode torch
```
## Nsight Systems
```bash
./tools/profile_nsys.sh python bench/bench_matmul.py --device cuda --mode triton
./tools/profile_nsys.sh python bench/bench_attention.py --device cuda --mode torch
```
## First Things To Inspect
- median runtime from the benchmark harness
- whether warmup was excluded
- whether kernels overlap or serialize
- whether memory throughput is near a practical ceiling
- whether a kernel launch is tiny enough that launch overhead matters

View File

@@ -0,0 +1,40 @@
# Task 07: Profiling
## 1. Problem Statement
Profile one kernel at a time and learn to interpret the first few metrics before tuning anything.
## 2. Expected Input/Output Shapes
Use the same shapes as your benchmark harness so measurements stay comparable.
## 3. Performance Intuition
Profiling is how you turn guesses into evidence. Use it after correctness is established.
## 4. Memory Access Discussion
Profilers can tell you whether the kernel is limited by memory throughput, occupancy, or something else. Interpret those numbers in terms of the operator's access pattern.
## 5. What Triton Is Abstracting
Triton hides low-level details in code, but profilers still show the resulting kernels and hardware behavior.
## 6. What CUDA Makes Explicit
CUDA kernels expose their launch shapes, synchronization behavior, and memory hierarchy choices more directly, which can make profiler results easier to map back to code.
## 7. Reflection Questions
- Did you profile a single kernel or an entire script?
- Did you warm up before timing?
- Which metric was the first signal that the kernel was bandwidth-bound or compute-bound?
## 8. Implementation Checklist
- Pick one benchmark and one implementation
- Warm up first
- Synchronize before and after timing
- Run `ncu` and inspect a small set of metrics
- Run `nsys` and inspect the timeline
- Write down what you learned before changing the kernel

2
tasks/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""Workbook tasks."""