#include #include #include "../common.cuh" // Flash Attention 2 forward kernel for BF16 with FP32 accumulation. // // Algorithm: outer loop over Q tiles (BR rows), inner loop over K/V tiles (BC rows). // Uses online softmax — no O(S^2) memory. // // Layout: Q [batch, num_q_heads, q_len, head_dim] // K [batch, num_kv_heads, kv_len, head_dim] // V [batch, num_kv_heads, kv_len, head_dim] // O [batch, num_q_heads, q_len, head_dim] // // Shared memory (BF16): // smem_q[BR][head_dim] — 64 * 128 * 2 = 16 KB (loaded once per Q tile) // smem_kv[BC][head_dim] — 64 * 128 * 2 = 16 KB (alternates K and V) // Total: 32 KB (fits in default 48 KB shared memory) #define BR 64 #define BC 64 #define THREADS_PER_BLOCK 128 __global__ void flash_attention_bf16_kernel( const __nv_bfloat16* __restrict__ Q, const __nv_bfloat16* __restrict__ K, const __nv_bfloat16* __restrict__ V, __nv_bfloat16* __restrict__ O, int num_q_heads, int num_kv_heads, int q_len, int kv_len, int head_dim, float scale, int causal ) { // Grid: (ceil(q_len / BR), batch * num_q_heads) int q_tile_idx = blockIdx.x; int bh = blockIdx.y; int batch_idx = bh / num_q_heads; int q_head = bh % num_q_heads; // GQA: map Q head to KV head int heads_per_group = num_q_heads / num_kv_heads; int kv_head = q_head / heads_per_group; int q_tile_start = q_tile_idx * BR; if (q_tile_start >= q_len) return; int q_tile_rows = min(BR, q_len - q_tile_start); // Pointers to this batch/head's data const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim; const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; __nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim; int tid = threadIdx.x; // Dynamic shared memory extern __shared__ __nv_bfloat16 smem[]; __nv_bfloat16* smem_q = smem; // BR * head_dim elements __nv_bfloat16* smem_kv = smem + BR * head_dim; // BC * head_dim elements // ---- Load Q tile into shared memory (cooperative) ---- int q_elems = q_tile_rows * head_dim; for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col]; } // Zero-pad if q_tile_rows < BR for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) { smem_q[i] = __float2bfloat16(0.0f); } __syncthreads(); // Thread t (0 <= t < q_tile_rows) owns Q row t bool owns_row = (tid < q_tile_rows); // Per-thread FP32 accumulators (head_dim up to 128) float O_acc[128]; float m_val = -INFINITY; float l_val = 0.0f; if (owns_row) { for (int d = 0; d < head_dim; d++) { O_acc[d] = 0.0f; } } // kv_offset handles cached KV longer than Q (decode step) int kv_offset = kv_len - q_len; int num_kv_tiles = (kv_len + BC - 1) / BC; // ---- Inner loop over K/V tiles ---- for (int j = 0; j < num_kv_tiles; j++) { int kv_tile_start = j * BC; int kv_tile_cols = min(BC, kv_len - kv_tile_start); // Causal: skip entire tile if all K positions are in the future if (causal) { int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset; if (kv_tile_start > max_allowed_kv) { continue; } } // ---- Load K tile into smem_kv ---- int kv_elems = kv_tile_cols * head_dim; for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col]; } for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) { smem_kv[i] = __float2bfloat16(0.0f); } __syncthreads(); // ---- Compute S = Q @ K^T * scale, causal mask, online softmax ---- float P[BC]; if (owns_row) { float row_max = -INFINITY; for (int c = 0; c < kv_tile_cols; c++) { float dot = 0.0f; for (int d = 0; d < head_dim; d++) { dot += __bfloat162float(smem_q[tid * head_dim + d]) * __bfloat162float(smem_kv[c * head_dim + d]); } float s = dot * scale; if (causal) { int q_pos = q_tile_start + tid; int kv_pos = kv_tile_start + c; if (kv_pos > q_pos + kv_offset) { s = -INFINITY; } } P[c] = s; // store score temporarily in P row_max = fmaxf(row_max, s); } // Online softmax: m_new, P = exp(S - m_new), l_new float m_new = fmaxf(m_val, row_max); float psum = 0.0f; for (int c = 0; c < kv_tile_cols; c++) { P[c] = expf(P[c] - m_new); psum += P[c]; } // Rescale previous accumulator float correction = expf(m_val - m_new); l_val = correction * l_val + psum; for (int d = 0; d < head_dim; d++) { O_acc[d] *= correction; } m_val = m_new; } // Sync before overwriting smem_kv with V tile __syncthreads(); // ---- Load V tile (reuse smem_kv) ---- int v_elems = kv_tile_cols * head_dim; for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col]; } for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) { smem_kv[i] = __float2bfloat16(0.0f); } __syncthreads(); // ---- Accumulate O += P @ V_tile ---- if (owns_row) { for (int c = 0; c < kv_tile_cols; c++) { float p = P[c]; if (p != 0.0f) { for (int d = 0; d < head_dim; d++) { O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]); } } } } __syncthreads(); } // ---- Final normalize and write output (convert FP32 → BF16) ---- if (owns_row) { float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f; int global_row = q_tile_start + tid; for (int d = 0; d < head_dim; d++) { O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l); } } } // Flash Attention 2 forward with gpt-oss attention sinks + optional sliding window. // Identical to flash_attention_bf16_kernel, plus: // - sinks: [num_q_heads] BF16 — a per-head extra softmax logit (no value), // folded into the denominator after the K/V tiles (exactly as the decode // sink kernel does). // - window_size > 0: sliding-window mask. Query at global position p attends // to keys k with p - window_size < k <= p (matches HF gpt-oss). __global__ void flash_attention_sinks_bf16_kernel( const __nv_bfloat16* __restrict__ Q, const __nv_bfloat16* __restrict__ K, const __nv_bfloat16* __restrict__ V, __nv_bfloat16* __restrict__ O, const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL int num_q_heads, int num_kv_heads, int q_len, int kv_len, int head_dim, float scale, int causal, int window_size ) { int q_tile_idx = blockIdx.x; int bh = blockIdx.y; int batch_idx = bh / num_q_heads; int q_head = bh % num_q_heads; int heads_per_group = num_q_heads / num_kv_heads; int kv_head = q_head / heads_per_group; int q_tile_start = q_tile_idx * BR; if (q_tile_start >= q_len) return; int q_tile_rows = min(BR, q_len - q_tile_start); const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim; const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; __nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim; int tid = threadIdx.x; extern __shared__ __nv_bfloat16 smem[]; __nv_bfloat16* smem_q = smem; __nv_bfloat16* smem_kv = smem + BR * head_dim; int q_elems = q_tile_rows * head_dim; for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col]; } for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) { smem_q[i] = __float2bfloat16(0.0f); } __syncthreads(); bool owns_row = (tid < q_tile_rows); float O_acc[128]; float m_val = -INFINITY; float l_val = 0.0f; if (owns_row) { for (int d = 0; d < head_dim; d++) O_acc[d] = 0.0f; } int kv_offset = kv_len - q_len; int num_kv_tiles = (kv_len + BC - 1) / BC; for (int j = 0; j < num_kv_tiles; j++) { int kv_tile_start = j * BC; int kv_tile_cols = min(BC, kv_len - kv_tile_start); if (causal) { int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset; if (kv_tile_start > max_allowed_kv) continue; } int kv_elems = kv_tile_cols * head_dim; for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col]; } for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) { smem_kv[i] = __float2bfloat16(0.0f); } __syncthreads(); float P[BC]; if (owns_row) { float row_max = -INFINITY; int q_pos = q_tile_start + tid + kv_offset; // global query position for (int c = 0; c < kv_tile_cols; c++) { float dot = 0.0f; for (int d = 0; d < head_dim; d++) { dot += __bfloat162float(smem_q[tid * head_dim + d]) * __bfloat162float(smem_kv[c * head_dim + d]); } float s = dot * scale; int kv_pos = kv_tile_start + c; if (causal && kv_pos > q_pos) { s = -INFINITY; } // Sliding window: drop keys older than the window. if (window_size > 0 && kv_pos <= q_pos - window_size) { s = -INFINITY; } P[c] = s; row_max = fmaxf(row_max, s); } // A fully-masked KV tile (every key causal- or window-masked) has // row_max == -INFINITY. Folding it in computes expf(-inf - (-inf)) // = NaN, and a later valid tile's 0*NaN correction then poisons the // whole row. This happens for sliding-window layers whenever a // query's window starts past an early tile (the causal `continue` // above only skips fully-future tiles, not out-of-window ones). // A masked tile contributes nothing to the softmax — skip it. if (row_max != -INFINITY) { float m_new = fmaxf(m_val, row_max); float psum = 0.0f; for (int c = 0; c < kv_tile_cols; c++) { P[c] = expf(P[c] - m_new); psum += P[c]; } float correction = expf(m_val - m_new); l_val = correction * l_val + psum; for (int d = 0; d < head_dim; d++) O_acc[d] *= correction; m_val = m_new; } else { for (int c = 0; c < kv_tile_cols; c++) P[c] = 0.0f; } } __syncthreads(); int v_elems = kv_tile_cols * head_dim; for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) { int row = i / head_dim; int col = i % head_dim; smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col]; } for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) { smem_kv[i] = __float2bfloat16(0.0f); } __syncthreads(); if (owns_row) { for (int c = 0; c < kv_tile_cols; c++) { float p = P[c]; if (p != 0.0f) { for (int d = 0; d < head_dim; d++) { O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]); } } } } __syncthreads(); } // Fold in the per-head attention sink (extra logit, no value contribution). if (owns_row && sinks != nullptr) { float sink_logit = __bfloat162float(sinks[q_head]); float m_new = fmaxf(m_val, sink_logit); float correction = expf(m_val - m_new); l_val = correction * l_val + expf(sink_logit - m_new); for (int d = 0; d < head_dim; d++) O_acc[d] *= correction; m_val = m_new; } if (owns_row) { float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f; int global_row = q_tile_start + tid; for (int d = 0; d < head_dim; d++) { O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l); } } } // ============================================================ // Decode Attention kernel: optimized for Q_len=1 (single-token decode). // Parallelizes across KV sequence dimension instead of Q rows. // // Grid: (batch * num_q_heads, 1) — one block per Q head // Block: 256 threads — each thread handles ceil(kv_len / 256) KV positions // Uses online softmax reduction across threads. // ============================================================ #define DECODE_THREADS 256 #define HEAD_DIM_MAX 128 __global__ void decode_attention_bf16_kernel( const __nv_bfloat16* __restrict__ Q, const __nv_bfloat16* __restrict__ K, const __nv_bfloat16* __restrict__ V, __nv_bfloat16* __restrict__ O, int num_q_heads, int num_kv_heads, int kv_len, int head_dim, float scale ) { int bh = blockIdx.x; int batch_idx = bh / num_q_heads; int q_head = bh % num_q_heads; // GQA mapping int heads_per_group = num_q_heads / num_kv_heads; int kv_head = q_head / heads_per_group; int tid = threadIdx.x; // Pointers to this batch/head's data // Q: [batch, num_q_heads, 1, head_dim] const __nv_bfloat16* Q_ptr = Q + ((long long)batch_idx * num_q_heads + q_head) * head_dim; // K/V: [batch, num_kv_heads, kv_len, head_dim] const __nv_bfloat16* K_base = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; const __nv_bfloat16* V_base = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim; __nv_bfloat16* O_ptr = O + ((long long)batch_idx * num_q_heads + q_head) * head_dim; // Load Q vector into registers (head_dim <= 128) float q_reg[HEAD_DIM_MAX]; for (int d = 0; d < head_dim; d++) { q_reg[d] = __bfloat162float(Q_ptr[d]); } // Each thread processes a chunk of KV positions // Thread tid handles positions: tid, tid+DECODE_THREADS, tid+2*DECODE_THREADS, ... float local_max = -INFINITY; float local_sum = 0.0f; float local_O[HEAD_DIM_MAX]; for (int d = 0; d < head_dim; d++) { local_O[d] = 0.0f; } for (int pos = tid; pos < kv_len; pos += DECODE_THREADS) { // Compute dot(Q, K[pos]) * scale const __nv_bfloat16* K_pos = K_base + pos * head_dim; float dot = 0.0f; for (int d = 0; d < head_dim; d++) { dot += q_reg[d] * __bfloat162float(K_pos[d]); } float s = dot * scale; // Online softmax update float new_max = fmaxf(local_max, s); float correction = expf(local_max - new_max); float p = expf(s - new_max); // Rescale running sum and O local_sum = local_sum * correction + p; for (int d = 0; d < head_dim; d++) { local_O[d] = local_O[d] * correction; } // Accumulate V[pos] weighted by p const __nv_bfloat16* V_pos = V_base + pos * head_dim; for (int d = 0; d < head_dim; d++) { local_O[d] += p * __bfloat162float(V_pos[d]); } local_max = new_max; } // --- Block-level online softmax reduction --- // We need to combine (local_max, local_sum, local_O) across all threads. // Strategy: reduce max, then each thread rescales, then reduce sum and O. // Shared memory for reduction __shared__ float smem_max[32]; // one per warp __shared__ float smem_sum[32]; __shared__ float smem_O_warp[32][HEAD_DIM_MAX]; // Step 1: Block-wide max reduction int lane = tid & 31; int warp_id = tid >> 5; int num_warps = DECODE_THREADS >> 5; // 8 warps float warp_max = local_max; #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset)); if (lane == 0) smem_max[warp_id] = warp_max; __syncthreads(); float global_max; if (tid == 0) { global_max = smem_max[0]; for (int i = 1; i < num_warps; i++) global_max = fmaxf(global_max, smem_max[i]); smem_max[0] = global_max; } __syncthreads(); global_max = smem_max[0]; // Step 2: Each thread rescales its local_sum and local_O with global_max float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max); local_sum *= rescale; for (int d = 0; d < head_dim; d++) { local_O[d] *= rescale; } // Step 3: Reduce sum across block float warp_sum = local_sum; #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset); if (lane == 0) smem_sum[warp_id] = warp_sum; __syncthreads(); float global_sum; if (tid == 0) { global_sum = 0.0f; for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i]; smem_sum[0] = global_sum; } __syncthreads(); global_sum = smem_sum[0]; // Step 4: Reduce O across block, dim by dim. Store one partial per warp // and sum in warp-id order; atomicAdd made greedy decode nondeterministic // when logits were close (same fix pattern as paged_attention.cu / gemv.cu). float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f; for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) { reinterpret_cast(smem_O_warp)[i] = 0.0f; } __syncthreads(); for (int d = 0; d < head_dim; d++) { float val = local_O[d]; #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) val += __shfl_down_sync(0xffffffff, val, offset); if (lane == 0) smem_O_warp[warp_id][d] = val; } __syncthreads(); // Thread 0..head_dim-1 write final output for (int d = tid; d < head_dim; d += DECODE_THREADS) { float out = 0.0f; for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d]; O_ptr[d] = __float2bfloat16(out * inv_sum); } } extern "C" { void launch_flash_attention_bf16( const void* Q, const void* K, const void* V, void* O, int batch, int num_q_heads, int num_kv_heads, int q_len, int kv_len, int head_dim, float scale, int causal, void* stream ) { int q_tiles = (q_len + BR - 1) / BR; dim3 grid(q_tiles, batch * num_q_heads); int block = THREADS_PER_BLOCK; // Shared memory: smem_q[BR * head_dim] + smem_kv[BC * head_dim], all BF16 int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16); flash_attention_bf16_kernel<<>>( (const __nv_bfloat16*)Q, (const __nv_bfloat16*)K, (const __nv_bfloat16*)V, (__nv_bfloat16*)O, num_q_heads, num_kv_heads, q_len, kv_len, head_dim, scale, causal ); CUDA_CHECK_LAST_ERROR(); } void launch_flash_attention_sinks_bf16( const void* Q, const void* K, const void* V, void* O, const void* sinks, int batch, int num_q_heads, int num_kv_heads, int q_len, int kv_len, int head_dim, float scale, int causal, int window_size, void* stream ) { int q_tiles = (q_len + BR - 1) / BR; dim3 grid(q_tiles, batch * num_q_heads); int block = THREADS_PER_BLOCK; int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16); flash_attention_sinks_bf16_kernel<<>>( (const __nv_bfloat16*)Q, (const __nv_bfloat16*)K, (const __nv_bfloat16*)V, (__nv_bfloat16*)O, (const __nv_bfloat16*)sinks, num_q_heads, num_kv_heads, q_len, kv_len, head_dim, scale, causal, window_size ); CUDA_CHECK_LAST_ERROR(); } void launch_decode_attention_bf16( const void* Q, const void* K, const void* V, void* O, int batch, int num_q_heads, int num_kv_heads, int kv_len, int head_dim, float scale, int causal, void* stream ) { int grid = batch * num_q_heads; int block = DECODE_THREADS; decode_attention_bf16_kernel<<>>( (const __nv_bfloat16*)Q, (const __nv_bfloat16*)K, (const __nv_bfloat16*)V, (__nv_bfloat16*)O, num_q_heads, num_kv_heads, kv_len, head_dim, scale ); CUDA_CHECK_LAST_ERROR(); } }