Initial project scaffold
This commit is contained in:
2
tasks/01_vector_add/__init__.py
Normal file
2
tasks/01_vector_add/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Vector add task."""
|
||||
|
||||
50
tasks/01_vector_add/bench.py
Normal file
50
tasks/01_vector_add/bench.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.vector_add import triton_vector_add
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 25) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
x = torch.randn(1 << 20, device=device)
|
||||
y = torch.randn(1 << 20, device=device)
|
||||
ref_ms = benchmark(torch_vector_add, x, y)
|
||||
print(f"torch_vector_add: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_vector_add, x, y)
|
||||
print(f"triton_vector_add: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_vector_add: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
tasks/01_vector_add/cuda_skeleton.cu
Normal file
10
tasks/01_vector_add/cuda_skeleton.cu
Normal file
@@ -0,0 +1,10 @@
|
||||
// Workbook-local CUDA sketch for vector add.
|
||||
//
|
||||
// The repository-level implementation lives in kernels/cuda/src/vector_add.cu.
|
||||
// Read this side by side with the Triton version.
|
||||
|
||||
// TODO(student):
|
||||
// 1. Compute global_idx from blockIdx.x, blockDim.x, and threadIdx.x.
|
||||
// 2. Guard the tail with if (global_idx < numel).
|
||||
// 3. Load x[global_idx] and y[global_idx].
|
||||
// 4. Store the sum.
|
||||
40
tasks/01_vector_add/spec.md
Normal file
40
tasks/01_vector_add/spec.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Task 01: Vector Add
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement `out[i] = x[i] + y[i]` in both Triton and CUDA, then compare both against the PyTorch reference.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- Input: two tensors with identical 1D or flattened shapes
|
||||
- Output: one tensor with the same shape
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
Vector add is simple enough that launch overhead and memory bandwidth dominate quickly. It is a good place to learn indexing before the math becomes interesting.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
This kernel should read `x[i]` and `y[i]` once and write `out[i]` once. The main thing to inspect is whether neighboring threads or lanes access neighboring elements.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton lets you express one block of contiguous offsets with `program_id` and `tl.arange`, then apply a mask on the tail.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA makes you compute `global_idx` from block and thread indices yourself and write the boundary check explicitly.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- What is the exact correspondence between `program_id` and `blockIdx.x` here?
|
||||
- Why is a mask or bounds check required on the final block?
|
||||
- How would the ownership change if one thread handled multiple elements?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Confirm the reference implementation
|
||||
- Fill in the Triton masked loads, add, and store
|
||||
- Fill in the CUDA thread ownership and store
|
||||
- Test small and non-multiple-of-block-size shapes
|
||||
- Benchmark bandwidth on larger vectors
|
||||
35
tasks/01_vector_add/test_task.py
Normal file
35
tasks/01_vector_add/test_task.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.vector_add import triton_vector_add
|
||||
from reference.torch_vector_add import torch_vector_add
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_vector_add_reference_matches_torch():
|
||||
x = torch.randn(257)
|
||||
y = torch.randn(257)
|
||||
out = torch_vector_add(x, y)
|
||||
torch.testing.assert_close(out, x + y)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_vector_add_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
x = torch.randn(513, device="cuda")
|
||||
y = torch.randn(513, device="cuda")
|
||||
out = _run_impl_or_skip(triton_vector_add, x, y)
|
||||
torch.testing.assert_close(out, x + y)
|
||||
19
tasks/01_vector_add/triton_skeleton.py
Normal file
19
tasks/01_vector_add/triton_skeleton.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Workbook-local Triton sketch for vector add.
|
||||
|
||||
The repository-level implementation lives in kernels/triton/vector_add.py.
|
||||
Use this file as a short-form scratchpad before editing the real kernel.
|
||||
"""
|
||||
|
||||
|
||||
def notes() -> str:
|
||||
return """
|
||||
TODO(student):
|
||||
1. Map one Triton program instance to one contiguous block of elements.
|
||||
2. Compute offsets with pid * BLOCK_SIZE + arange.
|
||||
3. Mask the tail.
|
||||
4. Load x and y, add them, store the result.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(notes())
|
||||
Reference in New Issue
Block a user