Initial project scaffold
This commit is contained in:
2
tasks/03_tiled_matmul/__init__.py
Normal file
2
tasks/03_tiled_matmul/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Tiled matmul task."""
|
||||
|
||||
51
tasks/03_tiled_matmul/bench.py
Normal file
51
tasks/03_tiled_matmul/bench.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
||||
from reference.torch_matmul import torch_matmul
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 20) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
for m, k, n in [(128, 128, 128), (512, 512, 512)]:
|
||||
a = torch.randn(m, k, device=device)
|
||||
b = torch.randn(k, n, device=device)
|
||||
ref_ms = benchmark(torch_matmul, a, b)
|
||||
print(f"torch_matmul {m}x{k}x{n}: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_tiled_matmul, a, b)
|
||||
print(f"triton_tiled_matmul {m}x{k}x{n}: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_tiled_matmul {m}x{k}x{n}: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
tasks/03_tiled_matmul/cuda_skeleton.cu
Normal file
9
tasks/03_tiled_matmul/cuda_skeleton.cu
Normal file
@@ -0,0 +1,9 @@
|
||||
// Workbook-local CUDA sketch for tiled matmul.
|
||||
//
|
||||
// TODO(student):
|
||||
// 1. Choose a block tile size, for example 16x16 or 32x32.
|
||||
// 2. Load one A tile and one B tile into shared memory.
|
||||
// 3. Synchronize.
|
||||
// 4. Accumulate partial products.
|
||||
// 5. Synchronize before loading the next tile.
|
||||
// 6. Store the final C element or tile.
|
||||
51
tasks/03_tiled_matmul/spec.md
Normal file
51
tasks/03_tiled_matmul/spec.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Task 03: Tiled Matmul
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement a tiled matrix multiplication and compare the tile abstraction in Triton with the explicit shared-memory strategy in CUDA.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- Input `A`: `[M, K]`
|
||||
- Input `B`: `[K, N]`
|
||||
- Output `C`: `[M, N]`
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
Matmul becomes interesting once data reuse matters. Re-reading the same `A` and `B` values from global memory is expensive; tiling exists to reuse those values across many multiply-accumulate operations.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
Think about which `A` tile and `B` tile each work unit needs. The performance win comes from moving those tiles into on-chip storage and reusing them before fetching the next tile.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton lets you think in output tiles and blocked pointer arithmetic. The tile loads and accumulations read like tensor operations.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA makes you choose block dimensions, allocate shared memory, manage cooperative loads, and synchronize between load and compute phases.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- Which values in `A` and `B` are reused across multiple output elements?
|
||||
- Why does tiling reduce global-memory traffic?
|
||||
- How does a Triton tile map to CUDA shared-memory tiles and threads?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Confirm the reference matmul
|
||||
- Draw a block/tile diagram before coding
|
||||
- Implement the Triton tile loop over `K`
|
||||
- Implement the CUDA shared-memory tile loop
|
||||
- Benchmark against `torch.matmul` on small and medium sizes
|
||||
|
||||
## Tile Diagram Prompt
|
||||
|
||||
Sketch:
|
||||
|
||||
- one output tile `C[m0:m1, n0:n1]`
|
||||
- the matching `A[m0:m1, k0:k1]`
|
||||
- the matching `B[k0:k1, n0:n1]`
|
||||
|
||||
That sketch should tell you what belongs in shared memory.
|
||||
35
tasks/03_tiled_matmul/test_task.py
Normal file
35
tasks/03_tiled_matmul/test_task.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.tiled_matmul import triton_tiled_matmul
|
||||
from reference.torch_matmul import torch_matmul
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_tiled_matmul_reference_matches_torch():
|
||||
a = torch.randn(8, 16)
|
||||
b = torch.randn(16, 12)
|
||||
out = torch_matmul(a, b)
|
||||
torch.testing.assert_close(out, a @ b)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_tiled_matmul_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
a = torch.randn(32, 48, device="cuda")
|
||||
b = torch.randn(48, 40, device="cuda")
|
||||
out = _run_impl_or_skip(triton_tiled_matmul, a, b)
|
||||
torch.testing.assert_close(out, a @ b, atol=1e-3, rtol=1e-3)
|
||||
16
tasks/03_tiled_matmul/triton_skeleton.py
Normal file
16
tasks/03_tiled_matmul/triton_skeleton.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Workbook-local Triton notes for tiled matmul."""
|
||||
|
||||
|
||||
def notes() -> str:
|
||||
return """
|
||||
TODO(student):
|
||||
1. Map one program instance to one output tile.
|
||||
2. Build row/col offsets for the tile.
|
||||
3. Loop over K in block_k chunks.
|
||||
4. Load A and B tiles, accumulate partial products.
|
||||
5. Store the output tile with masking on edges.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(notes())
|
||||
Reference in New Issue
Block a user