21 lines
508 B
Python
21 lines
508 B
Python
"""Workbook-local Triton notes for row softmax."""
|
|
|
|
|
|
def notes() -> str:
|
|
return """
|
|
TODO(student):
|
|
1. Decide what one program instance owns: a whole row or a row tile.
|
|
2. Load a row with masking.
|
|
3. Compute row_max = max(x).
|
|
4. Compute exp(x - row_max), then the row sum.
|
|
5. Normalize and store.
|
|
|
|
Reflection:
|
|
- Why does numerical stability matter here more than in vector add?
|
|
- Where does extra memory traffic appear in a naive multi-kernel approach?
|
|
"""
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print(notes())
|