attention: tree-aware paged_decode_attention_tree kernel + wrapper

New CUDA kernel paged_decode_attention_tree_bf16_kernel: same as base
paged_decode_attention but with a per-query mask over the newly-written
K/V region. `tree_mask[i][j] != 0` iff query i attends to newly-written
K/V at slot j. Positions before `tree_start` are always attended.

Motivation: speculative decoding with tree drafting needs siblings at
the same target position to attend to their own branch's history, not
each other's K/V.

Rust binding: paged_decode_attention_tree(...) mirrors
paged_decode_attention plus tree_mask_ptr, tree_start, tree_len.

Forward path: Qwen3::forward_verify_paged_decode_attention_tree_with_hidden
takes explicit positions, kv_lens, and a flattened [N*N] tree_mask.

Sanity check: bench-eagle3's γ_multi path now routes through the tree
kernel with a causal mask (mask[i][j]=1 iff j<=i), producing bit-
equivalent output to the non-tree variant. matched=false pattern +
acceptance rate + speedup all identical to previous run within noise
(11.3% acceptance, 1.00× speedup with the mask-check overhead).

--tree CLI flag is parsed but reserved. Real tree drafting (siblings
sharing a target position) is blocked by KV cache position rigidity:
paged_cache stores K/V at cache-position ≡ target-position, so an
accepted sibling at target position P+1 has its K/V physically at
cache position P+2 (its unique slot in the batched write). Continuing
decode at P+1 would see the WRONG K/V (top-1 sibling's, not accepted
top-2 sibling's). Fix requires either KV-slot remap on acceptance or
a virtual position layer.

Infrastructure is in place, next step is tackling that remap.
This commit is contained in:
2026-07-01 20:45:55 +08:00
parent 10a98539d0
commit fd392f7fbb
5 changed files with 422 additions and 3 deletions

View File

@@ -83,6 +83,24 @@ unsafe extern "C" {
scale: f32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_tree_bf16(
q: *const c_void,
k_cache: *const c_void,
v_cache: *const c_void,
o: *mut c_void,
block_tables: *const i32,
context_lens: *const i32,
tree_mask: *const i32,
batch: i32,
num_q_heads: i32,
num_kv_heads: i32,
head_dim: i32,
max_blocks_per_seq: i32,
tree_start: i32,
tree_len: i32,
scale: f32,
stream: *mut c_void,
);
fn launch_paged_decode_attention_sinks_bf16(
q: *const c_void,
k_cache: *const c_void,
@@ -515,6 +533,62 @@ pub fn paged_decode_attention(
output
}
/// Tree-aware paged decode attention. Adds a per-query attention mask over
/// the newly-written K/V region `[tree_start, tree_start+tree_len)`. Query i
/// attends to position tree_start+j iff tree_mask[i, j] != 0. Positions <
/// tree_start are always attended.
///
/// Used by speculative decoding with tree drafting to let sibling candidates
/// share position slots without seeing each other's K/V.
#[allow(clippy::too_many_arguments)]
pub fn paged_decode_attention_tree(
q: &Tensor,
k_cache_ptr: *const c_void,
v_cache_ptr: *const c_void,
block_tables_ptr: *const i32,
context_lens_ptr: *const i32,
tree_mask_ptr: *const i32,
batch: usize,
num_q_heads: usize,
num_kv_heads: usize,
head_dim: usize,
max_blocks_per_seq: usize,
tree_start: usize,
tree_len: usize,
) -> Tensor {
assert_eq!(q.ndim(), 4);
assert_eq!(q.shape()[2], 1);
assert_eq!(q.dtype(), DType::BF16);
assert!(num_q_heads % num_kv_heads == 0);
assert!(head_dim <= 128);
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device());
unsafe {
launch_paged_decode_attention_tree_bf16(
q.data_ptr() as *const c_void,
k_cache_ptr,
v_cache_ptr,
output.data_ptr() as *mut c_void,
block_tables_ptr,
context_lens_ptr,
tree_mask_ptr,
batch as i32,
num_q_heads as i32,
num_kv_heads as i32,
head_dim as i32,
max_blocks_per_seq as i32,
tree_start as i32,
tree_len as i32,
scale,
xserv_cuda::current_stream_raw(),
);
}
output
}
/// Paged decode attention with attention sinks and optional sliding window.
///
/// sinks_ptr: pointer to [num_q_heads] BF16 on GPU (or null for no sinks)

View File

@@ -16,7 +16,8 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use attention::{
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16,
reshape_and_cache_bf16,
};
pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};

View File

@@ -94,6 +94,7 @@ fn main() {
let max_seq_len = arg_usize(&args, "--max-seq-len", DEFAULT_MAX_SEQ_LEN);
let device = arg_usize(&args, "--device", 0) as u32;
let gamma = arg_usize(&args, "--gamma", 2).max(1);
let use_tree = args.iter().any(|a| a == "--tree");
xserv_cuda::device::set_device(device).unwrap();
let info = xserv_cuda::device::device_info(device).unwrap();
@@ -150,7 +151,11 @@ fn main() {
baseline_tokens += baseline.ids.len();
drop(baseline_cache);
// Speculative with EAGLE, γ from CLI.
// Speculative with EAGLE, γ from CLI. Verify uses the tree kernel with
// a causal mask (equivalent to non-tree behavior); a real tree
// (siblings sharing target positions) would require KV cache slot
// remap after acceptance, which is out of scope for this iteration.
let _ = use_tree; // reserved for future tree drafting
let mut target_cache = new_cache(&target_config, max_seq_len, device);
let spec = if gamma == 1 {
run_eagle_gamma1(
@@ -441,9 +446,23 @@ fn run_eagle_gamma_multi(
for &d in drafts.iter() {
verify_input.push(d);
}
let n = verify_input.len();
let pos_offset = cache.seq_len(slot);
let positions_v: Vec<u32> = (0..n).map(|i| (pos_offset + i) as u32).collect();
let kv_lens_v: Vec<i32> = (0..n).map(|i| (pos_offset + i + 1) as i32).collect();
// Causal mask over new-writes: mask[i][j] = 1 iff j <= i.
let mut tree_mask: Vec<i32> = vec![0; n * n];
for i in 0..n {
for j in 0..=i {
tree_mask[i * n + j] = 1;
}
}
let (verify_logits, verify_hooks) = target
.forward_verify_paged_decode_attention_with_hidden(
.forward_verify_paged_decode_attention_tree_with_hidden(
&verify_input,
&positions_v,
&kv_lens_v,
&tree_mask,
slot,
cache,
&EAGLE_HOOK_LAYERS,

View File

@@ -1230,6 +1230,138 @@ impl Qwen3 {
(logits, hidden_arr)
}
/// Tree-aware verify: like `_with_hidden` but supports sibling candidates
/// sharing the same target position. Caller supplies per-token positions
/// (for RoPE), kv_lens (attention context length), and a flattened
/// `tree_mask` (`[new_tokens, new_tokens]` i32; `mask[i, j]!=0` iff query i
/// attends to newly-written K/V at slot j). Positions in the paged cache
/// before pos_offset are always attended (regular history).
#[allow(clippy::too_many_arguments)]
pub fn forward_verify_paged_decode_attention_tree_with_hidden(
&self,
token_ids: &[u32],
positions: &[u32],
kv_lens: &[i32],
tree_mask: &[i32],
slot: usize,
paged_cache: &mut PagedKVCache,
hook_layers: &[usize; 3],
) -> (Tensor, [Tensor; 3]) {
let new_tokens = token_ids.len();
assert_eq!(positions.len(), new_tokens);
assert_eq!(kv_lens.len(), new_tokens);
assert_eq!(tree_mask.len(), new_tokens * new_tokens);
let pos_offset = paged_cache.seq_len(slot);
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32;
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
let slots = vec![slot; new_tokens];
paged_cache.sync_active_batch_with_lens(&slots, kv_lens);
let bt_ptr = paged_cache.block_table_gpu().as_ptr() as *const i32;
let cl_ptr = paged_cache.context_lens_gpu().as_ptr() as *const i32;
let max_blocks = paged_cache.max_blocks_per_seq();
// Upload tree_mask [new_tokens, new_tokens] i32 to GPU.
let mask_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(tree_mask.as_ptr() as *const u8, tree_mask.len() * 4)
};
let mut mask_buf =
xserv_cuda::allocator::cached_alloc(mask_bytes.len()).expect("alloc tree_mask");
mask_buf.copy_from_host(mask_bytes).unwrap();
let mask_ptr = mask_buf.as_ptr() as *const i32;
let mut x = embedding(&self.embed_tokens, token_ids);
let mut hooks: [Option<Tensor>; 3] = [None, None, None];
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let qkv = matmul_2d(&normed, &layer.qkv_proj_wt);
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_all = qkv.narrow(1, 0, q_dim);
let k_all = qkv.narrow(1, q_dim, kv_dim);
let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim);
let q_flat = q_all
.contiguous()
.reshape(&[new_tokens * num_heads, head_dim]);
let k_flat = k_all
.contiguous()
.reshape(&[new_tokens * num_kv_heads, head_dim]);
let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps);
let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps);
let q_3d = q_normed.reshape(&[new_tokens, num_heads, head_dim]);
let k_3d = k_normed.reshape(&[new_tokens, num_kv_heads, head_dim]);
rope_inplace(&q_3d, &self.rope_cache, positions);
rope_inplace(&k_3d, &self.rope_cache, positions);
let v_3d = v_all
.contiguous()
.reshape(&[new_tokens, num_kv_heads, head_dim]);
paged_cache.append_tokens_batched(layer_idx, &k_3d, &v_3d, new_tokens);
let q_decode = q_3d.reshape(&[new_tokens, num_heads, 1, head_dim]);
let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void;
let attn_out = xserv_kernels::paged_decode_attention_tree(
&q_decode,
k_pool_ptr,
v_pool_ptr,
bt_ptr,
cl_ptr,
mask_ptr,
new_tokens,
num_heads,
num_kv_heads,
head_dim,
max_blocks,
pos_offset,
new_tokens,
);
let attn_merged = attn_out.reshape(&[new_tokens, num_heads * head_dim]);
let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt);
self.all_reduce(&attn_proj);
let (normed, x_new) =
xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let residual = x_new.clone();
let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt);
let ffn_dim = gate_up.shape()[1] / 2;
let gate = gate_up.narrow(1, 0, ffn_dim).contiguous();
let up = gate_up.narrow(1, ffn_dim, ffn_dim).contiguous();
let hidden_states = xserv_kernels::silu_mul(&gate, &up);
let down = matmul_2d(&hidden_states, &layer.down_proj_wt);
self.all_reduce(&down);
x = add_any(&residual, &down);
for (h_idx, &h_layer) in hook_layers.iter().enumerate() {
if layer_idx == h_layer {
hooks[h_idx] = Some(x.clone());
}
}
}
let x = rmsnorm(&x, &self.norm, eps);
let logits = matmul_2d(&x, &self.lm_head_t);
let hidden_arr = [
hooks[0].take().expect("hook layer 0 not reached"),
hooks[1].take().expect("hook layer 1 not reached"),
hooks[2].take().expect("hook layer 2 not reached"),
];
(logits, hidden_arr)
}
/// Forward with GPU-resident KV cache and GPU transpose/reshape kernels.
pub fn forward_gpu_cache(&self, token_ids: &[u32], cache: &mut GpuKVCache) -> Tensor {
let new_tokens = token_ids.len();

View File

@@ -189,6 +189,169 @@ __global__ void paged_decode_attention_bf16_kernel(
}
}
// Tree-aware paged decode attention: per-query mask lets sibling candidates
// in the same batch attend to different subsets of newly-written K/V.
// `tree_start`: position where newly-written K/V begins (typically pos_offset).
// `tree_len`: number of newly-written K/V rows (= batch, one per query).
// `tree_mask[i][j] = 1` iff query i attends to K/V at position `tree_start+j`.
// Positions < tree_start are always attended (regular history).
__global__ void paged_decode_attention_tree_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K_cache,
const __nv_bfloat16* __restrict__ V_cache,
__nv_bfloat16* __restrict__ O,
const int* __restrict__ block_tables,
const int* __restrict__ context_lens,
const int* __restrict__ tree_mask, // [batch, tree_len] int32
int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
int tree_start, int tree_len,
float scale
) {
int seq_idx = blockIdx.y;
int q_head = blockIdx.x;
int tid = threadIdx.x;
int kv_len = context_lens[seq_idx];
if (kv_len <= 0) {
if (tid < head_dim) {
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
__float2bfloat16(0.0f);
}
return;
}
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
const __nv_bfloat16* Q_ptr = Q +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
__nv_bfloat16* O_ptr = O +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
const int* mask_row = tree_mask + (long long)seq_idx * tree_len;
float q_reg[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
q_reg[d] = __bfloat162float(Q_ptr[d]);
}
float local_max = -INFINITY;
float local_sum = 0.0f;
float local_O[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
// Tree mask: skip positions in [tree_start, tree_start+tree_len) that
// the mask marks as 0. Everything else (history) is always attended.
if (pos >= tree_start && pos < tree_start + tree_len) {
if (mask_row[pos - tree_start] == 0) continue;
}
int logical_blk = pos / PAGED_BLOCK_SIZE;
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
int phys_blk = bt[logical_blk];
const __nv_bfloat16* K_pos = K_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
const __nv_bfloat16* V_pos = V_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += q_reg[d] * __bfloat162float(K_pos[d]);
}
float s = dot * scale;
float new_max = fmaxf(local_max, s);
float correction = expf(local_max - new_max);
float p = expf(s - new_max);
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
for (int d = 0; d < head_dim; d++) {
local_O[d] += p * __bfloat162float(V_pos[d]);
}
local_max = new_max;
}
// Block-level reduction (identical to base kernel).
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = PAGED_THREADS >> 5;
float warp_max = local_max;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
if (lane == 0) smem_max[warp_id] = warp_max;
__syncthreads();
float global_max;
if (tid == 0) {
global_max = smem_max[0];
for (int i = 1; i < num_warps; i++)
global_max = fmaxf(global_max, smem_max[i]);
smem_max[0] = global_max;
}
__syncthreads();
global_max = smem_max[0];
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
local_sum *= rescale;
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
float warp_sum = local_sum;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
if (lane == 0) smem_sum[warp_id] = warp_sum;
__syncthreads();
float global_sum;
if (tid == 0) {
global_sum = 0.0f;
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
smem_sum[0] = global_sum;
}
__syncthreads();
global_sum = smem_sum[0];
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
float val = local_O[d];
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}
// Extended paged decode attention with attention sinks and sliding window.
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
@@ -389,6 +552,36 @@ void launch_paged_decode_attention_bf16(
CUDA_CHECK_LAST_ERROR();
}
void launch_paged_decode_attention_tree_bf16(
const void* Q,
const void* K_cache,
const void* V_cache,
void* O,
const int* block_tables,
const int* context_lens,
const int* tree_mask,
int batch, int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
int tree_start, int tree_len,
float scale, void* stream
) {
dim3 grid(num_q_heads, batch);
int block = PAGED_THREADS;
paged_decode_attention_tree_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K_cache,
(const __nv_bfloat16*)V_cache,
(__nv_bfloat16*)O,
block_tables, context_lens, tree_mask,
num_q_heads, num_kv_heads,
head_dim, max_blocks_per_seq,
tree_start, tree_len,
scale
);
CUDA_CHECK_LAST_ERROR();
}
void launch_paged_decode_attention_sinks_bf16(
const void* Q,
const void* K_cache,