diff --git a/csrc/ops/flash_attention.cu b/csrc/ops/flash_attention.cu index b5d986e..31e04b8 100644 --- a/csrc/ops/flash_attention.cu +++ b/csrc/ops/flash_attention.cu @@ -127,18 +127,23 @@ __global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V, } __syncthreads(); - // Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)). - // Each thread owns a strided subset of hd; loops over the tile columns. + // Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per + // column (instead of recomputing expf inside the per-dim V loop, which + // would cost hd× the transcendentals). Sum them for l. float lsum = 0.0f; - for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new); + for (int c = t; c < tile; c += nthreads) { + float p = expf(s_tile[c] - m_new); + s_tile[c] = p; + lsum += p; + } lsum = fa_block_sum(lsum); + // Rescale old accumulator + add this tile's p·V (p cached in s_tile). + // Each thread owns a strided subset of hd; loops over the tile columns. for (int d = t; d < hd; d += nthreads) { float a = acc[d] * corr; - for (int c = 0; c < tile; ++c) { - float p = expf(s_tile[c] - m_new); - a += p * vb[(size_t)(j0 + c) * hd + d]; - } + for (int c = 0; c < tile; ++c) + a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d]; acc[d] = a; } if (t == 0) {