Files
kernel-lab/tasks/05_flash_attention_fwd/triton_skeleton.py
2026-04-10 13:22:19 +00:00

20 lines
440 B
Python

"""Workbook-local Triton notes for FlashAttention forward."""
def notes() -> str:
return """
TODO(student):
1. Assign one program instance to one query block for one batch/head.
2. Load a Q block.
3. Iterate over K/V blocks.
4. Compute score blocks.
5. Apply optional causal masking.
6. Update running max and running sum.
7. Accumulate the output block.
8. Store the final output.
"""
if __name__ == "__main__":
print(notes())