attention: tree-aware paged_decode_attention_tree kernel + wrapper
New CUDA kernel paged_decode_attention_tree_bf16_kernel: same as base paged_decode_attention but with a per-query mask over the newly-written K/V region. `tree_mask[i][j] != 0` iff query i attends to newly-written K/V at slot j. Positions before `tree_start` are always attended. Motivation: speculative decoding with tree drafting needs siblings at the same target position to attend to their own branch's history, not each other's K/V. Rust binding: paged_decode_attention_tree(...) mirrors paged_decode_attention plus tree_mask_ptr, tree_start, tree_len. Forward path: Qwen3::forward_verify_paged_decode_attention_tree_with_hidden takes explicit positions, kv_lens, and a flattened [N*N] tree_mask. Sanity check: bench-eagle3's γ_multi path now routes through the tree kernel with a causal mask (mask[i][j]=1 iff j<=i), producing bit- equivalent output to the non-tree variant. matched=false pattern + acceptance rate + speedup all identical to previous run within noise (11.3% acceptance, 1.00× speedup with the mask-check overhead). --tree CLI flag is parsed but reserved. Real tree drafting (siblings sharing a target position) is blocked by KV cache position rigidity: paged_cache stores K/V at cache-position ≡ target-position, so an accepted sibling at target position P+1 has its K/V physically at cache position P+2 (its unique slot in the batched write). Continuing decode at P+1 would see the WRONG K/V (top-1 sibling's, not accepted top-2 sibling's). Fix requires either KV-slot remap on acceptance or a virtual position layer. Infrastructure is in place, next step is tackling that remap.
This commit is contained in:
@@ -83,6 +83,24 @@ unsafe extern "C" {
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_tree_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
tree_mask: *const i32,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
tree_start: i32,
|
||||
tree_len: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
@@ -515,6 +533,62 @@ pub fn paged_decode_attention(
|
||||
output
|
||||
}
|
||||
|
||||
/// Tree-aware paged decode attention. Adds a per-query attention mask over
|
||||
/// the newly-written K/V region `[tree_start, tree_start+tree_len)`. Query i
|
||||
/// attends to position tree_start+j iff tree_mask[i, j] != 0. Positions <
|
||||
/// tree_start are always attended.
|
||||
///
|
||||
/// Used by speculative decoding with tree drafting to let sibling candidates
|
||||
/// share position slots without seeing each other's K/V.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention_tree(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
tree_mask_ptr: *const i32,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
tree_start: usize,
|
||||
tree_len: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_tree_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
tree_mask_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
tree_start as i32,
|
||||
tree_len as i32,
|
||||
scale,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention with attention sinks and optional sliding window.
|
||||
///
|
||||
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)
|
||||
|
||||
@@ -16,7 +16,8 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use attention::{
|
||||
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
|
||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16,
|
||||
reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||
|
||||
@@ -94,6 +94,7 @@ fn main() {
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
||||
let use_tree = args.iter().any(|a| a == "--tree");
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
@@ -150,7 +151,11 @@ fn main() {
|
||||
baseline_tokens += baseline.ids.len();
|
||||
drop(baseline_cache);
|
||||
|
||||
// Speculative with EAGLE, γ from CLI.
|
||||
// Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with
|
||||
// a causal mask (equivalent to non-tree behavior); a real tree
|
||||
// (siblings sharing target positions) would require KV cache slot
|
||||
// remap after acceptance, which is out of scope for this iteration.
|
||||
let _ = use_tree; // reserved for future tree drafting
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = if gamma == 1 {
|
||||
run_eagle_gamma1(
|
||||
@@ -441,9 +446,23 @@ fn run_eagle_gamma_multi(
|
||||
for &d in drafts.iter() {
|
||||
verify_input.push(d);
|
||||
}
|
||||
let n = verify_input.len();
|
||||
let pos_offset = cache.seq_len(slot);
|
||||
let positions_v: Vec<u32> = (0..n).map(|i| (pos_offset + i) as u32).collect();
|
||||
let kv_lens_v: Vec<i32> = (0..n).map(|i| (pos_offset + i + 1) as i32).collect();
|
||||
// Causal mask over new-writes: mask[i][j] = 1 iff j <= i.
|
||||
let mut tree_mask: Vec<i32> = vec![0; n * n];
|
||||
for i in 0..n {
|
||||
for j in 0..=i {
|
||||
tree_mask[i * n + j] = 1;
|
||||
}
|
||||
}
|
||||
let (verify_logits, verify_hooks) = target
|
||||
.forward_verify_paged_decode_attention_with_hidden(
|
||||
.forward_verify_paged_decode_attention_tree_with_hidden(
|
||||
&verify_input,
|
||||
&positions_v,
|
||||
&kv_lens_v,
|
||||
&tree_mask,
|
||||
slot,
|
||||
cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
|
||||
@@ -1230,6 +1230,138 @@ impl Qwen3 {
|
||||
(logits, hidden_arr)
|
||||
}
|
||||
|
||||
/// Tree-aware verify: like `_with_hidden` but supports sibling candidates
|
||||
/// sharing the same target position. Caller supplies per-token positions
|
||||
/// (for RoPE), kv_lens (attention context length), and a flattened
|
||||
/// `tree_mask` (`[new_tokens, new_tokens]` i32; `mask[i, j]!=0` iff query i
|
||||
/// attends to newly-written K/V at slot j). Positions in the paged cache
|
||||
/// before pos_offset are always attended (regular history).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward_verify_paged_decode_attention_tree_with_hidden(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
positions: &[u32],
|
||||
kv_lens: &[i32],
|
||||
tree_mask: &[i32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
hook_layers: &[usize; 3],
|
||||
) -> (Tensor, [Tensor; 3]) {
|
||||
let new_tokens = token_ids.len();
|
||||
assert_eq!(positions.len(), new_tokens);
|
||||
assert_eq!(kv_lens.len(), new_tokens);
|
||||
assert_eq!(tree_mask.len(), new_tokens * new_tokens);
|
||||
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let slots = vec![slot; new_tokens];
|
||||
paged_cache.sync_active_batch_with_lens(&slots, kv_lens);
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
// Upload tree_mask [new_tokens, new_tokens] i32 to GPU.
|
||||
let mask_bytes: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(tree_mask.as_ptr() as *const u8, tree_mask.len() * 4)
|
||||
};
|
||||
let mut mask_buf =
|
||||
xserv_cuda::allocator::cached_alloc(mask_bytes.len()).expect("alloc tree_mask");
|
||||
mask_buf.copy_from_host(mask_bytes).unwrap();
|
||||
let mask_ptr = mask_buf.as_ptr() as *const i32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let q_flat = q_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, positions);
|
||||
|
||||
let v_3d = v_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
|
||||
|
||||
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention_tree(
|
||||
&q_decode,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
mask_ptr,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
pos_offset,
|
||||
new_tokens,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
|
||||
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
|
||||
if layer_idx == h_layer {
|
||||
hooks[h_idx] = Some(x.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_t);
|
||||
let hidden_arr = [
|
||||
hooks[0].take().expect("hook layer 0 not reached"),
|
||||
hooks[1].take().expect("hook layer 1 not reached"),
|
||||
hooks[2].take().expect("hook layer 2 not reached"),
|
||||
];
|
||||
(logits, hidden_arr)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
|
||||
@@ -189,6 +189,169 @@ __global__ void paged_decode_attention_bf16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Tree-aware paged decode attention: per-query mask lets sibling candidates
|
||||
// in the same batch attend to different subsets of newly-written K/V.
|
||||
// `tree_start`: position where newly-written K/V begins (typically pos_offset).
|
||||
// `tree_len`: number of newly-written K/V rows (= batch, one per query).
|
||||
// `tree_mask[i][j] = 1` iff query i attends to K/V at position `tree_start+j`.
|
||||
// Positions < tree_start are always attended (regular history).
|
||||
__global__ void paged_decode_attention_tree_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ Q,
|
||||
const __nv_bfloat16* __restrict__ K_cache,
|
||||
const __nv_bfloat16* __restrict__ V_cache,
|
||||
__nv_bfloat16* __restrict__ O,
|
||||
const int* __restrict__ block_tables,
|
||||
const int* __restrict__ context_lens,
|
||||
const int* __restrict__ tree_mask, // [batch, tree_len] int32
|
||||
int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
int tree_start, int tree_len,
|
||||
float scale
|
||||
) {
|
||||
int seq_idx = blockIdx.y;
|
||||
int q_head = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
int kv_len = context_lens[seq_idx];
|
||||
if (kv_len <= 0) {
|
||||
if (tid < head_dim) {
|
||||
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
|
||||
__float2bfloat16(0.0f);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int heads_per_group = num_q_heads / num_kv_heads;
|
||||
int kv_head = q_head / heads_per_group;
|
||||
|
||||
const __nv_bfloat16* Q_ptr = Q +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
__nv_bfloat16* O_ptr = O +
|
||||
((long long)seq_idx * num_q_heads + q_head) * head_dim;
|
||||
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
|
||||
const int* mask_row = tree_mask + (long long)seq_idx * tree_len;
|
||||
|
||||
float q_reg[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
q_reg[d] = __bfloat162float(Q_ptr[d]);
|
||||
}
|
||||
|
||||
float local_max = -INFINITY;
|
||||
float local_sum = 0.0f;
|
||||
float local_O[PAGED_HEAD_DIM_MAX];
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
|
||||
|
||||
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
|
||||
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
|
||||
|
||||
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
|
||||
// Tree mask: skip positions in [tree_start, tree_start+tree_len) that
|
||||
// the mask marks as 0. Everything else (history) is always attended.
|
||||
if (pos >= tree_start && pos < tree_start + tree_len) {
|
||||
if (mask_row[pos - tree_start] == 0) continue;
|
||||
}
|
||||
|
||||
int logical_blk = pos / PAGED_BLOCK_SIZE;
|
||||
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
|
||||
int phys_blk = bt[logical_blk];
|
||||
|
||||
const __nv_bfloat16* K_pos = K_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
const __nv_bfloat16* V_pos = V_cache
|
||||
+ (long long)phys_blk * kv_stride_block
|
||||
+ kv_head * kv_stride_head
|
||||
+ slot_in_blk * head_dim;
|
||||
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
dot += q_reg[d] * __bfloat162float(K_pos[d]);
|
||||
}
|
||||
float s = dot * scale;
|
||||
|
||||
float new_max = fmaxf(local_max, s);
|
||||
float correction = expf(local_max - new_max);
|
||||
float p = expf(s - new_max);
|
||||
|
||||
local_sum = local_sum * correction + p;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
local_O[d] += p * __bfloat162float(V_pos[d]);
|
||||
}
|
||||
|
||||
local_max = new_max;
|
||||
}
|
||||
|
||||
// Block-level reduction (identical to base kernel).
|
||||
__shared__ float smem_max[32];
|
||||
__shared__ float smem_sum[32];
|
||||
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
|
||||
|
||||
int lane = tid & 31;
|
||||
int warp_id = tid >> 5;
|
||||
int num_warps = PAGED_THREADS >> 5;
|
||||
|
||||
float warp_max = local_max;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
|
||||
if (lane == 0) smem_max[warp_id] = warp_max;
|
||||
__syncthreads();
|
||||
|
||||
float global_max;
|
||||
if (tid == 0) {
|
||||
global_max = smem_max[0];
|
||||
for (int i = 1; i < num_warps; i++)
|
||||
global_max = fmaxf(global_max, smem_max[i]);
|
||||
smem_max[0] = global_max;
|
||||
}
|
||||
__syncthreads();
|
||||
global_max = smem_max[0];
|
||||
|
||||
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
|
||||
local_sum *= rescale;
|
||||
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
|
||||
|
||||
float warp_sum = local_sum;
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
|
||||
if (lane == 0) smem_sum[warp_id] = warp_sum;
|
||||
__syncthreads();
|
||||
|
||||
float global_sum;
|
||||
if (tid == 0) {
|
||||
global_sum = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
|
||||
smem_sum[0] = global_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_sum = smem_sum[0];
|
||||
|
||||
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
|
||||
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int d = 0; d < head_dim; d++) {
|
||||
float val = local_O[d];
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
if (lane == 0) smem_O_warp[warp_id][d] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
|
||||
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
|
||||
O_ptr[d] = __float2bfloat16(out * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Extended paged decode attention with attention sinks and sliding window.
|
||||
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
|
||||
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
|
||||
@@ -389,6 +552,36 @@ void launch_paged_decode_attention_bf16(
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_paged_decode_attention_tree_bf16(
|
||||
const void* Q,
|
||||
const void* K_cache,
|
||||
const void* V_cache,
|
||||
void* O,
|
||||
const int* block_tables,
|
||||
const int* context_lens,
|
||||
const int* tree_mask,
|
||||
int batch, int num_q_heads, int num_kv_heads,
|
||||
int head_dim, int max_blocks_per_seq,
|
||||
int tree_start, int tree_len,
|
||||
float scale, void* stream
|
||||
) {
|
||||
dim3 grid(num_q_heads, batch);
|
||||
int block = PAGED_THREADS;
|
||||
|
||||
paged_decode_attention_tree_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)Q,
|
||||
(const __nv_bfloat16*)K_cache,
|
||||
(const __nv_bfloat16*)V_cache,
|
||||
(__nv_bfloat16*)O,
|
||||
block_tables, context_lens, tree_mask,
|
||||
num_q_heads, num_kv_heads,
|
||||
head_dim, max_blocks_per_seq,
|
||||
tree_start, tree_len,
|
||||
scale
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_paged_decode_attention_sinks_bf16(
|
||||
const void* Q,
|
||||
const void* K_cache,
|
||||
|
||||
Reference in New Issue
Block a user