Files
2026-04-10 13:22:19 +00:00

2.2 KiB

Task 05: Flash Attention Forward

1. Problem Statement

Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA.

2. Expected Input/Output Shapes

  • Q: [batch, heads, seq_len, head_dim]
  • K: [batch, heads, seq_len, head_dim]
  • V: [batch, heads, seq_len, head_dim]
  • Output: [batch, heads, seq_len, head_dim]

3. Performance Intuition

The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct.

4. Memory Access Discussion

This task is about staged movement:

  • load a Q block
  • iterate over K and V blocks
  • compute score blocks
  • update online normalization
  • accumulate the output block

Track where each quantity lives: global memory, registers, or shared memory.

5. What Triton Is Abstracting

Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions.

6. What CUDA Makes Explicit

CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level.

7. Reflection Questions

  • How does online softmax avoid writing out the full score matrix?
  • Which loop corresponds to iterating over key/value blocks?
  • Where do causal masking and normalization interact?
  • How does a Triton block pointer map to a CUDA shared-memory load phase?

8. Implementation Checklist

  • Confirm the PyTorch reference on tiny shapes
  • Trace the online softmax state update
  • Implement one Triton blockwise forward path
  • Implement one CUDA blockwise forward path
  • Test non-causal first, then causal
  • Benchmark only after small-shape correctness passes

Explicit Triton To CUDA Mapping

  • Triton program_id(axis=0) for query tiles maps to CUDA query-tile block ownership
  • Triton program_id(axis=1) for batch/head maps to a flattened batch-head block index
  • Triton block pointer math maps to shared-memory staging and pointer arithmetic
  • Triton masked edge handling maps to explicit tail checks and mask branches