speculative: copy_kv_position primitive for tree drafting KV remap

SGLang-style "write-all, copy-move on acceptance" approach: after tree
verification, physically copy an accepted sibling's K/V from its
physical cache slot to the canonical sequential position.

New CUDA kernel: copy_kv_position_kernel in reshape_and_cache.cu.
For one token (src_pos → dst_pos), copies head_dim × num_kv_heads BF16
elements in both K and V pools. Grid = num_kv_heads, block = head_dim.
Cost for one token across 36 layers: ~5.3 MB D2D copy @ 900 GB/s = <6μs.

Rust FFI: copy_kv_position(k_pool, v_pool, block_ids, src_pos, dst_pos,
num_kv_heads, head_dim, block_size, stream).

PagedKVCache method: copy_kv_position(slot, src_pos, dst_pos) — uploads
block_ids for the sequence, calls the kernel per layer. This is the
primitive needed by tree drafting: when a non-primary sibling at cache
position P+2 is accepted as the "true" token for target position P+1,
call copy_kv_position(slot, P+2, P+1) then truncate to P+2.

Next: wire into bench-eagle3 tree drafting loop with top-2 siblings.
This commit is contained in:
2026-07-01 23:09:35 +08:00
parent 40d8a29e33
commit 6da0972740
4 changed files with 143 additions and 3 deletions

View File

@@ -145,6 +145,17 @@ unsafe extern "C" {
max_blocks_per_seq: i32,
stream: *mut c_void,
);
fn launch_copy_kv_position(
k_pool: *mut c_void,
v_pool: *mut c_void,
block_ids: *const i32,
src_pos: i32,
dst_pos: i32,
num_kv_heads: i32,
head_dim: i32,
block_size: i32,
stream: *mut c_void,
);
}
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
@@ -231,6 +242,36 @@ pub unsafe fn reshape_and_cache_batched_bf16(
}
}
/// Copy one token's K/V from `src_pos` to `dst_pos` within the same sequence's
/// paged cache (one layer). Used by tree speculative decoding to remap
/// accepted sibling K/V to canonical sequential positions after acceptance.
///
/// # Safety
/// Pool and block_ids pointers must be valid GPU pointers for the given layer.
pub unsafe fn copy_kv_position(
k_pool_ptr: *mut c_void,
v_pool_ptr: *mut c_void,
block_ids_gpu: *const i32,
src_pos: usize,
dst_pos: usize,
num_kv_heads: usize,
head_dim: usize,
block_size: usize,
stream: *mut c_void,
) {
launch_copy_kv_position(
k_pool_ptr,
v_pool_ptr,
block_ids_gpu,
src_pos as i32,
dst_pos as i32,
num_kv_heads as i32,
head_dim as i32,
block_size as i32,
stream,
);
}
fn apply_causal_mask(scores: &Tensor, offset: usize) {
let ndim = scores.ndim();
let rows = scores.shape()[ndim - 2];

View File

@@ -15,9 +15,9 @@ pub mod transpose;
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use attention::{
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16,
reshape_and_cache_bf16,
attention, copy_kv_position, decode_attention, flash_attention, flash_attention_sinks,
paged_decode_attention, paged_decode_attention_sinks, paged_decode_attention_tree,
reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
};
pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};

View File

@@ -515,6 +515,51 @@ impl PagedKVCache {
Ok(())
}
/// Copy K/V data from `src_pos` to `dst_pos` within the same slot, across
/// all layers. Used by tree speculative decoding to remap an accepted
/// sibling's K/V to the canonical sequential position after acceptance.
///
/// Requires: both positions within the currently-allocated block range.
pub fn copy_kv_position(&self, slot: usize, src_pos: usize, dst_pos: usize) {
let state = self.seq_states[slot]
.as_ref()
.expect("copy_kv_position: slot not registered");
assert!(
src_pos < state.seq_len && dst_pos < state.seq_len,
"copy_kv_position: positions must be within seq_len"
);
// Upload this sequence's block_ids to a small GPU buffer.
let block_ids_host: Vec<i32> = state.block_ids.iter().map(|&b| b as i32).collect();
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
block_ids_host.as_ptr() as *const u8,
block_ids_host.len() * 4,
)
};
let mut ids_buf =
xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc block_ids for copy");
ids_buf.copy_from_host(bytes).unwrap();
let ids_ptr = ids_buf.as_ptr() as *const i32;
let stream = xserv_cuda::current_stream_raw();
let num_layers = self.k_pools.len();
for layer in 0..num_layers {
unsafe {
xserv_kernels::copy_kv_position(
self.k_pools[layer].as_ptr() as *mut std::ffi::c_void,
self.v_pools[layer].as_ptr() as *mut std::ffi::c_void,
ids_ptr,
src_pos,
dst_pos,
self.num_kv_heads,
self.head_dim,
BLOCK_SIZE,
stream,
);
}
}
}
/// Refresh the host-side block table + context lens from `seq_states`,
/// then upload to GPU. Call once per decode step before the paged kernel.
pub fn sync_to_gpu(&mut self) {