kernels: flash attention with gpt-oss sinks + sliding window

Add flash_attention_sinks_bf16 prefill kernel that folds the per-head
attention sink into the softmax denominator (exactly as the decode sink
kernel) and supports an optional sliding-window mask matching HF gpt-oss.

Wire it through xserv-kernels (flash_attention_sinks) and use it in
GptOss prefill, replacing the post-hoc sink approximation for an exact
match against the reference math.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 00:56:10 +08:00
parent 5cb3cf28f9
commit 9c98c169ff
4 changed files with 260 additions and 4 deletions

View File

@@ -16,6 +16,13 @@ unsafe extern "C" {
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, stream: *mut c_void,
);
fn launch_flash_attention_sinks_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
sinks: *const c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
q_len: i32, kv_len: i32, head_dim: i32,
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
);
fn launch_decode_attention_bf16(
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
batch: i32, num_q_heads: i32, num_kv_heads: i32,
@@ -295,6 +302,65 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
output
}
/// Flash attention for prefill with gpt-oss attention sinks + optional sliding window.
///
/// Same layout/contract as `flash_attention`, plus a per-head `sinks` tensor
/// ([num_q_heads] BF16, GPU) folded into the softmax denominator, and a
/// `window_size` (0 = full causal, >0 = sliding window). Always causal.
pub fn flash_attention_sinks(
q: &Tensor,
k: &Tensor,
v: &Tensor,
sinks: &Tensor,
window_size: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(k.ndim(), 4);
assert_eq!(v.ndim(), 4);
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
assert_eq!(q.dtype(), DType::BF16);
assert_eq!(k.dtype(), DType::BF16);
assert_eq!(v.dtype(), DType::BF16);
let batch = q.shape()[0];
let num_q_heads = q.shape()[1];
let q_len = q.shape()[2];
let head_dim = q.shape()[3];
let num_kv_heads = k.shape()[1];
let kv_len = k.shape()[2];
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
assert!(num_q_heads % num_kv_heads == 0);
assert!(head_dim <= 128);
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
unsafe {
launch_flash_attention_sinks_bf16(
q.data_ptr() as *const c_void,
k.data_ptr() as *const c_void,
v.data_ptr() as *const c_void,
output.data_ptr() as *mut c_void,
sinks.data_ptr() as *const c_void,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
q_len as i32,
kv_len as i32,
head_dim as i32,
scale,
1, // always causal
window_size as i32,
std::ptr::null_mut(),
);
}
output
}
/// Paged decode attention.
///
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU

View File

@@ -13,7 +13,7 @@ pub mod transpose;
pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
pub use embedding::embedding;
pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use layernorm::layernorm;

View File

@@ -373,9 +373,8 @@ impl GptOss {
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
// Flash attention for prefill (sinks handled post-hoc for simplicity)
// TODO: integrate sinks into flash attention for exact match
let attn_out = flash_attention(&q, &k_full, &v_full, true);
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);

View File

@@ -197,6 +197,172 @@ __global__ void flash_attention_bf16_kernel(
}
}
// Flash Attention 2 forward with gpt-oss attention sinks + optional sliding window.
// Identical to flash_attention_bf16_kernel, plus:
// - sinks: [num_q_heads] BF16 — a per-head extra softmax logit (no value),
// folded into the denominator after the K/V tiles (exactly as the decode
// sink kernel does).
// - window_size > 0: sliding-window mask. Query at global position p attends
// to keys k with p - window_size < k <= p (matches HF gpt-oss).
__global__ void flash_attention_sinks_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K,
const __nv_bfloat16* __restrict__ V,
__nv_bfloat16* __restrict__ O,
const __nv_bfloat16* __restrict__ sinks, // [num_q_heads] or NULL
int num_q_heads, int num_kv_heads,
int q_len, int kv_len, int head_dim,
float scale, int causal, int window_size
) {
int q_tile_idx = blockIdx.x;
int bh = blockIdx.y;
int batch_idx = bh / num_q_heads;
int q_head = bh % num_q_heads;
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
int q_tile_start = q_tile_idx * BR;
if (q_tile_start >= q_len) return;
int q_tile_rows = min(BR, q_len - q_tile_start);
const __nv_bfloat16* Q_head = Q + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
const __nv_bfloat16* K_head = K + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
const __nv_bfloat16* V_head = V + ((long long)batch_idx * num_kv_heads + kv_head) * kv_len * head_dim;
__nv_bfloat16* O_head = O + ((long long)batch_idx * num_q_heads + q_head) * q_len * head_dim;
int tid = threadIdx.x;
extern __shared__ __nv_bfloat16 smem[];
__nv_bfloat16* smem_q = smem;
__nv_bfloat16* smem_kv = smem + BR * head_dim;
int q_elems = q_tile_rows * head_dim;
for (int i = tid; i < q_elems; i += THREADS_PER_BLOCK) {
int row = i / head_dim;
int col = i % head_dim;
smem_q[row * head_dim + col] = Q_head[(q_tile_start + row) * head_dim + col];
}
for (int i = q_elems + tid; i < BR * head_dim; i += THREADS_PER_BLOCK) {
smem_q[i] = __float2bfloat16(0.0f);
}
__syncthreads();
bool owns_row = (tid < q_tile_rows);
float O_acc[128];
float m_val = -INFINITY;
float l_val = 0.0f;
if (owns_row) {
for (int d = 0; d < head_dim; d++) O_acc[d] = 0.0f;
}
int kv_offset = kv_len - q_len;
int num_kv_tiles = (kv_len + BC - 1) / BC;
for (int j = 0; j < num_kv_tiles; j++) {
int kv_tile_start = j * BC;
int kv_tile_cols = min(BC, kv_len - kv_tile_start);
if (causal) {
int max_allowed_kv = (q_tile_start + q_tile_rows - 1) + kv_offset;
if (kv_tile_start > max_allowed_kv) continue;
}
int kv_elems = kv_tile_cols * head_dim;
for (int i = tid; i < kv_elems; i += THREADS_PER_BLOCK) {
int row = i / head_dim;
int col = i % head_dim;
smem_kv[row * head_dim + col] = K_head[(kv_tile_start + row) * head_dim + col];
}
for (int i = kv_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
smem_kv[i] = __float2bfloat16(0.0f);
}
__syncthreads();
float P[BC];
if (owns_row) {
float row_max = -INFINITY;
int q_pos = q_tile_start + tid + kv_offset; // global query position
for (int c = 0; c < kv_tile_cols; c++) {
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += __bfloat162float(smem_q[tid * head_dim + d])
* __bfloat162float(smem_kv[c * head_dim + d]);
}
float s = dot * scale;
int kv_pos = kv_tile_start + c;
if (causal && kv_pos > q_pos) {
s = -INFINITY;
}
// Sliding window: drop keys older than the window.
if (window_size > 0 && kv_pos <= q_pos - window_size) {
s = -INFINITY;
}
P[c] = s;
row_max = fmaxf(row_max, s);
}
float m_new = fmaxf(m_val, row_max);
float psum = 0.0f;
for (int c = 0; c < kv_tile_cols; c++) {
P[c] = expf(P[c] - m_new);
psum += P[c];
}
float correction = expf(m_val - m_new);
l_val = correction * l_val + psum;
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
m_val = m_new;
}
__syncthreads();
int v_elems = kv_tile_cols * head_dim;
for (int i = tid; i < v_elems; i += THREADS_PER_BLOCK) {
int row = i / head_dim;
int col = i % head_dim;
smem_kv[row * head_dim + col] = V_head[(kv_tile_start + row) * head_dim + col];
}
for (int i = v_elems + tid; i < BC * head_dim; i += THREADS_PER_BLOCK) {
smem_kv[i] = __float2bfloat16(0.0f);
}
__syncthreads();
if (owns_row) {
for (int c = 0; c < kv_tile_cols; c++) {
float p = P[c];
if (p != 0.0f) {
for (int d = 0; d < head_dim; d++) {
O_acc[d] += p * __bfloat162float(smem_kv[c * head_dim + d]);
}
}
}
}
__syncthreads();
}
// Fold in the per-head attention sink (extra logit, no value contribution).
if (owns_row && sinks != nullptr) {
float sink_logit = __bfloat162float(sinks[q_head]);
float m_new = fmaxf(m_val, sink_logit);
float correction = expf(m_val - m_new);
l_val = correction * l_val + expf(sink_logit - m_new);
for (int d = 0; d < head_dim; d++) O_acc[d] *= correction;
m_val = m_new;
}
if (owns_row) {
float inv_l = (l_val > 0.0f) ? (1.0f / l_val) : 0.0f;
int global_row = q_tile_start + tid;
for (int d = 0; d < head_dim; d++) {
O_head[global_row * head_dim + d] = __float2bfloat16(O_acc[d] * inv_l);
}
}
}
// ============================================================
// Decode Attention kernel: optimized for Q_len=1 (single-token decode).
// Parallelizes across KV sequence dimension instead of Q rows.
@@ -395,6 +561,31 @@ void launch_flash_attention_bf16(
CUDA_CHECK_LAST_ERROR();
}
void launch_flash_attention_sinks_bf16(
const void* Q, const void* K, const void* V, void* O,
const void* sinks,
int batch, int num_q_heads, int num_kv_heads,
int q_len, int kv_len, int head_dim,
float scale, int causal, int window_size, void* stream
) {
int q_tiles = (q_len + BR - 1) / BR;
dim3 grid(q_tiles, batch * num_q_heads);
int block = THREADS_PER_BLOCK;
int smem_bytes = (BR + BC) * head_dim * (int)sizeof(__nv_bfloat16);
flash_attention_sinks_bf16_kernel<<<grid, block, smem_bytes, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K,
(const __nv_bfloat16*)V,
(__nv_bfloat16*)O,
(const __nv_bfloat16*)sinks,
num_q_heads, num_kv_heads,
q_len, kv_len, head_dim,
scale, causal, window_size
);
CUDA_CHECK_LAST_ERROR();
}
void launch_decode_attention_bf16(
const void* Q, const void* K, const void* V, void* O,
int batch, int num_q_heads, int num_kv_heads,