attention: tree-aware paged_decode_attention_tree kernel + wrapper
New CUDA kernel paged_decode_attention_tree_bf16_kernel: same as base paged_decode_attention but with a per-query mask over the newly-written K/V region. `tree_mask[i][j] != 0` iff query i attends to newly-written K/V at slot j. Positions before `tree_start` are always attended. Motivation: speculative decoding with tree drafting needs siblings at the same target position to attend to their own branch's history, not each other's K/V. Rust binding: paged_decode_attention_tree(...) mirrors paged_decode_attention plus tree_mask_ptr, tree_start, tree_len. Forward path: Qwen3::forward_verify_paged_decode_attention_tree_with_hidden takes explicit positions, kv_lens, and a flattened [N*N] tree_mask. Sanity check: bench-eagle3's γ_multi path now routes through the tree kernel with a causal mask (mask[i][j]=1 iff j<=i), producing bit- equivalent output to the non-tree variant. matched=false pattern + acceptance rate + speedup all identical to previous run within noise (11.3% acceptance, 1.00× speedup with the mask-check overhead). --tree CLI flag is parsed but reserved. Real tree drafting (siblings sharing a target position) is blocked by KV cache position rigidity: paged_cache stores K/V at cache-position ≡ target-position, so an accepted sibling at target position P+1 has its K/V physically at cache position P+2 (its unique slot in the batched write). Continuing decode at P+1 would see the WRONG K/V (top-1 sibling's, not accepted top-2 sibling's). Fix requires either KV-slot remap on acceptance or a virtual position layer. Infrastructure is in place, next step is tackling that remap.
This commit is contained in:
@@ -83,6 +83,24 @@ unsafe extern "C" {
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_tree_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
v_cache: *const c_void,
|
||||
o: *mut c_void,
|
||||
block_tables: *const i32,
|
||||
context_lens: *const i32,
|
||||
tree_mask: *const i32,
|
||||
batch: i32,
|
||||
num_q_heads: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
max_blocks_per_seq: i32,
|
||||
tree_start: i32,
|
||||
tree_len: i32,
|
||||
scale: f32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_paged_decode_attention_sinks_bf16(
|
||||
q: *const c_void,
|
||||
k_cache: *const c_void,
|
||||
@@ -515,6 +533,62 @@ pub fn paged_decode_attention(
|
||||
output
|
||||
}
|
||||
|
||||
/// Tree-aware paged decode attention. Adds a per-query attention mask over
|
||||
/// the newly-written K/V region `[tree_start, tree_start+tree_len)`. Query i
|
||||
/// attends to position tree_start+j iff tree_mask[i, j] != 0. Positions <
|
||||
/// tree_start are always attended.
|
||||
///
|
||||
/// Used by speculative decoding with tree drafting to let sibling candidates
|
||||
/// share position slots without seeing each other's K/V.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn paged_decode_attention_tree(
|
||||
q: &Tensor,
|
||||
k_cache_ptr: *const c_void,
|
||||
v_cache_ptr: *const c_void,
|
||||
block_tables_ptr: *const i32,
|
||||
context_lens_ptr: *const i32,
|
||||
tree_mask_ptr: *const i32,
|
||||
batch: usize,
|
||||
num_q_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
max_blocks_per_seq: usize,
|
||||
tree_start: usize,
|
||||
tree_len: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(q.ndim(), 4);
|
||||
assert_eq!(q.shape()[2], 1);
|
||||
assert_eq!(q.dtype(), DType::BF16);
|
||||
assert!(num_q_heads % num_kv_heads == 0);
|
||||
assert!(head_dim <= 128);
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
|
||||
|
||||
unsafe {
|
||||
launch_paged_decode_attention_tree_bf16(
|
||||
q.data_ptr() as *const c_void,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
output.data_ptr() as *mut c_void,
|
||||
block_tables_ptr,
|
||||
context_lens_ptr,
|
||||
tree_mask_ptr,
|
||||
batch as i32,
|
||||
num_q_heads as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
tree_start as i32,
|
||||
tree_len as i32,
|
||||
scale,
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Paged decode attention with attention sinks and optional sliding window.
|
||||
///
|
||||
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)
|
||||
|
||||
@@ -16,7 +16,8 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use attention::{
|
||||
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
|
||||
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16,
|
||||
reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||
|
||||
@@ -94,6 +94,7 @@ fn main() {
|
||||
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
|
||||
let device = arg_usize(&args, "--device", 0) as u32;
|
||||
let gamma = arg_usize(&args, "--gamma", 2).max(1);
|
||||
let use_tree = args.iter().any(|a| a == "--tree");
|
||||
|
||||
xserv_cuda::device::set_device(device).unwrap();
|
||||
let info = xserv_cuda::device::device_info(device).unwrap();
|
||||
@@ -150,7 +151,11 @@ fn main() {
|
||||
baseline_tokens += baseline.ids.len();
|
||||
drop(baseline_cache);
|
||||
|
||||
// Speculative with EAGLE, γ from CLI.
|
||||
// Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with
|
||||
// a causal mask (equivalent to non-tree behavior); a real tree
|
||||
// (siblings sharing target positions) would require KV cache slot
|
||||
// remap after acceptance, which is out of scope for this iteration.
|
||||
let _ = use_tree; // reserved for future tree drafting
|
||||
let mut target_cache = new_cache(&target_config, max_seq_len, device);
|
||||
let spec = if gamma == 1 {
|
||||
run_eagle_gamma1(
|
||||
@@ -441,9 +446,23 @@ fn run_eagle_gamma_multi(
|
||||
for &d in drafts.iter() {
|
||||
verify_input.push(d);
|
||||
}
|
||||
let n = verify_input.len();
|
||||
let pos_offset = cache.seq_len(slot);
|
||||
let positions_v: Vec<u32> = (0..n).map(|i| (pos_offset + i) as u32).collect();
|
||||
let kv_lens_v: Vec<i32> = (0..n).map(|i| (pos_offset + i + 1) as i32).collect();
|
||||
// Causal mask over new-writes: mask[i][j] = 1 iff j <= i.
|
||||
let mut tree_mask: Vec<i32> = vec![0; n * n];
|
||||
for i in 0..n {
|
||||
for j in 0..=i {
|
||||
tree_mask[i * n + j] = 1;
|
||||
}
|
||||
}
|
||||
let (verify_logits, verify_hooks) = target
|
||||
.forward_verify_paged_decode_attention_with_hidden(
|
||||
.forward_verify_paged_decode_attention_tree_with_hidden(
|
||||
&verify_input,
|
||||
&positions_v,
|
||||
&kv_lens_v,
|
||||
&tree_mask,
|
||||
slot,
|
||||
cache,
|
||||
&EAGLE_HOOK_LAYERS,
|
||||
|
||||
@@ -1230,6 +1230,138 @@ impl Qwen3 {
|
||||
(logits, hidden_arr)
|
||||
}
|
||||
|
||||
/// Tree-aware verify: like `_with_hidden` but supports sibling candidates
|
||||
/// sharing the same target position. Caller supplies per-token positions
|
||||
/// (for RoPE), kv_lens (attention context length), and a flattened
|
||||
/// `tree_mask` (`[new_tokens, new_tokens]` i32; `mask[i, j]!=0` iff query i
|
||||
/// attends to newly-written K/V at slot j). Positions in the paged cache
|
||||
/// before pos_offset are always attended (regular history).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward_verify_paged_decode_attention_tree_with_hidden(
|
||||
&self,
|
||||
token_ids: &[u32],
|
||||
positions: &[u32],
|
||||
kv_lens: &[i32],
|
||||
tree_mask: &[i32],
|
||||
slot: usize,
|
||||
paged_cache: &mut PagedKVCache,
|
||||
hook_layers: &[usize; 3],
|
||||
) -> (Tensor, [Tensor; 3]) {
|
||||
let new_tokens = token_ids.len();
|
||||
assert_eq!(positions.len(), new_tokens);
|
||||
assert_eq!(kv_lens.len(), new_tokens);
|
||||
assert_eq!(tree_mask.len(), new_tokens * new_tokens);
|
||||
|
||||
let pos_offset = paged_cache.seq_len(slot);
|
||||
let num_heads = self.local_num_heads;
|
||||
let num_kv_heads = self.local_num_kv_heads;
|
||||
let head_dim = self.config.head_dim();
|
||||
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
|
||||
|
||||
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
|
||||
paged_cache.advance_seq_len(slot, new_tokens);
|
||||
|
||||
let slots = vec![slot; new_tokens];
|
||||
paged_cache.sync_active_batch_with_lens(&slots, kv_lens);
|
||||
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
|
||||
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
|
||||
let max_blocks = paged_cache.max_blocks_per_seq();
|
||||
|
||||
// Upload tree_mask [new_tokens, new_tokens] i32 to GPU.
|
||||
let mask_bytes: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(tree_mask.as_ptr() as *const u8, tree_mask.len() * 4)
|
||||
};
|
||||
let mut mask_buf =
|
||||
xserv_cuda::allocator::cached_alloc(mask_bytes.len()).expect("alloc tree_mask");
|
||||
mask_buf.copy_from_host(mask_bytes).unwrap();
|
||||
let mask_ptr = mask_buf.as_ptr() as *const i32;
|
||||
|
||||
let mut x = embedding(&self.embed_tokens, token_ids);
|
||||
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
|
||||
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let residual = x.clone();
|
||||
let normed = rmsnorm(&x, &layer.input_norm, eps);
|
||||
|
||||
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
|
||||
let q_dim = num_heads * head_dim;
|
||||
let kv_dim = num_kv_heads * head_dim;
|
||||
let q_all = qkv.narrow(1, 0, q_dim);
|
||||
let k_all = qkv.narrow(1, q_dim, kv_dim);
|
||||
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
|
||||
|
||||
let q_flat = q_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_heads, head_dim]);
|
||||
let k_flat = k_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens * num_kv_heads, head_dim]);
|
||||
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
|
||||
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
|
||||
|
||||
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
|
||||
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
rope_inplace(&q_3d, &self.rope_cache, positions);
|
||||
rope_inplace(&k_3d, &self.rope_cache, positions);
|
||||
|
||||
let v_3d = v_all
|
||||
.contiguous()
|
||||
.reshape(&[new_tokens, num_kv_heads, head_dim]);
|
||||
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
|
||||
|
||||
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
|
||||
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
|
||||
let attn_out = xserv_kernels::paged_decode_attention_tree(
|
||||
&q_decode,
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
bt_ptr,
|
||||
cl_ptr,
|
||||
mask_ptr,
|
||||
new_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_blocks,
|
||||
pos_offset,
|
||||
new_tokens,
|
||||
);
|
||||
|
||||
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
|
||||
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
|
||||
self.all_reduce(&attn_proj);
|
||||
|
||||
let (normed, x_new) =
|
||||
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
|
||||
let residual = x_new.clone();
|
||||
|
||||
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
|
||||
let ffn_dim = gate_up.shape()[1] / 2;
|
||||
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
|
||||
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
|
||||
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
|
||||
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
|
||||
self.all_reduce(&down);
|
||||
x = add_any(&residual, &down);
|
||||
|
||||
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
|
||||
if layer_idx == h_layer {
|
||||
hooks[h_idx] = Some(x.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let x = rmsnorm(&x, &self.norm, eps);
|
||||
let logits = matmul_2d(&x, &self.lm_head_t);
|
||||
let hidden_arr = [
|
||||
hooks[0].take().expect("hook layer 0 not reached"),
|
||||
hooks[1].take().expect("hook layer 1 not reached"),
|
||||
hooks[2].take().expect("hook layer 2 not reached"),
|
||||
];
|
||||
(logits, hidden_arr)
|
||||
}
|
||||
|
||||
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
|
||||
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
|
||||
let new_tokens = token_ids.len();
|
||||
|
||||
Reference in New Issue
Block a user