Initial project scaffold
This commit is contained in:
2
tasks/05_flash_attention_fwd/__init__.py
Normal file
2
tasks/05_flash_attention_fwd/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Flash attention forward task."""
|
||||
|
||||
51
tasks/05_flash_attention_fwd/bench.py
Normal file
51
tasks/05_flash_attention_fwd/bench.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import torch
|
||||
|
||||
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
||||
from reference.torch_attention import torch_attention
|
||||
|
||||
|
||||
def benchmark(fn, *args, warmup: int = 5, reps: int = 20, **kwargs) -> float:
|
||||
for _ in range(warmup):
|
||||
fn(*args, **kwargs)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms = []
|
||||
for _ in range(reps):
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
fn(*args, **kwargs)
|
||||
if args[0].is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
times_ms.append((time.perf_counter() - start) * 1e3)
|
||||
return statistics.median(times_ms)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
q = torch.randn(2, 8, 128, 64, device=device)
|
||||
k = torch.randn(2, 8, 128, 64, device=device)
|
||||
v = torch.randn(2, 8, 128, 64, device=device)
|
||||
ref_ms = benchmark(torch_attention, q, k, v, causal=False)
|
||||
print(f"torch_attention: {ref_ms:.3f} ms")
|
||||
if device == "cuda":
|
||||
try:
|
||||
triton_ms = benchmark(triton_flash_attention_fwd, q, k, v, causal=False)
|
||||
print(f"triton_flash_attention_fwd: {triton_ms:.3f} ms")
|
||||
except (NotImplementedError, RuntimeError) as exc:
|
||||
print(f"triton_flash_attention_fwd: skipped ({exc})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
14
tasks/05_flash_attention_fwd/cuda_skeleton.cu
Normal file
14
tasks/05_flash_attention_fwd/cuda_skeleton.cu
Normal file
@@ -0,0 +1,14 @@
|
||||
// Workbook-local CUDA sketch for FlashAttention forward.
|
||||
//
|
||||
// Map this against the Triton sketch:
|
||||
// - Triton program_id for query tile -> CUDA block ownership
|
||||
// - Triton block pointer loads -> CUDA cooperative global-to-shared loads
|
||||
// - Triton masks -> explicit edge and causal checks
|
||||
// - Triton implicit block math -> thread/block index arithmetic
|
||||
|
||||
// TODO(student):
|
||||
// 1. Assign a block to one batch/head/query tile.
|
||||
// 2. Load a Q tile and loop over K/V tiles.
|
||||
// 3. Compute score tiles and causal masking.
|
||||
// 4. Update online softmax state.
|
||||
// 5. Accumulate the output tile.
|
||||
59
tasks/05_flash_attention_fwd/spec.md
Normal file
59
tasks/05_flash_attention_fwd/spec.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Task 05: Flash Attention Forward
|
||||
|
||||
## 1. Problem Statement
|
||||
|
||||
Implement a learning-oriented forward-only FlashAttention-style kernel in both Triton and CUDA.
|
||||
|
||||
## 2. Expected Input/Output Shapes
|
||||
|
||||
- `Q`: `[batch, heads, seq_len, head_dim]`
|
||||
- `K`: `[batch, heads, seq_len, head_dim]`
|
||||
- `V`: `[batch, heads, seq_len, head_dim]`
|
||||
- `Output`: `[batch, heads, seq_len, head_dim]`
|
||||
|
||||
## 3. Performance Intuition
|
||||
|
||||
The goal is to reduce memory traffic by avoiding full materialization of the score matrix. Correctness comes first. Performance work only matters after the blockwise algorithm is correct.
|
||||
|
||||
## 4. Memory Access Discussion
|
||||
|
||||
This task is about staged movement:
|
||||
|
||||
- load a `Q` block
|
||||
- iterate over `K` and `V` blocks
|
||||
- compute score blocks
|
||||
- update online normalization
|
||||
- accumulate the output block
|
||||
|
||||
Track where each quantity lives: global memory, registers, or shared memory.
|
||||
|
||||
## 5. What Triton Is Abstracting
|
||||
|
||||
Triton makes block pointers, program IDs, and masked block operations compact. Those abstractions still correspond to explicit memory ownership decisions.
|
||||
|
||||
## 6. What CUDA Makes Explicit
|
||||
|
||||
CUDA exposes thread-block mapping, shared-memory staging, synchronization, and reduction details directly. This is where the same algorithm becomes visibly lower level.
|
||||
|
||||
## 7. Reflection Questions
|
||||
|
||||
- How does online softmax avoid writing out the full score matrix?
|
||||
- Which loop corresponds to iterating over key/value blocks?
|
||||
- Where do causal masking and normalization interact?
|
||||
- How does a Triton block pointer map to a CUDA shared-memory load phase?
|
||||
|
||||
## 8. Implementation Checklist
|
||||
|
||||
- Confirm the PyTorch reference on tiny shapes
|
||||
- Trace the online softmax state update
|
||||
- Implement one Triton blockwise forward path
|
||||
- Implement one CUDA blockwise forward path
|
||||
- Test non-causal first, then causal
|
||||
- Benchmark only after small-shape correctness passes
|
||||
|
||||
## Explicit Triton To CUDA Mapping
|
||||
|
||||
- Triton `program_id(axis=0)` for query tiles maps to CUDA query-tile block ownership
|
||||
- Triton `program_id(axis=1)` for batch/head maps to a flattened batch-head block index
|
||||
- Triton block pointer math maps to shared-memory staging and pointer arithmetic
|
||||
- Triton masked edge handling maps to explicit tail checks and mask branches
|
||||
49
tasks/05_flash_attention_fwd/test_task.py
Normal file
49
tasks/05_flash_attention_fwd/test_task.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels.triton.flash_attention_fwd import triton_flash_attention_fwd
|
||||
from reference.torch_attention import torch_attention
|
||||
|
||||
|
||||
def _run_impl_or_skip(fn, *args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pytest.skip("implementation is still TODO")
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_attention_reference_small_shape():
|
||||
q = torch.randn(1, 2, 8, 16)
|
||||
k = torch.randn(1, 2, 8, 16)
|
||||
v = torch.randn(1, 2, 8, 16)
|
||||
out = torch_attention(q, k, v, causal=False)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.reference
|
||||
def test_attention_reference_causal_small_shape():
|
||||
q = torch.randn(1, 1, 8, 16)
|
||||
k = torch.randn(1, 1, 8, 16)
|
||||
v = torch.randn(1, 1, 8, 16)
|
||||
out = torch_attention(q, k, v, causal=True)
|
||||
expected = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.triton_required
|
||||
@pytest.mark.skeleton
|
||||
def test_triton_flash_attention_if_available():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available")
|
||||
q = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
k = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
v = torch.randn(1, 2, 16, 32, device="cuda")
|
||||
out = _run_impl_or_skip(triton_flash_attention_fwd, q, k, v, causal=False)
|
||||
expected = torch_attention(q, k, v, causal=False)
|
||||
torch.testing.assert_close(out, expected, atol=2e-3, rtol=2e-3)
|
||||
19
tasks/05_flash_attention_fwd/triton_skeleton.py
Normal file
19
tasks/05_flash_attention_fwd/triton_skeleton.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Workbook-local Triton notes for FlashAttention forward."""
|
||||
|
||||
|
||||
def notes() -> str:
|
||||
return """
|
||||
TODO(student):
|
||||
1. Assign one program instance to one query block for one batch/head.
|
||||
2. Load a Q block.
|
||||
3. Iterate over K/V blocks.
|
||||
4. Compute score blocks.
|
||||
5. Apply optional causal masking.
|
||||
6. Update running max and running sum.
|
||||
7. Accumulate the output block.
|
||||
8. Store the final output.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(notes())
|
||||
Reference in New Issue
Block a user