perf: cache softmax weights in shared mem (drop hd× redundant expf)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:24:56 +08:00
parent 9b05f4f93f
commit 4d7b69f8d4

View File

@@ -127,18 +127,23 @@ __global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
}
__syncthreads();
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)).
// Each thread owns a strided subset of hd; loops over the tile columns.
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
// column (instead of recomputing expf inside the per-dim V loop, which
// would cost hd× the transcendentals). Sum them for l.
float lsum = 0.0f;
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new);
for (int c = t; c < tile; c += nthreads) {
float p = expf(s_tile[c] - m_new);
s_tile[c] = p;
lsum += p;
}
lsum = fa_block_sum(lsum);
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
// Each thread owns a strided subset of hd; loops over the tile columns.
for (int d = t; d < hd; d += nthreads) {
float a = acc[d] * corr;
for (int c = 0; c < tile; ++c) {
float p = expf(s_tile[c] - m_new);
a += p * vb[(size_t)(j0 + c) * hd + d];
}
for (int c = 0; c < tile; ++c)
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
acc[d] = a;
}
if (t == 0) {