Files
xtrain/csrc/ops/flash_attention.cu
2026-06-17 23:27:33 +08:00

282 lines
12 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Hand-written fused flash-attention (Phase T14).
//
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
//
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
// bh = batch·n_heads. The query's position within its sequence is the row index
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
//
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
// matching the csrc/ layout).
//
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
// accumulated across query rows).
#include <math.h>
extern "C" {
__device__ __forceinline__ float fa_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float fa_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float fa_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = fa_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float fa_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = fa_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
#define FA_TILE 32 // KV tile width (columns streamed per step)
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
// row is never stored, only the per-tile partials flow through shared memory.
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
float* O, float* L, int seq, int hd, float scale) {
int row = blockIdx.x; // global query row over bh*S
int b = row / seq; // which (batch,head) block
int i = row % seq; // query position within the sequence (causal limit)
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
// Q row in shared memory (reused every tile); acc accumulator over hd.
extern __shared__ float smem[];
float* sq = smem; // [hd]
float* acc = smem + hd; // [hd]
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
acc[d] = 0.0f;
}
__shared__ float m_run, l_run;
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
__syncthreads();
int valid = i + 1; // causal: attend to keys [0, i]
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Each thread computes whole logits for a strided subset of the tile's
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
// dot loop is cheap; this avoids a block-reduce per column.
__shared__ float s_tile[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
float dot = 0.0f;
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
s_tile[c] = dot * scale;
}
__syncthreads();
// Tile max, then online rescale of (m, l, acc).
float tmax = -INFINITY;
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
tmax = fa_block_max(tmax);
__shared__ float m_new, corr;
if (t == 0) {
float mn = fmaxf(m_run, tmax);
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
m_new = mn;
}
__syncthreads();
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
// column (instead of recomputing expf inside the per-dim V loop, which
// would cost hd× the transcendentals). Sum them for l.
float lsum = 0.0f;
for (int c = t; c < tile; c += nthreads) {
float p = expf(s_tile[c] - m_new);
s_tile[c] = p;
lsum += p;
}
lsum = fa_block_sum(lsum);
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
// Each thread owns a strided subset of hd; loops over the tile columns.
for (int d = t; d < hd; d += nthreads) {
float a = acc[d] * corr;
for (int c = 0; c < tile; ++c)
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
acc[d] = a;
}
if (t == 0) {
l_run = l_run * corr + lsum;
m_run = m_new;
}
__syncthreads();
}
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
float inv = 1.0f / l_run;
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
if (t == 0) L[row] = m_run + logf(l_run);
}
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
float* o, float* l, int bh, int seq, int hd,
float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
}
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
int row = blockIdx.x;
int t = threadIdx.x;
const float* d = dO + (size_t)row * hd;
const float* o = O + (size_t)row * hd;
float v = 0.0f;
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
v = fa_block_sum(v);
if (t == 0) D[row] = v;
}
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
// shared global buffers (pre-zeroed by the caller).
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
// ds_ij = p_ij·(dp_ij - D[i])·scale
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
const float* dO, const float* L, const float* D,
float* dQ, float* dK, float* dV,
int seq, int hd, float scale) {
int row = blockIdx.x;
int b = row / seq;
int i = row % seq;
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* doi = dO + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd;
const float* vb = V + (size_t)b * seq * hd;
float* dkb = dK + (size_t)b * seq * hd;
float* dvb = dV + (size_t)b * seq * hd;
extern __shared__ float smem[];
float* sq = smem; // [hd] Qᵢ
float* sdo = smem + hd; // [hd] dOᵢ
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
sdo[d] = doi[d];
dqa[d] = 0.0f;
}
__shared__ float Li, Di;
if (t == 0) { Li = L[row]; Di = D[row]; }
__syncthreads();
int valid = i + 1;
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Phase 1: per-column ds[c] and p[c] (the column owner does the dots).
__shared__ float s_ds[FA_TILE];
__shared__ float s_p[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
const float* vj = vb + (size_t)(j0 + c) * hd;
float sdot = 0.0f, dpdot = 0.0f;
for (int d = 0; d < hd; ++d) {
sdot += sq[d] * kj[d];
dpdot += sdo[d] * vj[d];
}
float p = expf(sdot * scale - Li);
s_p[c] = p;
s_ds[c] = p * (dpdot - Di) * scale;
}
__syncthreads();
// Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic).
// Spread the tile×hd atomics over ALL threads (was serial in the column
// owner) — flatten (c,d) so every thread issues a balanced share.
for (int idx = t; idx < tile * hd; idx += nthreads) {
int c = idx / hd, d = idx % hd;
size_t off = (size_t)(j0 + c) * hd + d;
atomicAdd(&dvb[off], s_p[c] * sdo[d]);
atomicAdd(&dkb[off], s_ds[c] * sq[d]);
}
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
for (int d = t; d < hd; d += nthreads) {
float a = 0.0f;
for (int c = 0; c < tile; ++c)
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
dqa[d] += a;
}
__syncthreads();
}
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
}
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
const float* d_o, const float* l, float* d_d,
float* dq, float* dk, float* dv,
int bh, int seq, int hd, float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
}
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
int rows, int hd, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
}
} // extern "C"