kernels: flash attention with gpt-oss sinks + sliding window
Add flash_attention_sinks_bf16 prefill kernel that folds the per-head attention sink into the softmax denominator (exactly as the decode sink kernel) and supports an optional sliding-window mask matching HF gpt-oss. Wire it through xserv-kernels (flash_attention_sinks) and use it in GptOss prefill, replacing the post-hoc sink approximation for an exact match against the reference math. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,13 @@ unsafe extern "C" {
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, stream: *mut c_void,
|
||||
);
|
||||
fn launch_flash_attention_sinks_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
sinks: *const c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
q_len: i32, kv_len: i32, head_dim: i32,
|
||||
scale: f32, causal: i32, window_size: i32, stream: *mut c_void,
|
||||
);
|
||||
fn launch_decode_attention_bf16(
|
||||
q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void,
|
||||
batch: i32, num_q_heads: i32, num_kv_heads: i32,
|
||||
@@ -295,6 +302,65 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
output
|
||||
}
|
||||
|
||||
/// Flash attention for prefill with gpt-oss attention sinks + optional sliding window.
|
||||
///
|
||||
/// Same layout/contract as `flash_attention`, plus a per-head `sinks` tensor
|
||||
/// ([num_q_heads] BF16, GPU) folded into the softmax denominator, and a
|
||||
/// `window_size` (0 = full causal, >0 = sliding window). Always causal.
|
||||
pub fn flash_attention_sinks(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
sinks: &Tensor,
|
||||
window_size: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(k.ndim(), 4);
|
||||
assert_eq!(v.ndim(), 4);
|
||||
assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous());
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert_eq!(k.dtype(), DType::BF16);
|
||||
assert_eq!(v.dtype(), DType::BF16);
|
||||
|
||||
let batch = q.shape()[0];
|
||||
let num_q_heads = q.shape()[1];
|
||||
let q_len = q.shape()[2];
|
||||
let head_dim = q.shape()[3];
|
||||
let num_kv_heads = k.shape()[1];
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries");
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_flash_attention_sinks_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k.data_ptr() as *const c_void,
|
||||
v.data_ptr() as *const c_void,
|
||||
output.data_ptr() as *mut c_void,
|
||||
sinks.data_ptr() as *const c_void,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
q_len as i32,
|
||||
kv_len as i32,
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // always causal
|
||||
window_size as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention.
|
||||
///
|
||||
/// q: [batch, num_q_heads, 1, head_dim] BF16, contiguous, GPU
|
||||
|
||||
@@ -13,7 +13,7 @@ pub mod transpose;
|
||||
pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use embedding::embedding;
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
|
||||
@@ -373,9 +373,8 @@ impl GptOss {
|
||||
paged_cache.append_tokens(slot, layer_idx, &k, &v, new_tokens, pos_offset);
|
||||
let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx);
|
||||
|
||||
// Flash attention for prefill (sinks handled post-hoc for simplicity)
|
||||
// TODO: integrate sinks into flash attention for exact match
|
||||
let attn_out = flash_attention(&q, &k_full, &v_full, true);
|
||||
// Flash attention with gpt-oss sinks + (per-layer) sliding window.
|
||||
let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size);
|
||||
|
||||
let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
|
||||
Reference in New Issue
Block a user