diff --git a/csrc/ops/flash_attention.cu b/csrc/ops/flash_attention.cu index 31e04b8..0e01cd6 100644 --- a/csrc/ops/flash_attention.cu +++ b/csrc/ops/flash_attention.cu @@ -221,8 +221,9 @@ __global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V, int valid = i + 1; for (int j0 = 0; j0 < valid; j0 += FA_TILE) { int tile = min(FA_TILE, valid - j0); - // Per-tile ds[c] (one per column), computed by the thread that owns column c. + // Phase 1: per-column ds[c] and p[c] (the column owner does the dots). __shared__ float s_ds[FA_TILE]; + __shared__ float s_p[FA_TILE]; for (int c = t; c < tile; c += nthreads) { const float* kj = kb + (size_t)(j0 + c) * hd; const float* vj = vb + (size_t)(j0 + c) * hd; @@ -232,17 +233,19 @@ __global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V, dpdot += sdo[d] * vj[d]; } float p = expf(sdot * scale - Li); - float ds = p * (dpdot - Di) * scale; - s_ds[c] = ds; - // dV_j += p · dOᵢ ; dK_j += ds · Qᵢ (accumulated across rows → atomic) - float* dvj = dvb + (size_t)(j0 + c) * hd; - float* dkj = dkb + (size_t)(j0 + c) * hd; - for (int d = 0; d < hd; ++d) { - atomicAdd(&dvj[d], p * sdo[d]); - atomicAdd(&dkj[d], ds * sq[d]); - } + s_p[c] = p; + s_ds[c] = p * (dpdot - Di) * scale; } __syncthreads(); + // Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic). + // Spread the tile×hd atomics over ALL threads (was serial in the column + // owner) — flatten (c,d) so every thread issues a balanced share. + for (int idx = t; idx < tile * hd; idx += nthreads) { + int c = idx / hd, d = idx % hd; + size_t off = (size_t)(j0 + c) * hd + d; + atomicAdd(&dvb[off], s_p[c] * sdo[d]); + atomicAdd(&dkb[off], s_ds[c] * sq[d]); + } // dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic). for (int d = t; d < hd; d += nthreads) { float a = 0.0f;