76 lines
1.9 KiB
Markdown
76 lines
1.9 KiB
Markdown
# Roadmap
|
|
|
|
## Week 1 Study Plan
|
|
|
|
Day 1:
|
|
|
|
- Run `tools/check_env.py`
|
|
- Read `docs/gpu_basics.md`
|
|
- Read `docs/cuda_execution_model.md`
|
|
- Inspect `reference/torch_vector_add.py`
|
|
- Implement or partially implement `tasks/01_vector_add/triton_skeleton.py`
|
|
|
|
Day 2:
|
|
|
|
- Read `docs/triton_vs_cuda.md`
|
|
- Inspect `kernels/cuda/src/vector_add.cu`
|
|
- Fill in vector add indexing TODOs in Triton and CUDA
|
|
- Run `pytest -q tasks/01_vector_add/test_task.py`
|
|
|
|
Day 3:
|
|
|
|
- Read `reference/torch_row_softmax.py`
|
|
- Read `tasks/02_row_softmax/spec.md`
|
|
- Implement numerically stable row softmax in Triton first
|
|
- Compare against the CUDA skeleton and map the reduction strategy
|
|
|
|
Day 4:
|
|
|
|
- Study `tasks/03_tiled_matmul/spec.md`
|
|
- Draw the tile decomposition on paper
|
|
- Implement one matmul tile path with correctness-only priorities
|
|
|
|
Day 5:
|
|
|
|
- Read `docs/flashattention_notes.md`
|
|
- Read `tasks/04_online_softmax/spec.md`
|
|
- Derive the running max / running sum recurrence informally
|
|
|
|
Day 6:
|
|
|
|
- Inspect `tasks/05_flash_attention_fwd/spec.md`
|
|
- Trace the PyTorch reference line by line
|
|
- Annotate where Q/K/V loads, score computation, normalization, and output accumulation happen
|
|
|
|
Day 7:
|
|
|
|
- Read `docs/profiling_guide.md`
|
|
- Run one benchmark and one profiler command
|
|
- Write down which numbers changed after warmup and synchronization
|
|
|
|
## Recommended TODO Order
|
|
|
|
1. Environment checks
|
|
2. Vector add Triton
|
|
3. Vector add CUDA
|
|
4. Row softmax Triton
|
|
5. Row softmax CUDA
|
|
6. Tiled matmul Triton
|
|
7. Tiled matmul CUDA
|
|
8. Online softmax Triton
|
|
9. Online softmax CUDA
|
|
10. Flash attention forward Triton
|
|
11. Flash attention forward CUDA
|
|
12. PyTorch custom op binding
|
|
13. Profiling passes and benchmark validation
|
|
|
|
## What To Focus On First
|
|
|
|
- Correctness on tiny shapes
|
|
- Clear index math
|
|
- Explicit shape assumptions
|
|
- Numerically stable reductions
|
|
- Repeatable measurement
|
|
|
|
Do not chase peak performance before you can explain the memory traffic and launch geometry of your kernel.
|