perf: cache softmax weights in shared mem (drop hd× redundant expf)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -127,18 +127,23 @@ __global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
|
|||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)).
|
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
|
||||||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
// column (instead of recomputing expf inside the per-dim V loop, which
|
||||||
|
// would cost hd× the transcendentals). Sum them for l.
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new);
|
for (int c = t; c < tile; c += nthreads) {
|
||||||
|
float p = expf(s_tile[c] - m_new);
|
||||||
|
s_tile[c] = p;
|
||||||
|
lsum += p;
|
||||||
|
}
|
||||||
lsum = fa_block_sum(lsum);
|
lsum = fa_block_sum(lsum);
|
||||||
|
|
||||||
|
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
|
||||||
|
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||||||
for (int d = t; d < hd; d += nthreads) {
|
for (int d = t; d < hd; d += nthreads) {
|
||||||
float a = acc[d] * corr;
|
float a = acc[d] * corr;
|
||||||
for (int c = 0; c < tile; ++c) {
|
for (int c = 0; c < tile; ++c)
|
||||||
float p = expf(s_tile[c] - m_new);
|
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
|
||||||
a += p * vb[(size_t)(j0 + c) * hd + d];
|
|
||||||
}
|
|
||||||
acc[d] = a;
|
acc[d] = a;
|
||||||
}
|
}
|
||||||
if (t == 0) {
|
if (t == 0) {
|
||||||
|
|||||||
Reference in New Issue
Block a user