"""Workbook-local Triton notes for FlashAttention forward.""" def notes() -> str: return """ TODO(student): 1. Assign one program instance to one query block for one batch/head. 2. Load a Q block. 3. Iterate over K/V blocks. 4. Compute score blocks. 5. Apply optional causal masking. 6. Update running max and running sum. 7. Accumulate the output block. 8. Store the final output. """ if __name__ == "__main__": print(notes())