cuda: fused flash-attention kernel (fwd + flash-style bwd)

csrc/ops/flash_attention.cu: a single fused fwd kernel (one block per
query row, streams KV in tiles of 32, online softmax — running max/sum
+ rescaled V accumulator, causal mask inlined, never materializes the
[bh,S,S] scores) writing out[bh,S,hd] + the per-row logsumexp L (O(N),
saved for backward). flash-style bwd: recompute scores from Q/K/V + L,
collapse the softmax Jacobian with D[i]=ΣdO·O, dQ owned per row, dK/dV
atomicAdd across rows. Tensor::flash_attention / flash_attention_backward
wrap them (bf16 upcasts Q/K/V→f32 for the kernel, same fp32-softmax
policy as composed).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 23:10:25 +08:00
parent 65a2264227
commit 326a6fadfe
4 changed files with 440 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ fn main() {
.file("../../csrc/ops/model.cu")
.file("../../csrc/ops/optim.cu")
.file("../../csrc/ops/attention.cu")
.file("../../csrc/ops/flash_attention.cu")
.file("../../csrc/ops/cast.cu")
.compile("xtrain_cuda_kernels");
}

View File

@@ -243,6 +243,59 @@ unsafe extern "C" {
);
}
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
// each for forward/backward that streams over KV tiles with an online softmax and
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
#[cfg(not(no_cuda))]
unsafe extern "C" {
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_fwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
o: *mut f32,
l: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
// before the backward kernel (which takes the precomputed D, not O).
pub fn launch_flash_attention_rowdot_f32(
d_o: *const f32,
o: *const f32,
d_d: *mut f32,
rows: i32,
hd: i32,
s: CudaStream,
);
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
// (dk/dv are accumulated across query rows via atomicAdd).
#[allow(clippy::too_many_arguments)]
pub fn launch_flash_attention_bwd_f32(
q: *const f32,
k: *const f32,
v: *const f32,
d_o: *const f32,
l: *const f32,
d_d: *mut f32,
dq: *mut f32,
dk: *mut f32,
dv: *mut f32,
bh: i32,
seq: i32,
hd: i32,
scale: f32,
s: CudaStream,
);
}
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
// the global grad-norm reduction + in-place rescale (Phase T7).
#[cfg(not(no_cuda))]

View File

@@ -1092,6 +1092,119 @@ impl Tensor {
(dq, dk, dv)
}
// --- Fused flash-attention (the T14 op) ---
/// Fused flash-attention forward (Phase T14). `self`=Q, `k`, `v` each
/// `[bh, seq, head_dim]`, contiguous on one GPU. Computes, per batch element,
/// `out = softmax(causal(Q·Kᵀ·scale))·V` in a SINGLE kernel that streams over
/// KV tiles with an online softmax — the `[bh,seq,seq]` score matrix is NEVER
/// materialized. Returns `(out, lse)` where `lse`:[bh,seq] (F32) is the per-row
/// logsumexp cached for backward (O(N), vs the composed path's O(N²) probs).
///
/// The fused kernel is fp32; for bf16 we upcast Q/K/V → f32 → kernel → downcast
/// `out` back to bf16 (same fp32-softmax policy as the composed [`attention`]),
/// so flash and composed produce the same softmax numerics. `lse` stays fp32.
#[cfg(not(no_cuda))]
pub fn flash_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
assert_eq!(
self.ndim(),
3,
"flash_attention Q must be [bh,seq,head_dim]"
);
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
let dev = self.device();
let dt = self.dtype;
let qf = self.to_dtype(DType::F32);
let kf = k.to_dtype(DType::F32);
let vf = v.to_dtype(DType::F32);
let out_f32 = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let lse = Tensor::zeros(&[bh, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_fwd_f32(
qf.data_ptr() as *const f32,
kf.data_ptr() as *const f32,
vf.data_ptr() as *const f32,
out_f32.data_ptr() as *mut f32,
lse.data_ptr() as *mut f32,
bh as i32,
seq as i32,
hd as i32,
scale,
std::ptr::null_mut(),
);
}
(out_f32.to_dtype(dt), lse)
}
/// Backward of [`flash_attention`](Self::flash_attention). Inputs: forward
/// `q`,`k`,`v`, the forward output `out`, the cached `lse`:[bh,seq], the upstream
/// `dout`, and the same `scale`. Returns `(dq, dk, dv)`.
///
/// flash-style: NO cached probs. Recomputes scores from Q/K/V + `lse`, uses
/// `D[i]=Σ dOᵢ·Oᵢ` to collapse the softmax Jacobian, streams KV in tiles. dQ is
/// owned per query row; dK/dV are accumulated across rows (atomicAdd). Same
/// fp32 kernel; bf16 callers get fp32 grads which the autograd `cast` op casts.
#[cfg(not(no_cuda))]
pub fn flash_attention_backward(
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &Tensor,
lse: &Tensor,
dout: &Tensor,
scale: f32,
) -> (Tensor, Tensor, Tensor) {
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
let dev = q.device();
let dt = q.dtype;
let qf = q.to_dtype(DType::F32);
let kf = k.to_dtype(DType::F32);
let vf = v.to_dtype(DType::F32);
let of = out.to_dtype(DType::F32);
let dof = dout.to_dtype(DType::F32);
// D[i] = Σ_d dO[i,d]·O[i,d] (one scalar per query row, O(N)).
let d = Tensor::zeros(&[bh, seq], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_rowdot_f32(
dof.data_ptr() as *const f32,
of.data_ptr() as *const f32,
d.data_ptr() as *mut f32,
(bh * seq) as i32,
hd as i32,
std::ptr::null_mut(),
);
}
// dq/dk/dv pre-zeroed (Tensor::zeros memsets); dk/dv accumulate via atomicAdd.
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
unsafe {
xtrain_cuda::ffi::launch_flash_attention_bwd_f32(
qf.data_ptr() as *const f32,
kf.data_ptr() as *const f32,
vf.data_ptr() as *const f32,
dof.data_ptr() as *const f32,
lse.data_ptr() as *const f32,
d.data_ptr() as *mut f32,
dq.data_ptr() as *mut f32,
dk.data_ptr() as *mut f32,
dv.data_ptr() as *mut f32,
bh as i32,
seq as i32,
hd as i32,
scale,
std::ptr::null_mut(),
);
}
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
}
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).

273
csrc/ops/flash_attention.cu Normal file
View File

@@ -0,0 +1,273 @@
// Hand-written fused flash-attention (Phase T14).
//
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
//
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
// bh = batch·n_heads. The query's position within its sequence is the row index
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
//
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
// matching the csrc/ layout).
//
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
// accumulated across query rows).
#include <math.h>
extern "C" {
__device__ __forceinline__ float fa_warp_sum(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v += __shfl_down_sync(0xffffffff, v, off);
return v;
}
__device__ __forceinline__ float fa_warp_max(float v) {
#pragma unroll
for (int off = 16; off > 0; off >>= 1)
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
return v;
}
__device__ __forceinline__ float fa_block_sum(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_sum(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
if (warp == 0) v = fa_warp_sum(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
__device__ __forceinline__ float fa_block_max(float v) {
__shared__ float sh[32];
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
int nwarps = (blockDim.x + 31) >> 5;
v = fa_warp_max(v);
if (lane == 0) sh[warp] = v;
__syncthreads();
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
if (warp == 0) v = fa_warp_max(v);
__shared__ float bc;
if (threadIdx.x == 0) bc = v;
__syncthreads();
return bc;
}
#define FA_TILE 32 // KV tile width (columns streamed per step)
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
// row is never stored, only the per-tile partials flow through shared memory.
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
float* O, float* L, int seq, int hd, float scale) {
int row = blockIdx.x; // global query row over bh*S
int b = row / seq; // which (batch,head) block
int i = row % seq; // query position within the sequence (causal limit)
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
// Q row in shared memory (reused every tile); acc accumulator over hd.
extern __shared__ float smem[];
float* sq = smem; // [hd]
float* acc = smem + hd; // [hd]
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
acc[d] = 0.0f;
}
__shared__ float m_run, l_run;
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
__syncthreads();
int valid = i + 1; // causal: attend to keys [0, i]
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Each thread computes whole logits for a strided subset of the tile's
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
// dot loop is cheap; this avoids a block-reduce per column.
__shared__ float s_tile[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
float dot = 0.0f;
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
s_tile[c] = dot * scale;
}
__syncthreads();
// Tile max, then online rescale of (m, l, acc).
float tmax = -INFINITY;
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
tmax = fa_block_max(tmax);
__shared__ float m_new, corr;
if (t == 0) {
float mn = fmaxf(m_run, tmax);
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
m_new = mn;
}
__syncthreads();
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)).
// Each thread owns a strided subset of hd; loops over the tile columns.
float lsum = 0.0f;
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new);
lsum = fa_block_sum(lsum);
for (int d = t; d < hd; d += nthreads) {
float a = acc[d] * corr;
for (int c = 0; c < tile; ++c) {
float p = expf(s_tile[c] - m_new);
a += p * vb[(size_t)(j0 + c) * hd + d];
}
acc[d] = a;
}
if (t == 0) {
l_run = l_run * corr + lsum;
m_run = m_new;
}
__syncthreads();
}
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
float inv = 1.0f / l_run;
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
if (t == 0) L[row] = m_run + logf(l_run);
}
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
float* o, float* l, int bh, int seq, int hd,
float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
}
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
int row = blockIdx.x;
int t = threadIdx.x;
const float* d = dO + (size_t)row * hd;
const float* o = O + (size_t)row * hd;
float v = 0.0f;
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
v = fa_block_sum(v);
if (t == 0) D[row] = v;
}
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
// shared global buffers (pre-zeroed by the caller).
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
// ds_ij = p_ij·(dp_ij - D[i])·scale
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
const float* dO, const float* L, const float* D,
float* dQ, float* dK, float* dV,
int seq, int hd, float scale) {
int row = blockIdx.x;
int b = row / seq;
int i = row % seq;
int t = threadIdx.x;
int nthreads = blockDim.x;
const float* q = Q + (size_t)row * hd;
const float* doi = dO + (size_t)row * hd;
const float* kb = K + (size_t)b * seq * hd;
const float* vb = V + (size_t)b * seq * hd;
float* dkb = dK + (size_t)b * seq * hd;
float* dvb = dV + (size_t)b * seq * hd;
extern __shared__ float smem[];
float* sq = smem; // [hd] Qᵢ
float* sdo = smem + hd; // [hd] dOᵢ
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
for (int d = t; d < hd; d += nthreads) {
sq[d] = q[d];
sdo[d] = doi[d];
dqa[d] = 0.0f;
}
__shared__ float Li, Di;
if (t == 0) { Li = L[row]; Di = D[row]; }
__syncthreads();
int valid = i + 1;
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
int tile = min(FA_TILE, valid - j0);
// Per-tile ds[c] (one per column), computed by the thread that owns column c.
__shared__ float s_ds[FA_TILE];
for (int c = t; c < tile; c += nthreads) {
const float* kj = kb + (size_t)(j0 + c) * hd;
const float* vj = vb + (size_t)(j0 + c) * hd;
float sdot = 0.0f, dpdot = 0.0f;
for (int d = 0; d < hd; ++d) {
sdot += sq[d] * kj[d];
dpdot += sdo[d] * vj[d];
}
float p = expf(sdot * scale - Li);
float ds = p * (dpdot - Di) * scale;
s_ds[c] = ds;
// dV_j += p · dOᵢ ; dK_j += ds · Qᵢ (accumulated across rows → atomic)
float* dvj = dvb + (size_t)(j0 + c) * hd;
float* dkj = dkb + (size_t)(j0 + c) * hd;
for (int d = 0; d < hd; ++d) {
atomicAdd(&dvj[d], p * sdo[d]);
atomicAdd(&dkj[d], ds * sq[d]);
}
}
__syncthreads();
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
for (int d = t; d < hd; d += nthreads) {
float a = 0.0f;
for (int c = 0; c < tile; ++c)
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
dqa[d] += a;
}
__syncthreads();
}
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
}
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
const float* d_o, const float* l, float* d_d,
float* dq, float* dk, float* dv,
int bh, int seq, int hd, float scale, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
}
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
int rows, int hd, void* s) {
int blk = hd < 1024 ? hd : 1024;
if (blk < 32) blk = 32;
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
}
} // extern "C"