12 lines
412 B
Plaintext
12 lines
412 B
Plaintext
// Workbook-local CUDA sketch for row softmax.
|
|
//
|
|
// Reflection prompt:
|
|
// Softmax is usually bandwidth-bound because the math is cheap but the rows are read and written a lot.
|
|
// Keep track of how many global-memory passes your implementation needs.
|
|
|
|
// TODO(student):
|
|
// 1. Assign one block or block tile to a row.
|
|
// 2. Compute the row max.
|
|
// 3. Compute the sum of exp(x - row_max).
|
|
// 4. Normalize the row.
|