282 lines
12 KiB
Plaintext
282 lines
12 KiB
Plaintext
// Hand-written fused flash-attention (Phase T14).
|
||
//
|
||
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
|
||
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
|
||
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
|
||
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
|
||
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
|
||
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
|
||
//
|
||
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
|
||
// bh = batch·n_heads. The query's position within its sequence is the row index
|
||
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
|
||
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
|
||
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
|
||
//
|
||
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
|
||
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
|
||
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
|
||
// matching the csrc/ layout).
|
||
//
|
||
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
|
||
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
|
||
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
|
||
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
|
||
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
|
||
// accumulated across query rows).
|
||
|
||
#include <math.h>
|
||
|
||
extern "C" {
|
||
|
||
__device__ __forceinline__ float fa_warp_sum(float v) {
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1)
|
||
v += __shfl_down_sync(0xffffffff, v, off);
|
||
return v;
|
||
}
|
||
__device__ __forceinline__ float fa_warp_max(float v) {
|
||
#pragma unroll
|
||
for (int off = 16; off > 0; off >>= 1)
|
||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||
return v;
|
||
}
|
||
__device__ __forceinline__ float fa_block_sum(float v) {
|
||
__shared__ float sh[32];
|
||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||
int nwarps = (blockDim.x + 31) >> 5;
|
||
v = fa_warp_sum(v);
|
||
if (lane == 0) sh[warp] = v;
|
||
__syncthreads();
|
||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
||
if (warp == 0) v = fa_warp_sum(v);
|
||
__shared__ float bc;
|
||
if (threadIdx.x == 0) bc = v;
|
||
__syncthreads();
|
||
return bc;
|
||
}
|
||
__device__ __forceinline__ float fa_block_max(float v) {
|
||
__shared__ float sh[32];
|
||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||
int nwarps = (blockDim.x + 31) >> 5;
|
||
v = fa_warp_max(v);
|
||
if (lane == 0) sh[warp] = v;
|
||
__syncthreads();
|
||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
||
if (warp == 0) v = fa_warp_max(v);
|
||
__shared__ float bc;
|
||
if (threadIdx.x == 0) bc = v;
|
||
__syncthreads();
|
||
return bc;
|
||
}
|
||
|
||
#define FA_TILE 32 // KV tile width (columns streamed per step)
|
||
|
||
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
|
||
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
|
||
// row is never stored, only the per-tile partials flow through shared memory.
|
||
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
|
||
float* O, float* L, int seq, int hd, float scale) {
|
||
int row = blockIdx.x; // global query row over bh*S
|
||
int b = row / seq; // which (batch,head) block
|
||
int i = row % seq; // query position within the sequence (causal limit)
|
||
int t = threadIdx.x;
|
||
int nthreads = blockDim.x;
|
||
|
||
const float* q = Q + (size_t)row * hd;
|
||
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
|
||
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
|
||
|
||
// Q row in shared memory (reused every tile); acc accumulator over hd.
|
||
extern __shared__ float smem[];
|
||
float* sq = smem; // [hd]
|
||
float* acc = smem + hd; // [hd]
|
||
for (int d = t; d < hd; d += nthreads) {
|
||
sq[d] = q[d];
|
||
acc[d] = 0.0f;
|
||
}
|
||
__shared__ float m_run, l_run;
|
||
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
|
||
__syncthreads();
|
||
|
||
int valid = i + 1; // causal: attend to keys [0, i]
|
||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||
int tile = min(FA_TILE, valid - j0);
|
||
// Each thread computes whole logits for a strided subset of the tile's
|
||
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
|
||
// dot loop is cheap; this avoids a block-reduce per column.
|
||
__shared__ float s_tile[FA_TILE];
|
||
for (int c = t; c < tile; c += nthreads) {
|
||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||
float dot = 0.0f;
|
||
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
|
||
s_tile[c] = dot * scale;
|
||
}
|
||
__syncthreads();
|
||
|
||
// Tile max, then online rescale of (m, l, acc).
|
||
float tmax = -INFINITY;
|
||
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
|
||
tmax = fa_block_max(tmax);
|
||
|
||
__shared__ float m_new, corr;
|
||
if (t == 0) {
|
||
float mn = fmaxf(m_run, tmax);
|
||
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
|
||
m_new = mn;
|
||
}
|
||
__syncthreads();
|
||
|
||
// Overwrite s_tile with the softmax weights p = exp(s - m_new) ONCE per
|
||
// column (instead of recomputing expf inside the per-dim V loop, which
|
||
// would cost hd× the transcendentals). Sum them for l.
|
||
float lsum = 0.0f;
|
||
for (int c = t; c < tile; c += nthreads) {
|
||
float p = expf(s_tile[c] - m_new);
|
||
s_tile[c] = p;
|
||
lsum += p;
|
||
}
|
||
lsum = fa_block_sum(lsum);
|
||
|
||
// Rescale old accumulator + add this tile's p·V (p cached in s_tile).
|
||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||
for (int d = t; d < hd; d += nthreads) {
|
||
float a = acc[d] * corr;
|
||
for (int c = 0; c < tile; ++c)
|
||
a += s_tile[c] * vb[(size_t)(j0 + c) * hd + d];
|
||
acc[d] = a;
|
||
}
|
||
if (t == 0) {
|
||
l_run = l_run * corr + lsum;
|
||
m_run = m_new;
|
||
}
|
||
__syncthreads();
|
||
}
|
||
|
||
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
|
||
float inv = 1.0f / l_run;
|
||
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
|
||
if (t == 0) L[row] = m_run + logf(l_run);
|
||
}
|
||
|
||
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
|
||
float* o, float* l, int bh, int seq, int hd,
|
||
float scale, void* s) {
|
||
int blk = hd < 1024 ? hd : 1024;
|
||
if (blk < 32) blk = 32;
|
||
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
|
||
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
|
||
}
|
||
|
||
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
|
||
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
|
||
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
|
||
int row = blockIdx.x;
|
||
int t = threadIdx.x;
|
||
const float* d = dO + (size_t)row * hd;
|
||
const float* o = O + (size_t)row * hd;
|
||
float v = 0.0f;
|
||
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
|
||
v = fa_block_sum(v);
|
||
if (t == 0) D[row] = v;
|
||
}
|
||
|
||
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
|
||
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
|
||
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
|
||
// shared global buffers (pre-zeroed by the caller).
|
||
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
|
||
// ds_ij = p_ij·(dp_ij - D[i])·scale
|
||
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
|
||
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
|
||
const float* dO, const float* L, const float* D,
|
||
float* dQ, float* dK, float* dV,
|
||
int seq, int hd, float scale) {
|
||
int row = blockIdx.x;
|
||
int b = row / seq;
|
||
int i = row % seq;
|
||
int t = threadIdx.x;
|
||
int nthreads = blockDim.x;
|
||
|
||
const float* q = Q + (size_t)row * hd;
|
||
const float* doi = dO + (size_t)row * hd;
|
||
const float* kb = K + (size_t)b * seq * hd;
|
||
const float* vb = V + (size_t)b * seq * hd;
|
||
float* dkb = dK + (size_t)b * seq * hd;
|
||
float* dvb = dV + (size_t)b * seq * hd;
|
||
|
||
extern __shared__ float smem[];
|
||
float* sq = smem; // [hd] Qᵢ
|
||
float* sdo = smem + hd; // [hd] dOᵢ
|
||
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
|
||
for (int d = t; d < hd; d += nthreads) {
|
||
sq[d] = q[d];
|
||
sdo[d] = doi[d];
|
||
dqa[d] = 0.0f;
|
||
}
|
||
__shared__ float Li, Di;
|
||
if (t == 0) { Li = L[row]; Di = D[row]; }
|
||
__syncthreads();
|
||
|
||
int valid = i + 1;
|
||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||
int tile = min(FA_TILE, valid - j0);
|
||
// Phase 1: per-column ds[c] and p[c] (the column owner does the dots).
|
||
__shared__ float s_ds[FA_TILE];
|
||
__shared__ float s_p[FA_TILE];
|
||
for (int c = t; c < tile; c += nthreads) {
|
||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||
const float* vj = vb + (size_t)(j0 + c) * hd;
|
||
float sdot = 0.0f, dpdot = 0.0f;
|
||
for (int d = 0; d < hd; ++d) {
|
||
sdot += sq[d] * kj[d];
|
||
dpdot += sdo[d] * vj[d];
|
||
}
|
||
float p = expf(sdot * scale - Li);
|
||
s_p[c] = p;
|
||
s_ds[c] = p * (dpdot - Di) * scale;
|
||
}
|
||
__syncthreads();
|
||
// Phase 2: dV_j += p·dOᵢ ; dK_j += ds·Qᵢ (accumulated across rows → atomic).
|
||
// Spread the tile×hd atomics over ALL threads (was serial in the column
|
||
// owner) — flatten (c,d) so every thread issues a balanced share.
|
||
for (int idx = t; idx < tile * hd; idx += nthreads) {
|
||
int c = idx / hd, d = idx % hd;
|
||
size_t off = (size_t)(j0 + c) * hd + d;
|
||
atomicAdd(&dvb[off], s_p[c] * sdo[d]);
|
||
atomicAdd(&dkb[off], s_ds[c] * sq[d]);
|
||
}
|
||
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
|
||
for (int d = t; d < hd; d += nthreads) {
|
||
float a = 0.0f;
|
||
for (int c = 0; c < tile; ++c)
|
||
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
|
||
dqa[d] += a;
|
||
}
|
||
__syncthreads();
|
||
}
|
||
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
|
||
}
|
||
|
||
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
|
||
const float* d_o, const float* l, float* d_d,
|
||
float* dq, float* dk, float* dv,
|
||
int bh, int seq, int hd, float scale, void* s) {
|
||
int blk = hd < 1024 ? hd : 1024;
|
||
if (blk < 32) blk = 32;
|
||
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
|
||
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
|
||
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
|
||
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
|
||
}
|
||
|
||
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
|
||
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
|
||
int rows, int hd, void* s) {
|
||
int blk = hd < 1024 ? hd : 1024;
|
||
if (blk < 32) blk = 32;
|
||
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
|
||
}
|
||
|
||
} // extern "C"
|