Initial project scaffold
This commit is contained in:
59
tasks/05_flash_attention_fwd/spec.md
Normal file
59
tasks/05_flash_attention_fwd/spec.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Task 05: Flash Attention Forward
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- `Q`: `[batch, heads, seq_len, head_dim]`
|
||||
- `K`: `[batch, heads, seq_len, head_dim]`
|
||||
- `V`: `[batch, heads, seq_len, head_dim]`
|
||||
- `Output`: `[batch, heads, seq_len, head_dim]`
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
This task is about staged movement:
|
||||
|
||||
- load a `Q` block
|
||||
- iterate over `K` and `V` blocks
|
||||
- compute score blocks
|
||||
- update online normalization
|
||||
- accumulate the output block
|
||||
|
||||
Track where each quantity lives: global memory, registers, or shared memory.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- How does online softmax avoid writing out the full score matrix?
|
||||
- Which loop corresponds to iterating over key/value blocks?
|
||||
- Where do causal masking and normalization interact?
|
||||
- How does a Triton block pointer map to a CUDA shared-memory load phase?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Confirm the PyTorch reference on tiny shapes
|
||||
- Trace the online softmax state update
|
||||
- Implement one Triton blockwise forward path
|
||||
- Implement one CUDA blockwise forward path
|
||||
- Test non-causal first, then causal
|
||||
- Benchmark only after small-shape correctness passes
|
||||
|
||||
## Explicit Triton To CUDA Mapping
|
||||
|
||||
- Triton `program_id(axis=0)` for query tiles maps to CUDA query-tile block ownership
|
||||
- Triton `program_id(axis=1)` for batch/head maps to a flattened batch-head block index
|
||||
- Triton block pointer math maps to shared-memory staging and pointer arithmetic
|
||||
- Triton masked edge handling maps to explicit tail checks and mask branches
|
||||
Reference in New Issue
Block a user