2.2 KiB
2.2 KiB
Task 05: Flash Attention Forward
1. Problem Statement
Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA.
2. Expected Input/Output Shapes
Q:[batch, heads, seq_len, head_dim]K:[batch, heads, seq_len, head_dim]V:[batch, heads, seq_len, head_dim]Output:[batch, heads, seq_len, head_dim]
3. Performance Intuition
The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct.
4. Memory Access Discussion
This task is about staged movement:
- load a
Qblock - iterate over
KandVblocks - compute score blocks
- update online normalization
- accumulate the output block
Track where each quantity lives: global memory, registers, or shared memory.
5. What Triton Is Abstracting
Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions.
6. What CUDA Makes Explicit
CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level.
7. Reflection Questions
- How does online softmax avoid writing out the full score matrix?
- Which loop corresponds to iterating over key/value blocks?
- Where do causal masking and normalization interact?
- How does a Triton block pointer map to a CUDA shared-memory load phase?
8. Implementation Checklist
- Confirm the PyTorch reference on tiny shapes
- Trace the online softmax state update
- Implement one Triton blockwise forward path
- Implement one CUDA blockwise forward path
- Test non-causal first, then causal
- Benchmark only after small-shape correctness passes
Explicit Triton To CUDA Mapping
- Triton
program_id(axis=0)for query tiles maps to CUDA query-tile block ownership - Triton
program_id(axis=1)for batch/head maps to a flattened batch-head block index - Triton block pointer math maps to shared-memory staging and pointer arithmetic
- Triton masked edge handling maps to explicit tail checks and mask branches