perf: spread flash bwd dK/dV atomics across all threads

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:27:33 +08:00
parent 4d7b69f8d4
commit d217f4fbd3

View File

@@ -221,8 +221,9 @@ __global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
int valid = i + 1;
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Per-tile ds[c] (one per column), computed by the thread that owns column c.
// Phase 1: per-column ds[c] and p[c] (the column owner does the dots).
__shared__ float s_ds[FA_TILE];
__shared__ float s_p[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
const float* vj = vb + (size_t)(j0 + c) * hd;
@@ -232,17 +233,19 @@ __global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
dpdot += sdo[d] * vj[d];
}
float p = expf(sdot * scale - Li);
float ds = p * (dpdot - Di) * scale;
s_ds[c] = ds;
// dV_j += p · dOᵢ ; dK_j += ds · Qᵢ (accumulated across rows → atomic)
float* dvj = dvb + (size_t)(j0 + c) * hd;
float* dkj = dkb + (size_t)(j0 + c) * hd;
for (int d = 0; d < hd; ++d) {
atomicAdd(&dvj[d], p * sdo[d]);
atomicAdd(&dkj[d], ds * sq[d]);
}
s_p[c] = p;
s_ds[c] = p * (dpdot - Di) * scale;
}
__syncthreads();
// Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic).
// Spread the tile×hd atomics over ALL threads (was serial in the column
// owner) — flatten (c,d) so every thread issues a balanced share.
for (int idx = t; idx < tile * hd; idx += nthreads) {
int c = idx / hd, d = idx % hd;
size_t off = (size_t)(j0 + c) * hd + d;
atomicAdd(&dvb[off], s_p[c] * sdo[d]);
atomicAdd(&dkb[off], s_ds[c] * sq[d]);
}
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
for (int d = t; d < hd; d += nthreads) {
float a = 0.0f;