perf: cache softmax weights in shared mem (drop hd× redundant expf)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -127,18 +127,23 @@ __global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)).
|
||||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||||
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
|
||||
// column (instead of recomputing expf inside the per-dim V loop, which
|
||||
// would cost hd× the transcendentals). Sum them for l.
|
||||
float lsum = 0.0f;
|
||||
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new);
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
float p = expf(s_tile[c] - m_new);
|
||||
s_tile[c] = p;
|
||||
lsum += p;
|
||||
}
|
||||
lsum = fa_block_sum(lsum);
|
||||
|
||||
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
|
||||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
float a = acc[d] * corr;
|
||||
for (int c = 0; c < tile; ++c) {
|
||||
float p = expf(s_tile[c] - m_new);
|
||||
a += p * vb[(size_t)(j0 + c) * hd + d];
|
||||
}
|
||||
for (int c = 0; c < tile; ++c)
|
||||
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
|
||||
acc[d] = a;
|
||||
}
|
||||
if (t == 0) {
|
||||
|
||||
Reference in New Issue
Block a user