Files
xtrain/crates/xtrain-tensor
Gahow Wang 326a6fadfe cuda: fused flash-attention kernel (fwd + flash-style bwd)
csrc/ops/flash_attention.cu: a single fused fwd kernel (one block per
query row, streams KV in tiles of 32, online softmax — running max/sum
+ rescaled V accumulator, causal mask inlined, never materializes the
[bh,S,S] scores) writing out[bh,S,hd] + the per-row logsumexp L (O(N),
saved for backward). flash-style bwd: recompute scores from Q/K/V + L,
collapse the softmax Jacobian with D[i]=ΣdO·O, dQ owned per row, dK/dV
atomicAdd across rows. Tensor::flash_attention / flash_attention_backward
wrap them (bf16 upcasts Q/K/V→f32 for the kernel, same fp32-softmax
policy as composed).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-17 23:10:25 +08:00
..