# Task 05: Flash Attention Forward ## 1. Problem Statement Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA. ## 2. Expected Input/Output Shapes - `Q`: `[batch, heads, seq_len, head_dim]` - `K`: `[batch, heads, seq_len, head_dim]` - `V`: `[batch, heads, seq_len, head_dim]` - `Output`: `[batch, heads, seq_len, head_dim]` ## 3. Performance Intuition The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct. ## 4. Memory Access Discussion This task is about staged movement: - load a `Q` block - iterate over `K` and `V` blocks - compute score blocks - update online normalization - accumulate the output block Track where each quantity lives: global memory, registers, or shared memory. ## 5. What Triton Is Abstracting Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions. ## 6. What CUDA Makes Explicit CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level. ## 7. Reflection Questions - How does online softmax avoid writing out the full score matrix? - Which loop corresponds to iterating over key/value blocks? - Where do causal masking and normalization interact? - How does a Triton block pointer map to a CUDA shared-memory load phase? ## 8. Implementation Checklist - Confirm the PyTorch reference on tiny shapes - Trace the online softmax state update - Implement one Triton blockwise forward path - Implement one CUDA blockwise forward path - Test non-causal first, then causal - Benchmark only after small-shape correctness passes ## Explicit Triton To CUDA Mapping - Triton `program_id(axis=0)` for query tiles maps to CUDA query-tile block ownership - Triton `program_id(axis=1)` for batch/head maps to a flattened batch-head block index - Triton block pointer math maps to shared-memory staging and pointer arithmetic - Triton masked edge handling maps to explicit tail checks and mask branches