Three CUDA bugs from the review after5b350ee/cfbd64dthat were missed by those commits: - flash_attention.cu decode_attention_bf16_kernel used atomicAdd to merge per-warp partials into smem_O — same nondeterminism pattern that5b350eealready fixed in paged_attention.cu and gemv.cu. This kernel is on the legacy forward_gpu_cache path plus the speculative bench baseline, so verify/decode parity depended on it. Replace with smem_O_warp[32][HEAD_DIM_MAX] partials reduced in fixed warp-id order. - causal_mask.cu computed the flat address as `batch_idx * rows * cols + row * cols + col` in int; batch=128 heads=28 seq=32768 already overflows int32. Promote the index to long long. - quantization/dequant_fp8.cu had `int total = num_experts * rows * cols` and `int expert_stride = rows * cols`; 32 experts × 8k × 8k overflows. Same fix pattern as the MoE dense kernels incfbd64d— 64-bit total / idx / expert_stride, and grid computed in long long.
617 lines
22 KiB
Plaintext
617 lines
22 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <float.h>
|
|
#include "../common.cuh"
|
|
|
|
// Flash Attention 2 forward kernel for BF16 with FP32 accumulation.
|
|
//
|
|
// Algorithm: outer loop over Q tiles (BR rows), inner loop over K/V tiles (BC rows).
|
|
// Uses online softmax — no O(S^2) memory.
|
|
//
|
|
// Layout: Q [batch, num_q_heads, q_len, head_dim]
|
|
// K [batch, num_kv_heads, kv_len, head_dim]
|
|
// V [batch, num_kv_heads, kv_len, head_dim]
|
|
// O [batch, num_q_heads, q_len, head_dim]
|
|
//
|
|
// Shared memory (BF16):
|
|
// smem_q[BR][head_dim] — 64 * 128 * 2 = 16 KB (loaded once per Q tile)
|
|
// smem_kv[BC][head_dim] — 64 * 128 * 2 = 16 KB (alternates K and V)
|
|
// Total: 32 KB (fits in default 48 KB shared memory)
|
|
|
|
#define BR 64
|
|
#define BC 64
|
|
#define THREADS_PER_BLOCK 128
|
|
|
|
__global__ void flash_attention_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ Q,
|
|
const __nv_bfloat16* __restrict__ K,
|
|
const __nv_bfloat16* __restrict__ V,
|
|
__nv_bfloat16* __restrict__ O,
|
|
int num_q_heads, int num_kv_heads,
|
|
int q_len, int kv_len, int head_dim,
|
|
float scale, int causal
|
|
) {
|
|
// Grid: (ceil(q_len / BR), batch * num_q_heads)
|
|
int q_tile_idx = blockIdx.x;
|
|
int bh = blockIdx.y;
|
|
int batch_idx = bh / num_q_heads;
|
|
int q_head = bh % num_q_heads;
|
|
|
|
// GQA: map Q head to KV head
|
|
int heads_per_group = num_q_heads / num_kv_heads;
|
|
int kv_head = q_head / heads_per_group;
|
|
|
|
int q_tile_start = q_tile_idx * BR;
|
|
if (q_tile_start >= q_len) return;
|
|
int q_tile_rows = min(BR, q_len - q_tile_start);
|
|
|
|
// Pointers to this batch/head's data
|
|
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
|
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
// Dynamic shared memory
|
|
extern __shared__ __nv_bfloat16 smem[];
|
|
__nv_bfloat16* smem_q = smem; // BR * head_dim elements
|
|
__nv_bfloat16* smem_kv = smem + BR * head_dim; // BC * head_dim elements
|
|
|
|
// ---- Load Q tile into shared memory (cooperative) ----
|
|
int q_elems = q_tile_rows * head_dim;
|
|
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
|
|
}
|
|
// Zero-pad if q_tile_rows < BR
|
|
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_q[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
// Thread t (0 <= t < q_tile_rows) owns Q row t
|
|
bool owns_row = (tid < q_tile_rows);
|
|
|
|
// Per-thread FP32 accumulators (head_dim up to 128)
|
|
float O_acc[128];
|
|
float m_val = -INFINITY;
|
|
float l_val = 0.0f;
|
|
if (owns_row) {
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_acc[d] = 0.0f;
|
|
}
|
|
}
|
|
|
|
// kv_offset handles cached KV longer than Q (decode step)
|
|
int kv_offset = kv_len - q_len;
|
|
int num_kv_tiles = (kv_len + BC - 1) / BC;
|
|
|
|
// ---- Inner loop over K/V tiles ----
|
|
for (int j = 0; j < num_kv_tiles; j++) {
|
|
int kv_tile_start = j * BC;
|
|
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
|
|
|
|
// Causal: skip entire tile if all K positions are in the future
|
|
if (causal) {
|
|
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
|
|
if (kv_tile_start > max_allowed_kv) {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// ---- Load K tile into smem_kv ----
|
|
int kv_elems = kv_tile_cols * head_dim;
|
|
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
|
|
}
|
|
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_kv[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
// ---- Compute S = Q @ K^T * scale, causal mask, online softmax ----
|
|
float P[BC];
|
|
|
|
if (owns_row) {
|
|
float row_max = -INFINITY;
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
float dot = 0.0f;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
dot += __bfloat162float(smem_q[tid * head_dim + d])
|
|
* __bfloat162float(smem_kv[c * head_dim + d]);
|
|
}
|
|
float s = dot * scale;
|
|
|
|
if (causal) {
|
|
int q_pos = q_tile_start + tid;
|
|
int kv_pos = kv_tile_start + c;
|
|
if (kv_pos > q_pos + kv_offset) {
|
|
s = -INFINITY;
|
|
}
|
|
}
|
|
|
|
P[c] = s; // store score temporarily in P
|
|
row_max = fmaxf(row_max, s);
|
|
}
|
|
|
|
// Online softmax: m_new, P = exp(S - m_new), l_new
|
|
float m_new = fmaxf(m_val, row_max);
|
|
|
|
float psum = 0.0f;
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
P[c] = expf(P[c] - m_new);
|
|
psum += P[c];
|
|
}
|
|
|
|
// Rescale previous accumulator
|
|
float correction = expf(m_val - m_new);
|
|
l_val = correction * l_val + psum;
|
|
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_acc[d] *= correction;
|
|
}
|
|
|
|
m_val = m_new;
|
|
}
|
|
|
|
// Sync before overwriting smem_kv with V tile
|
|
__syncthreads();
|
|
|
|
// ---- Load V tile (reuse smem_kv) ----
|
|
int v_elems = kv_tile_cols * head_dim;
|
|
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
|
|
}
|
|
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_kv[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
// ---- Accumulate O += P @ V_tile ----
|
|
if (owns_row) {
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
float p = P[c];
|
|
if (p != 0.0f) {
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
}
|
|
|
|
// ---- Final normalize and write output (convert FP32 → BF16) ----
|
|
if (owns_row) {
|
|
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
|
|
int global_row = q_tile_start + tid;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Flash Attention 2 forward with gpt-oss attention sinks + optional sliding window.
|
|
// Identical to flash_attention_bf16_kernel, plus:
|
|
// - sinks: [num_q_heads] BF16 — a per-head extra softmax logit (no value),
|
|
// folded into the denominator after the K/V tiles (exactly as the decode
|
|
// sink kernel does).
|
|
// - window_size > 0: sliding-window mask. Query at global position p attends
|
|
// to keys k with p - window_size < k <= p (matches HF gpt-oss).
|
|
__global__ void flash_attention_sinks_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ Q,
|
|
const __nv_bfloat16* __restrict__ K,
|
|
const __nv_bfloat16* __restrict__ V,
|
|
__nv_bfloat16* __restrict__ O,
|
|
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
|
|
int num_q_heads, int num_kv_heads,
|
|
int q_len, int kv_len, int head_dim,
|
|
float scale, int causal, int window_size
|
|
) {
|
|
int q_tile_idx = blockIdx.x;
|
|
int bh = blockIdx.y;
|
|
int batch_idx = bh / num_q_heads;
|
|
int q_head = bh % num_q_heads;
|
|
|
|
int heads_per_group = num_q_heads / num_kv_heads;
|
|
int kv_head = q_head / heads_per_group;
|
|
|
|
int q_tile_start = q_tile_idx * BR;
|
|
if (q_tile_start >= q_len) return;
|
|
int q_tile_rows = min(BR, q_len - q_tile_start);
|
|
|
|
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
|
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
extern __shared__ __nv_bfloat16 smem[];
|
|
__nv_bfloat16* smem_q = smem;
|
|
__nv_bfloat16* smem_kv = smem + BR * head_dim;
|
|
|
|
int q_elems = q_tile_rows * head_dim;
|
|
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
|
|
}
|
|
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_q[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
bool owns_row = (tid < q_tile_rows);
|
|
|
|
float O_acc[128];
|
|
float m_val = -INFINITY;
|
|
float l_val = 0.0f;
|
|
if (owns_row) {
|
|
for (int d = 0; d < head_dim; d++) O_acc[d] = 0.0f;
|
|
}
|
|
|
|
int kv_offset = kv_len - q_len;
|
|
int num_kv_tiles = (kv_len + BC - 1) / BC;
|
|
|
|
for (int j = 0; j < num_kv_tiles; j++) {
|
|
int kv_tile_start = j * BC;
|
|
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
|
|
|
|
if (causal) {
|
|
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
|
|
if (kv_tile_start > max_allowed_kv) continue;
|
|
}
|
|
|
|
int kv_elems = kv_tile_cols * head_dim;
|
|
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
|
|
}
|
|
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_kv[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
float P[BC];
|
|
|
|
if (owns_row) {
|
|
float row_max = -INFINITY;
|
|
int q_pos = q_tile_start + tid + kv_offset; // global query position
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
float dot = 0.0f;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
dot += __bfloat162float(smem_q[tid * head_dim + d])
|
|
* __bfloat162float(smem_kv[c * head_dim + d]);
|
|
}
|
|
float s = dot * scale;
|
|
|
|
int kv_pos = kv_tile_start + c;
|
|
if (causal && kv_pos > q_pos) {
|
|
s = -INFINITY;
|
|
}
|
|
// Sliding window: drop keys older than the window.
|
|
if (window_size > 0 && kv_pos <= q_pos - window_size) {
|
|
s = -INFINITY;
|
|
}
|
|
|
|
P[c] = s;
|
|
row_max = fmaxf(row_max, s);
|
|
}
|
|
|
|
// A fully-masked KV tile (every key causal- or window-masked) has
|
|
// row_max == -INFINITY. Folding it in computes expf(-inf - (-inf))
|
|
// = NaN, and a later valid tile's 0*NaN correction then poisons the
|
|
// whole row. This happens for sliding-window layers whenever a
|
|
// query's window starts past an early tile (the causal `continue`
|
|
// above only skips fully-future tiles, not out-of-window ones).
|
|
// A masked tile contributes nothing to the softmax — skip it.
|
|
if (row_max != -INFINITY) {
|
|
float m_new = fmaxf(m_val, row_max);
|
|
float psum = 0.0f;
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
P[c] = expf(P[c] - m_new);
|
|
psum += P[c];
|
|
}
|
|
float correction = expf(m_val - m_new);
|
|
l_val = correction * l_val + psum;
|
|
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
|
m_val = m_new;
|
|
} else {
|
|
for (int c = 0; c < kv_tile_cols; c++) P[c] = 0.0f;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
int v_elems = kv_tile_cols * head_dim;
|
|
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
|
|
int row = i / head_dim;
|
|
int col = i % head_dim;
|
|
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
|
|
}
|
|
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
|
|
smem_kv[i] = __float2bfloat16(0.0f);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (owns_row) {
|
|
for (int c = 0; c < kv_tile_cols; c++) {
|
|
float p = P[c];
|
|
if (p != 0.0f) {
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// Fold in the per-head attention sink (extra logit, no value contribution).
|
|
if (owns_row && sinks != nullptr) {
|
|
float sink_logit = __bfloat162float(sinks[q_head]);
|
|
float m_new = fmaxf(m_val, sink_logit);
|
|
float correction = expf(m_val - m_new);
|
|
l_val = correction * l_val + expf(sink_logit - m_new);
|
|
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
|
|
m_val = m_new;
|
|
}
|
|
|
|
if (owns_row) {
|
|
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
|
|
int global_row = q_tile_start + tid;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================
|
|
// Decode Attention kernel: optimized for Q_len=1 (single-token decode).
|
|
// Parallelizes across KV sequence dimension instead of Q rows.
|
|
//
|
|
// Grid: (batch * num_q_heads, 1) — one block per Q head
|
|
// Block: 256 threads — each thread handles ceil(kv_len / 256) KV positions
|
|
// Uses online softmax reduction across threads.
|
|
// ============================================================
|
|
|
|
#define DECODE_THREADS 256
|
|
#define HEAD_DIM_MAX 128
|
|
|
|
__global__ void decode_attention_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ Q,
|
|
const __nv_bfloat16* __restrict__ K,
|
|
const __nv_bfloat16* __restrict__ V,
|
|
__nv_bfloat16* __restrict__ O,
|
|
int num_q_heads, int num_kv_heads,
|
|
int kv_len, int head_dim,
|
|
float scale
|
|
) {
|
|
int bh = blockIdx.x;
|
|
int batch_idx = bh / num_q_heads;
|
|
int q_head = bh % num_q_heads;
|
|
|
|
// GQA mapping
|
|
int heads_per_group = num_q_heads / num_kv_heads;
|
|
int kv_head = q_head / heads_per_group;
|
|
|
|
int tid = threadIdx.x;
|
|
|
|
// Pointers to this batch/head's data
|
|
// Q: [batch, num_q_heads, 1, head_dim]
|
|
const __nv_bfloat16* Q_ptr = Q + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
|
// K/V: [batch, num_kv_heads, kv_len, head_dim]
|
|
const __nv_bfloat16* K_base = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
const __nv_bfloat16* V_base = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
|
|
__nv_bfloat16* O_ptr = O + ((long long)batch_idx * num_q_heads + q_head) * head_dim;
|
|
|
|
// Load Q vector into registers (head_dim <= 128)
|
|
float q_reg[HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) {
|
|
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
|
}
|
|
|
|
// Each thread processes a chunk of KV positions
|
|
// Thread tid handles positions: tid, tid+DECODE_THREADS, tid+2*DECODE_THREADS, ...
|
|
float local_max = -INFINITY;
|
|
float local_sum = 0.0f;
|
|
float local_O[HEAD_DIM_MAX];
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] = 0.0f;
|
|
}
|
|
|
|
for (int pos = tid; pos < kv_len; pos += DECODE_THREADS) {
|
|
// Compute dot(Q, K[pos]) * scale
|
|
const __nv_bfloat16* K_pos = K_base + pos * head_dim;
|
|
float dot = 0.0f;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
|
}
|
|
float s = dot * scale;
|
|
|
|
// Online softmax update
|
|
float new_max = fmaxf(local_max, s);
|
|
float correction = expf(local_max - new_max);
|
|
float p = expf(s - new_max);
|
|
|
|
// Rescale running sum and O
|
|
local_sum = local_sum * correction + p;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] = local_O[d] * correction;
|
|
}
|
|
|
|
// Accumulate V[pos] weighted by p
|
|
const __nv_bfloat16* V_pos = V_base + pos * head_dim;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] += p * __bfloat162float(V_pos[d]);
|
|
}
|
|
|
|
local_max = new_max;
|
|
}
|
|
|
|
// --- Block-level online softmax reduction ---
|
|
// We need to combine (local_max, local_sum, local_O) across all threads.
|
|
// Strategy: reduce max, then each thread rescales, then reduce sum and O.
|
|
|
|
// Shared memory for reduction
|
|
__shared__ float smem_max[32]; // one per warp
|
|
__shared__ float smem_sum[32];
|
|
__shared__ float smem_O_warp[32][HEAD_DIM_MAX];
|
|
|
|
// Step 1: Block-wide max reduction
|
|
int lane = tid & 31;
|
|
int warp_id = tid >> 5;
|
|
int num_warps = DECODE_THREADS >> 5; // 8 warps
|
|
|
|
float warp_max = local_max;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
|
if (lane == 0) smem_max[warp_id] = warp_max;
|
|
__syncthreads();
|
|
|
|
float global_max;
|
|
if (tid == 0) {
|
|
global_max = smem_max[0];
|
|
for (int i = 1; i < num_warps; i++)
|
|
global_max = fmaxf(global_max, smem_max[i]);
|
|
smem_max[0] = global_max;
|
|
}
|
|
__syncthreads();
|
|
global_max = smem_max[0];
|
|
|
|
// Step 2: Each thread rescales its local_sum and local_O with global_max
|
|
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
|
local_sum *= rescale;
|
|
for (int d = 0; d < head_dim; d++) {
|
|
local_O[d] *= rescale;
|
|
}
|
|
|
|
// Step 3: Reduce sum across block
|
|
float warp_sum = local_sum;
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
|
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
|
__syncthreads();
|
|
|
|
float global_sum;
|
|
if (tid == 0) {
|
|
global_sum = 0.0f;
|
|
for (int i = 0; i < num_warps; i++)
|
|
global_sum += smem_sum[i];
|
|
smem_sum[0] = global_sum;
|
|
}
|
|
__syncthreads();
|
|
global_sum = smem_sum[0];
|
|
|
|
// Step 4: Reduce O across block, dim by dim. Store one partial per warp
|
|
// and sum in warp-id order; atomicAdd made greedy decode nondeterministic
|
|
// when logits were close (same fix pattern as paged_attention.cu / gemv.cu).
|
|
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
|
|
|
for (int i = tid; i < 32 * HEAD_DIM_MAX; i += DECODE_THREADS) {
|
|
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
|
}
|
|
__syncthreads();
|
|
|
|
for (int d = 0; d < head_dim; d++) {
|
|
float val = local_O[d];
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
|
}
|
|
__syncthreads();
|
|
|
|
// Thread 0..head_dim-1 write final output
|
|
for (int d = tid; d < head_dim; d += DECODE_THREADS) {
|
|
float out = 0.0f;
|
|
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
|
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_flash_attention_bf16(
|
|
const void* Q, const void* K, const void* V, void* O,
|
|
int batch, int num_q_heads, int num_kv_heads,
|
|
int q_len, int kv_len, int head_dim,
|
|
float scale, int causal, void* stream
|
|
) {
|
|
int q_tiles = (q_len + BR - 1) / BR;
|
|
dim3 grid(q_tiles, batch * num_q_heads);
|
|
int block = THREADS_PER_BLOCK;
|
|
|
|
// Shared memory: smem_q[BR * head_dim] + smem_kv[BC * head_dim], all BF16
|
|
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
|
|
|
|
flash_attention_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)Q,
|
|
(const __nv_bfloat16*)K,
|
|
(const __nv_bfloat16*)V,
|
|
(__nv_bfloat16*)O,
|
|
num_q_heads, num_kv_heads,
|
|
q_len, kv_len, head_dim,
|
|
scale, causal
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_flash_attention_sinks_bf16(
|
|
const void* Q, const void* K, const void* V, void* O,
|
|
const void* sinks,
|
|
int batch, int num_q_heads, int num_kv_heads,
|
|
int q_len, int kv_len, int head_dim,
|
|
float scale, int causal, int window_size, void* stream
|
|
) {
|
|
int q_tiles = (q_len + BR - 1) / BR;
|
|
dim3 grid(q_tiles, batch * num_q_heads);
|
|
int block = THREADS_PER_BLOCK;
|
|
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
|
|
|
|
flash_attention_sinks_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)Q,
|
|
(const __nv_bfloat16*)K,
|
|
(const __nv_bfloat16*)V,
|
|
(__nv_bfloat16*)O,
|
|
(const __nv_bfloat16*)sinks,
|
|
num_q_heads, num_kv_heads,
|
|
q_len, kv_len, head_dim,
|
|
scale, causal, window_size
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_decode_attention_bf16(
|
|
const void* Q, const void* K, const void* V, void* O,
|
|
int batch, int num_q_heads, int num_kv_heads,
|
|
int kv_len, int head_dim,
|
|
float scale, int causal, void* stream
|
|
) {
|
|
int grid = batch * num_q_heads;
|
|
int block = DECODE_THREADS;
|
|
|
|
decode_attention_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)Q,
|
|
(const __nv_bfloat16*)K,
|
|
(const __nv_bfloat16*)V,
|
|
(__nv_bfloat16*)O,
|
|
num_q_heads, num_kv_heads,
|
|
kv_len, head_dim,
|
|
scale
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|