perf: cache softmax weights in shared mem (drop hd× redundant expf)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:24:56 +08:00
parent 9b05f4f93f
commit 4d7b69f8d4

View File

@@ -127,18 +127,23 @@ __global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
} }
__syncthreads(); __syncthreads();
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)). // Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
// Each thread owns a strided subset of hd; loops over the tile columns. // column (instead of recomputing expf inside the per-dim V loop, which
// would cost hd× the transcendentals). Sum them for l.
float lsum = 0.0f; float lsum = 0.0f;
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new); for (int c = t; c < tile; c += nthreads) {
float p = expf(s_tile[c] - m_new);
s_tile[c] = p;
lsum += p;
}
lsum = fa_block_sum(lsum); lsum = fa_block_sum(lsum);
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
// Each thread owns a strided subset of hd; loops over the tile columns.
for (int d = t; d < hd; d += nthreads) { for (int d = t; d < hd; d += nthreads) {
float a = acc[d] * corr; float a = acc[d] * corr;
for (int c = 0; c < tile; ++c) { for (int c = 0; c < tile; ++c)
float p = expf(s_tile[c] - m_new); a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
a += p * vb[(size_t)(j0 + c) * hd + d];
}
acc[d] = a; acc[d] = a;
} }
if (t == 0) { if (t == 0) {