attention: tree-aware paged_decode_attention_tree kernel + wrapper

New CUDA kernel paged_decode_attention_tree_bf16_kernel: same as base
paged_decode_attention but with a per-query mask over the newly-written
K/V region. `tree_mask[i][j] != 0` iff query i attends to newly-written
K/V at slot j. Positions before `tree_start` are always attended.

Motivation: speculative decoding with tree drafting needs siblings at
the same target position to attend to their own branch's history, not
each other's K/V.

Rust binding: paged_decode_attention_tree(...) mirrors
paged_decode_attention plus tree_mask_ptr, tree_start, tree_len.

Forward path: Qwen3::forward_verify_paged_decode_attention_tree_with_hidden
takes explicit positions, kv_lens, and a flattened [N*N] tree_mask.

Sanity check: bench-eagle3's γ_multi path now routes through the tree
kernel with a causal mask (mask[i][j]=1 iff j<=i), producing bit-
equivalent output to the non-tree variant. matched=false pattern +
acceptance rate + speedup all identical to previous run within noise
(11.3% acceptance, 1.00× speedup with the mask-check overhead).

--tree CLI flag is parsed but reserved. Real tree drafting (siblings
sharing a target position) is blocked by KV cache position rigidity:
paged_cache stores K/V at cache-position ≡ target-position, so an
accepted sibling at target position P+1 has its K/V physically at
cache position P+2 (its unique slot in the batched write). Continuing
decode at P+1 would see the WRONG K/V (top-1 sibling's, not accepted
top-2 sibling's). Fix requires either KV-slot remap on acceptance or
a virtual position layer.

Infrastructure is in place, next step is tackling that remap.
This commit is contained in:
2026-07-01 20:45:55 +08:00
parent 10a98539d0
commit fd392f7fbb
5 changed files with 422 additions and 3 deletions

View File

@@ -189,6 +189,169 @@ __global__ void paged_decode_attention_bf16_kernel(
}
}
// Tree-aware paged decode attention: per-query mask lets sibling candidates
// in the same batch attend to different subsets of newly-written K/V.
// `tree_start`: position where newly-written K/V begins (typically pos_offset).
// `tree_len`: number of newly-written K/V rows (= batch, one per query).
// `tree_mask[i][j] = 1` iff query i attends to K/V at position `tree_start+j`.
// Positions < tree_start are always attended (regular history).
__global__ void paged_decode_attention_tree_bf16_kernel(
const __nv_bfloat16* __restrict__ Q,
const __nv_bfloat16* __restrict__ K_cache,
const __nv_bfloat16* __restrict__ V_cache,
__nv_bfloat16* __restrict__ O,
const int* __restrict__ block_tables,
const int* __restrict__ context_lens,
const int* __restrict__ tree_mask, // [batch, tree_len] int32
int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
int tree_start, int tree_len,
float scale
) {
int seq_idx = blockIdx.y;
int q_head = blockIdx.x;
int tid = threadIdx.x;
int kv_len = context_lens[seq_idx];
if (kv_len <= 0) {
if (tid < head_dim) {
O[((long long)seq_idx * num_q_heads + q_head) * head_dim + tid] =
__float2bfloat16(0.0f);
}
return;
}
int heads_per_group = num_q_heads / num_kv_heads;
int kv_head = q_head / heads_per_group;
const __nv_bfloat16* Q_ptr = Q +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
__nv_bfloat16* O_ptr = O +
((long long)seq_idx * num_q_heads + q_head) * head_dim;
const int* bt = block_tables + (long long)seq_idx * max_blocks_per_seq;
const int* mask_row = tree_mask + (long long)seq_idx * tree_len;
float q_reg[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) {
q_reg[d] = __bfloat162float(Q_ptr[d]);
}
float local_max = -INFINITY;
float local_sum = 0.0f;
float local_O[PAGED_HEAD_DIM_MAX];
for (int d = 0; d < head_dim; d++) local_O[d] = 0.0f;
int kv_stride_block = num_kv_heads * PAGED_BLOCK_SIZE * head_dim;
int kv_stride_head = PAGED_BLOCK_SIZE * head_dim;
for (int pos = tid; pos < kv_len; pos += PAGED_THREADS) {
// Tree mask: skip positions in [tree_start, tree_start+tree_len) that
// the mask marks as 0. Everything else (history) is always attended.
if (pos >= tree_start && pos < tree_start + tree_len) {
if (mask_row[pos - tree_start] == 0) continue;
}
int logical_blk = pos / PAGED_BLOCK_SIZE;
int slot_in_blk = pos % PAGED_BLOCK_SIZE;
int phys_blk = bt[logical_blk];
const __nv_bfloat16* K_pos = K_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
const __nv_bfloat16* V_pos = V_cache
+ (long long)phys_blk * kv_stride_block
+ kv_head * kv_stride_head
+ slot_in_blk * head_dim;
float dot = 0.0f;
for (int d = 0; d < head_dim; d++) {
dot += q_reg[d] * __bfloat162float(K_pos[d]);
}
float s = dot * scale;
float new_max = fmaxf(local_max, s);
float correction = expf(local_max - new_max);
float p = expf(s - new_max);
local_sum = local_sum * correction + p;
for (int d = 0; d < head_dim; d++) local_O[d] *= correction;
for (int d = 0; d < head_dim; d++) {
local_O[d] += p * __bfloat162float(V_pos[d]);
}
local_max = new_max;
}
// Block-level reduction (identical to base kernel).
__shared__ float smem_max[32];
__shared__ float smem_sum[32];
__shared__ float smem_O_warp[32][PAGED_HEAD_DIM_MAX];
int lane = tid & 31;
int warp_id = tid >> 5;
int num_warps = PAGED_THREADS >> 5;
float warp_max = local_max;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_max = fmaxf(warp_max, __shfl_down_sync(0xffffffff, warp_max, offset));
if (lane == 0) smem_max[warp_id] = warp_max;
__syncthreads();
float global_max;
if (tid == 0) {
global_max = smem_max[0];
for (int i = 1; i < num_warps; i++)
global_max = fmaxf(global_max, smem_max[i]);
smem_max[0] = global_max;
}
__syncthreads();
global_max = smem_max[0];
float rescale = (local_max == -INFINITY) ? 0.0f : expf(local_max - global_max);
local_sum *= rescale;
for (int d = 0; d < head_dim; d++) local_O[d] *= rescale;
float warp_sum = local_sum;
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
warp_sum += __shfl_down_sync(0xffffffff, warp_sum, offset);
if (lane == 0) smem_sum[warp_id] = warp_sum;
__syncthreads();
float global_sum;
if (tid == 0) {
global_sum = 0.0f;
for (int i = 0; i < num_warps; i++) global_sum += smem_sum[i];
smem_sum[0] = global_sum;
}
__syncthreads();
global_sum = smem_sum[0];
for (int i = tid; i < 32 * PAGED_HEAD_DIM_MAX; i += PAGED_THREADS) {
reinterpret_cast<float*>(smem_O_warp)[i] = 0.0f;
}
__syncthreads();
for (int d = 0; d < head_dim; d++) {
float val = local_O[d];
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
if (lane == 0) smem_O_warp[warp_id][d] = val;
}
__syncthreads();
float inv_sum = (global_sum > 0.0f) ? (1.0f / global_sum) : 0.0f;
for (int d = tid; d < head_dim; d += PAGED_THREADS) {
float out = 0.0f;
for (int i = 0; i < num_warps; i++) out += smem_O_warp[i][d];
O_ptr[d] = __float2bfloat16(out * inv_sum);
}
}
// Extended paged decode attention with attention sinks and sliding window.
// sinks: [num_q_heads] BF16 — per-head extra logit appended before softmax.
// window_size: >0 = sliding window (only attend to last `window_size` positions), 0 = full.
@@ -389,6 +552,36 @@ void launch_paged_decode_attention_bf16(
CUDA_CHECK_LAST_ERROR();
}
void launch_paged_decode_attention_tree_bf16(
const void* Q,
const void* K_cache,
const void* V_cache,
void* O,
const int* block_tables,
const int* context_lens,
const int* tree_mask,
int batch, int num_q_heads, int num_kv_heads,
int head_dim, int max_blocks_per_seq,
int tree_start, int tree_len,
float scale, void* stream
) {
dim3 grid(num_q_heads, batch);
int block = PAGED_THREADS;
paged_decode_attention_tree_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)Q,
(const __nv_bfloat16*)K_cache,
(const __nv_bfloat16*)V_cache,
(__nv_bfloat16*)O,
block_tables, context_lens, tree_mask,
num_q_heads, num_kv_heads,
head_dim, max_blocks_per_seq,
tree_start, tree_len,
scale
);
CUDA_CHECK_LAST_ERROR();
}
void launch_paged_decode_attention_sinks_bf16(
const void* Q,
const void* K_cache,