20 lines
440 B
Python
20 lines
440 B
Python
"""Workbook-local Triton notes for FlashAttention forward."""
|
|
|
|
|
|
def notes() -> str:
|
|
return """
|
|
TODO(student):
|
|
1. Assign one program instance to one query block for one batch/head.
|
|
2. Load a Q block.
|
|
3. Iterate over K/V blocks.
|
|
4. Compute score blocks.
|
|
5. Apply optional causal masking.
|
|
6. Update running max and running sum.
|
|
7. Accumulate the output block.
|
|
8. Store the final output.
|
|
"""
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(notes())
|