15 lines
620 B
Plaintext
15 lines
620 B
Plaintext
// Workbook-local CUDA sketch for FlashAttention forward.
|
|
//
|
|
// Map this against the Triton sketch:
|
|
// - Triton program_id for query tile -> CUDA block ownership
|
|
// - Triton block pointer loads -> CUDA cooperative global-to-shared loads
|
|
// - Triton masks -> explicit edge and causal checks
|
|
// - Triton implicit block math -> thread/block index arithmetic
|
|
|
|
// TODO(student):
|
|
// 1. Assign a block to one batch/head/query tile.
|
|
// 2. Load a Q tile and loop over K/V tiles.
|
|
// 3. Compute score tiles and causal masking.
|
|
// 4. Update online softmax state.
|
|
// 5. Accumulate the output tile.
|