cuda: fused flash-attention kernel (fwd + flash-style bwd)
csrc/ops/flash_attention.cu: a single fused fwd kernel (one block per query row, streams KV in tiles of 32, online softmax — running max/sum + rescaled V accumulator, causal mask inlined, never materializes the [bh,S,S] scores) writing out[bh,S,hd] + the per-row logsumexp L (O(N), saved for backward). flash-style bwd: recompute scores from Q/K/V + L, collapse the softmax Jacobian with D[i]=ΣdO·O, dQ owned per row, dK/dV atomicAdd across rows. Tensor::flash_attention / flash_attention_backward wrap them (bf16 upcasts Q/K/V→f32 for the kernel, same fp32-softmax policy as composed). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -36,6 +36,7 @@ fn main() {
|
||||
.file("../../csrc/ops/model.cu")
|
||||
.file("../../csrc/ops/optim.cu")
|
||||
.file("../../csrc/ops/attention.cu")
|
||||
.file("../../csrc/ops/flash_attention.cu")
|
||||
.file("../../csrc/ops/cast.cu")
|
||||
.compile("xtrain_cuda_kernels");
|
||||
}
|
||||
|
||||
@@ -243,6 +243,59 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
|
||||
// Fused flash-attention (csrc/ops/flash_attention.cu, Phase T14). A SINGLE kernel
|
||||
// each for forward/backward that streams over KV tiles with an online softmax and
|
||||
// NEVER materializes the [bh,S,S] score matrix. Q/K/V/out are [bh,S,hd] row-major
|
||||
// F32; the forward saves only the per-row logsumexp `l` ([bh*S], O(N)) for backward.
|
||||
#[cfg(not(no_cuda))]
|
||||
unsafe extern "C" {
|
||||
// Forward: o[bh,S,hd] = softmax(causal(Q·Kᵀ·scale))·V, online over KV tiles.
|
||||
// Also writes l[bh*S] = per-row logsumexp (saved for backward, not the scores).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launch_flash_attention_fwd_f32(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
o: *mut f32,
|
||||
l: *mut f32,
|
||||
bh: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Per-row D[i]=Σ_d dO[i,d]·O[i,d] over `rows`=bh*S rows of width `hd`. Must run
|
||||
// before the backward kernel (which takes the precomputed D, not O).
|
||||
pub fn launch_flash_attention_rowdot_f32(
|
||||
d_o: *const f32,
|
||||
o: *const f32,
|
||||
d_d: *mut f32,
|
||||
rows: i32,
|
||||
hd: i32,
|
||||
s: CudaStream,
|
||||
);
|
||||
// Backward: recomputes scores from Q/K/V + saved logsumexp `l` (NO cached probs)
|
||||
// and the precomputed `d_d` (= D), produces dq/dk/dv. dq/dk/dv must be PRE-ZEROED
|
||||
// (dk/dv are accumulated across query rows via atomicAdd).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn launch_flash_attention_bwd_f32(
|
||||
q: *const f32,
|
||||
k: *const f32,
|
||||
v: *const f32,
|
||||
d_o: *const f32,
|
||||
l: *const f32,
|
||||
d_d: *mut f32,
|
||||
dq: *mut f32,
|
||||
dk: *mut f32,
|
||||
dv: *mut f32,
|
||||
bh: i32,
|
||||
seq: i32,
|
||||
hd: i32,
|
||||
scale: f32,
|
||||
s: CudaStream,
|
||||
);
|
||||
}
|
||||
|
||||
// GPU-side optimizer kernels (csrc/ops/optim.cu): AdamW step (m/v on device) and
|
||||
// the global grad-norm reduction + in-place rescale (Phase T7).
|
||||
#[cfg(not(no_cuda))]
|
||||
|
||||
@@ -1092,6 +1092,119 @@ impl Tensor {
|
||||
(dq, dk, dv)
|
||||
}
|
||||
|
||||
// --- Fused flash-attention (the T14 op) ---
|
||||
|
||||
/// Fused flash-attention forward (Phase T14). `self`=Q, `k`, `v` each
|
||||
/// `[bh, seq, head_dim]`, contiguous on one GPU. Computes, per batch element,
|
||||
/// `out = softmax(causal(Q·Kᵀ·scale))·V` in a SINGLE kernel that streams over
|
||||
/// KV tiles with an online softmax — the `[bh,seq,seq]` score matrix is NEVER
|
||||
/// materialized. Returns `(out, lse)` where `lse`:[bh,seq] (F32) is the per-row
|
||||
/// logsumexp cached for backward (O(N), vs the composed path's O(N²) probs).
|
||||
///
|
||||
/// The fused kernel is fp32; for bf16 we upcast Q/K/V → f32 → kernel → downcast
|
||||
/// `out` back to bf16 (same fp32-softmax policy as the composed [`attention`]),
|
||||
/// so flash and composed produce the same softmax numerics. `lse` stays fp32.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn flash_attention(&self, k: &Tensor, v: &Tensor, scale: f32) -> (Tensor, Tensor) {
|
||||
assert_eq!(
|
||||
self.ndim(),
|
||||
3,
|
||||
"flash_attention Q must be [bh,seq,head_dim]"
|
||||
);
|
||||
assert_eq!(self.shape(), k.shape(), "Q/K shape mismatch");
|
||||
assert_eq!(self.shape(), v.shape(), "Q/V shape mismatch");
|
||||
assert_eq!(self.dtype, k.dtype, "Q/K dtype mismatch");
|
||||
assert_eq!(self.dtype, v.dtype, "Q/V dtype mismatch");
|
||||
let (bh, seq, hd) = (self.shape[0], self.shape[1], self.shape[2]);
|
||||
let dev = self.device();
|
||||
let dt = self.dtype;
|
||||
|
||||
let qf = self.to_dtype(DType::F32);
|
||||
let kf = k.to_dtype(DType::F32);
|
||||
let vf = v.to_dtype(DType::F32);
|
||||
let out_f32 = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let lse = Tensor::zeros(&[bh, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_fwd_f32(
|
||||
qf.data_ptr() as *const f32,
|
||||
kf.data_ptr() as *const f32,
|
||||
vf.data_ptr() as *const f32,
|
||||
out_f32.data_ptr() as *mut f32,
|
||||
lse.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
(out_f32.to_dtype(dt), lse)
|
||||
}
|
||||
|
||||
/// Backward of [`flash_attention`](Self::flash_attention). Inputs: forward
|
||||
/// `q`,`k`,`v`, the forward output `out`, the cached `lse`:[bh,seq], the upstream
|
||||
/// `dout`, and the same `scale`. Returns `(dq, dk, dv)`.
|
||||
///
|
||||
/// flash-style: NO cached probs. Recomputes scores from Q/K/V + `lse`, uses
|
||||
/// `D[i]=Σ dOᵢ·Oᵢ` to collapse the softmax Jacobian, streams KV in tiles. dQ is
|
||||
/// owned per query row; dK/dV are accumulated across rows (atomicAdd). Same
|
||||
/// fp32 kernel; bf16 callers get fp32 grads which the autograd `cast` op casts.
|
||||
#[cfg(not(no_cuda))]
|
||||
pub fn flash_attention_backward(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
out: &Tensor,
|
||||
lse: &Tensor,
|
||||
dout: &Tensor,
|
||||
scale: f32,
|
||||
) -> (Tensor, Tensor, Tensor) {
|
||||
let (bh, seq, hd) = (q.shape[0], q.shape[1], q.shape[2]);
|
||||
let dev = q.device();
|
||||
let dt = q.dtype;
|
||||
|
||||
let qf = q.to_dtype(DType::F32);
|
||||
let kf = k.to_dtype(DType::F32);
|
||||
let vf = v.to_dtype(DType::F32);
|
||||
let of = out.to_dtype(DType::F32);
|
||||
let dof = dout.to_dtype(DType::F32);
|
||||
// D[i] = Σ_d dO[i,d]·O[i,d] (one scalar per query row, O(N)).
|
||||
let d = Tensor::zeros(&[bh, seq], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_rowdot_f32(
|
||||
dof.data_ptr() as *const f32,
|
||||
of.data_ptr() as *const f32,
|
||||
d.data_ptr() as *mut f32,
|
||||
(bh * seq) as i32,
|
||||
hd as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
// dq/dk/dv pre-zeroed (Tensor::zeros memsets); dk/dv accumulate via atomicAdd.
|
||||
let dq = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let dk = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
let dv = Tensor::zeros(&[bh, seq, hd], DType::F32, dev);
|
||||
unsafe {
|
||||
xtrain_cuda::ffi::launch_flash_attention_bwd_f32(
|
||||
qf.data_ptr() as *const f32,
|
||||
kf.data_ptr() as *const f32,
|
||||
vf.data_ptr() as *const f32,
|
||||
dof.data_ptr() as *const f32,
|
||||
lse.data_ptr() as *const f32,
|
||||
d.data_ptr() as *mut f32,
|
||||
dq.data_ptr() as *mut f32,
|
||||
dk.data_ptr() as *mut f32,
|
||||
dv.data_ptr() as *mut f32,
|
||||
bh as i32,
|
||||
seq as i32,
|
||||
hd as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
(dq.to_dtype(dt), dk.to_dtype(dt), dv.to_dtype(dt))
|
||||
}
|
||||
|
||||
/// 4D axis-(1,2) transpose: `self`:[a,b,c,d] → [a,c,b,d],
|
||||
/// `out[i,k,j,l]=self[i,j,k,l]`. Lays out batched multi-head attention
|
||||
/// (`[B,S,nh,hd] <-> [B,nh,S,hd]`). Its own backward is the same op (swap b,c).
|
||||
|
||||
273
csrc/ops/flash_attention.cu
Normal file
273
csrc/ops/flash_attention.cu
Normal file
@@ -0,0 +1,273 @@
|
||||
// Hand-written fused flash-attention (Phase T14).
|
||||
//
|
||||
// The T10 composed SDPA path is 3 launches that MATERIALIZE the [bh,S,S] score
|
||||
// matrix: cublasSgemmStridedBatched (Q·Kᵀ) → causal-softmax kernel (writes the
|
||||
// whole probs) → cublasSgemmStridedBatched (P·V), and backward caches that whole
|
||||
// probs. flash-attention NEVER materializes N×N: a single fused kernel streams
|
||||
// over KV tiles with an ONLINE softmax (running max/sum + rescaled V accumulator),
|
||||
// so peak attention activation drops from O(S²) to O(S·hd) (= the output itself).
|
||||
//
|
||||
// Layout (matches the T10 op): Q/K/V/out are [bh, S, hd] row-major contiguous,
|
||||
// bh = batch·n_heads. The query's position within its sequence is the row index
|
||||
// within its [S,hd] block (so the flat row's qpos = (row % S) is automatic here —
|
||||
// we index per (bh, row)). CAUSAL: a query at position i attends to keys j ≤ i.
|
||||
// `scale` (= 1/sqrt(hd)) is folded into the logits before the max/exp.
|
||||
//
|
||||
// All F32, contiguous. (bf16 callers upcast Q/K/V → f32 on the Rust side and
|
||||
// downcast the f32 out, mirroring the composed path's fp32 softmax policy, so the
|
||||
// kernel only ever sees fp32.) Reduction helpers are inlined (self-contained file,
|
||||
// matching the csrc/ layout).
|
||||
//
|
||||
// Parallelisation: grid = bh*S, one block per query row; blockDim.x threads
|
||||
// cooperate. Forward keeps m (running max), l (running sum), acc[hd] (rescaled
|
||||
// V accumulator) in shared memory, streams KV in tiles of BK. Backward recomputes
|
||||
// scores from Q/K/V + the saved logsumexp L[bh,S] (NO cached probs), uses
|
||||
// D[i]=Σ dOᵢ·Oᵢ to collapse the softmax Jacobian, and atomicAdds dK/dV (which are
|
||||
// accumulated across query rows).
|
||||
|
||||
#include <math.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
__device__ __forceinline__ float fa_warp_sum(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v += __shfl_down_sync(0xffffffff, v, off);
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float fa_warp_max(float v) {
|
||||
#pragma unroll
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
v = fmaxf(v, __shfl_down_sync(0xffffffff, v, off));
|
||||
return v;
|
||||
}
|
||||
__device__ __forceinline__ float fa_block_sum(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = fa_warp_sum(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : 0.0f;
|
||||
if (warp == 0) v = fa_warp_sum(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
__device__ __forceinline__ float fa_block_max(float v) {
|
||||
__shared__ float sh[32];
|
||||
int lane = threadIdx.x & 31, warp = threadIdx.x >> 5;
|
||||
int nwarps = (blockDim.x + 31) >> 5;
|
||||
v = fa_warp_max(v);
|
||||
if (lane == 0) sh[warp] = v;
|
||||
__syncthreads();
|
||||
v = (threadIdx.x < nwarps) ? sh[threadIdx.x] : -INFINITY;
|
||||
if (warp == 0) v = fa_warp_max(v);
|
||||
__shared__ float bc;
|
||||
if (threadIdx.x == 0) bc = v;
|
||||
__syncthreads();
|
||||
return bc;
|
||||
}
|
||||
|
||||
#define FA_TILE 32 // KV tile width (columns streamed per step)
|
||||
|
||||
// One block per (bh-row, query-position). Computes out[bh, i, :] and L[bh, i] via
|
||||
// an online softmax that streams the keys in tiles of FA_TILE — the [S,S] score
|
||||
// row is never stored, only the per-tile partials flow through shared memory.
|
||||
__global__ void flash_attn_fwd_k(const float* Q, const float* K, const float* V,
|
||||
float* O, float* L, int seq, int hd, float scale) {
|
||||
int row = blockIdx.x; // global query row over bh*S
|
||||
int b = row / seq; // which (batch,head) block
|
||||
int i = row % seq; // query position within the sequence (causal limit)
|
||||
int t = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
|
||||
const float* q = Q + (size_t)row * hd;
|
||||
const float* kb = K + (size_t)b * seq * hd; // this block's keys [seq,hd]
|
||||
const float* vb = V + (size_t)b * seq * hd; // this block's values[seq,hd]
|
||||
|
||||
// Q row in shared memory (reused every tile); acc accumulator over hd.
|
||||
extern __shared__ float smem[];
|
||||
float* sq = smem; // [hd]
|
||||
float* acc = smem + hd; // [hd]
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
sq[d] = q[d];
|
||||
acc[d] = 0.0f;
|
||||
}
|
||||
__shared__ float m_run, l_run;
|
||||
if (t == 0) { m_run = -INFINITY; l_run = 0.0f; }
|
||||
__syncthreads();
|
||||
|
||||
int valid = i + 1; // causal: attend to keys [0, i]
|
||||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||||
int tile = min(FA_TILE, valid - j0);
|
||||
// Each thread computes whole logits for a strided subset of the tile's
|
||||
// columns: s = scale * (q · k_j). hd is small (≤128) so the per-thread
|
||||
// dot loop is cheap; this avoids a block-reduce per column.
|
||||
__shared__ float s_tile[FA_TILE];
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < hd; ++d) dot += sq[d] * kj[d];
|
||||
s_tile[c] = dot * scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Tile max, then online rescale of (m, l, acc).
|
||||
float tmax = -INFINITY;
|
||||
for (int c = t; c < tile; c += nthreads) tmax = fmaxf(tmax, s_tile[c]);
|
||||
tmax = fa_block_max(tmax);
|
||||
|
||||
__shared__ float m_new, corr;
|
||||
if (t == 0) {
|
||||
float mn = fmaxf(m_run, tmax);
|
||||
corr = (m_run == -INFINITY) ? 0.0f : expf(m_run - mn); // rescale old state
|
||||
m_new = mn;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Rescale old accumulator + add this tile's p·V (p = exp(s - m_new)).
|
||||
// Each thread owns a strided subset of hd; loops over the tile columns.
|
||||
float lsum = 0.0f;
|
||||
for (int c = t; c < tile; c += nthreads) lsum += expf(s_tile[c] - m_new);
|
||||
lsum = fa_block_sum(lsum);
|
||||
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
float a = acc[d] * corr;
|
||||
for (int c = 0; c < tile; ++c) {
|
||||
float p = expf(s_tile[c] - m_new);
|
||||
a += p * vb[(size_t)(j0 + c) * hd + d];
|
||||
}
|
||||
acc[d] = a;
|
||||
}
|
||||
if (t == 0) {
|
||||
l_run = l_run * corr + lsum;
|
||||
m_run = m_new;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// out = acc / l ; L = m + log(l) (logsumexp, saved for backward).
|
||||
float inv = 1.0f / l_run;
|
||||
for (int d = t; d < hd; d += nthreads) O[(size_t)row * hd + d] = acc[d] * inv;
|
||||
if (t == 0) L[row] = m_run + logf(l_run);
|
||||
}
|
||||
|
||||
void launch_flash_attention_fwd_f32(const float* q, const float* k, const float* v,
|
||||
float* o, float* l, int bh, int seq, int hd,
|
||||
float scale, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
size_t shmem = (size_t)2 * hd * sizeof(float); // sq[hd] + acc[hd]
|
||||
flash_attn_fwd_k<<<bh * seq, blk, shmem, (cudaStream_t)s>>>(q, k, v, o, l, seq, hd, scale);
|
||||
}
|
||||
|
||||
// Per-row D[i] = Σ_d dO[i,d] · O[i,d]. One block per row (bh*S rows). Used to
|
||||
// collapse the softmax Jacobian in backward (Σ_j P_ij dP_ij = dOᵢ·Oᵢ).
|
||||
__global__ void flash_attn_rowdot_k(const float* dO, const float* O, float* D, int hd) {
|
||||
int row = blockIdx.x;
|
||||
int t = threadIdx.x;
|
||||
const float* d = dO + (size_t)row * hd;
|
||||
const float* o = O + (size_t)row * hd;
|
||||
float v = 0.0f;
|
||||
for (int c = t; c < hd; c += blockDim.x) v += d[c] * o[c];
|
||||
v = fa_block_sum(v);
|
||||
if (t == 0) D[row] = v;
|
||||
}
|
||||
|
||||
// Backward: one block per query row i. Recomputes scores from Q/K/V + the saved
|
||||
// logsumexp L (NO cached probs), streams KV in tiles. dQ accumulates locally (this
|
||||
// row owns it). dK/dV are accumulated ACROSS query rows so they atomicAdd into the
|
||||
// shared global buffers (pre-zeroed by the caller).
|
||||
// p_ij = exp(Qᵢ·Kⱼ·scale - L[i]) ; dp_ij = dOᵢ·Vⱼ ;
|
||||
// ds_ij = p_ij·(dp_ij - D[i])·scale
|
||||
// dQᵢ += Σ_j ds_ij·Kⱼ ; dKⱼ += ds_ij·Qᵢ ; dVⱼ += p_ij·dOᵢ
|
||||
__global__ void flash_attn_bwd_k(const float* Q, const float* K, const float* V,
|
||||
const float* dO, const float* L, const float* D,
|
||||
float* dQ, float* dK, float* dV,
|
||||
int seq, int hd, float scale) {
|
||||
int row = blockIdx.x;
|
||||
int b = row / seq;
|
||||
int i = row % seq;
|
||||
int t = threadIdx.x;
|
||||
int nthreads = blockDim.x;
|
||||
|
||||
const float* q = Q + (size_t)row * hd;
|
||||
const float* doi = dO + (size_t)row * hd;
|
||||
const float* kb = K + (size_t)b * seq * hd;
|
||||
const float* vb = V + (size_t)b * seq * hd;
|
||||
float* dkb = dK + (size_t)b * seq * hd;
|
||||
float* dvb = dV + (size_t)b * seq * hd;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
float* sq = smem; // [hd] Qᵢ
|
||||
float* sdo = smem + hd; // [hd] dOᵢ
|
||||
float* dqa = smem + 2*hd; // [hd] dQᵢ accumulator
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
sq[d] = q[d];
|
||||
sdo[d] = doi[d];
|
||||
dqa[d] = 0.0f;
|
||||
}
|
||||
__shared__ float Li, Di;
|
||||
if (t == 0) { Li = L[row]; Di = D[row]; }
|
||||
__syncthreads();
|
||||
|
||||
int valid = i + 1;
|
||||
for (int j0 = 0; j0 < valid; j0 += FA_TILE) {
|
||||
int tile = min(FA_TILE, valid - j0);
|
||||
// Per-tile ds[c] (one per column), computed by the thread that owns column c.
|
||||
__shared__ float s_ds[FA_TILE];
|
||||
for (int c = t; c < tile; c += nthreads) {
|
||||
const float* kj = kb + (size_t)(j0 + c) * hd;
|
||||
const float* vj = vb + (size_t)(j0 + c) * hd;
|
||||
float sdot = 0.0f, dpdot = 0.0f;
|
||||
for (int d = 0; d < hd; ++d) {
|
||||
sdot += sq[d] * kj[d];
|
||||
dpdot += sdo[d] * vj[d];
|
||||
}
|
||||
float p = expf(sdot * scale - Li);
|
||||
float ds = p * (dpdot - Di) * scale;
|
||||
s_ds[c] = ds;
|
||||
// dV_j += p · dOᵢ ; dK_j += ds · Qᵢ (accumulated across rows → atomic)
|
||||
float* dvj = dvb + (size_t)(j0 + c) * hd;
|
||||
float* dkj = dkb + (size_t)(j0 + c) * hd;
|
||||
for (int d = 0; d < hd; ++d) {
|
||||
atomicAdd(&dvj[d], p * sdo[d]);
|
||||
atomicAdd(&dkj[d], ds * sq[d]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// dQᵢ += Σ_c ds[c] · K_{j0+c} (this row owns dQ — no atomic).
|
||||
for (int d = t; d < hd; d += nthreads) {
|
||||
float a = 0.0f;
|
||||
for (int c = 0; c < tile; ++c)
|
||||
a += s_ds[c] * kb[(size_t)(j0 + c) * hd + d];
|
||||
dqa[d] += a;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
for (int d = t; d < hd; d += nthreads) dQ[(size_t)row * hd + d] = dqa[d];
|
||||
}
|
||||
|
||||
void launch_flash_attention_bwd_f32(const float* q, const float* k, const float* v,
|
||||
const float* d_o, const float* l, float* d_d,
|
||||
float* dq, float* dk, float* dv,
|
||||
int bh, int seq, int hd, float scale, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
// d_d is the pre-computed D[i]=Σ dOᵢ·Oᵢ (the Rust wrapper runs rowdot first,
|
||||
// since it holds the forward O). dq/dk/dv are pre-zeroed by the caller.
|
||||
flash_attn_bwd_k<<<bh * seq, blk, (size_t)3 * hd * sizeof(float), (cudaStream_t)s>>>(
|
||||
q, k, v, d_o, l, d_d, dq, dk, dv, seq, hd, scale);
|
||||
}
|
||||
|
||||
// Standalone D = rowdot(dO, O) launcher (the Rust wrapper calls this before bwd).
|
||||
void launch_flash_attention_rowdot_f32(const float* d_o, const float* o, float* d_d,
|
||||
int rows, int hd, void* s) {
|
||||
int blk = hd < 1024 ? hd : 1024;
|
||||
if (blk < 32) blk = 32;
|
||||
flash_attn_rowdot_k<<<rows, blk, 0, (cudaStream_t)s>>>(d_o, o, d_d, hd);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
Reference in New Issue
Block a user