From 6da09727402f3798553152eb0e91a1afc8ea6c5c Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Wed, 1 Jul 2026 23:09:35 +0800 Subject: [PATCH] speculative: copy_kv_position primitive for tree drafting KV remap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SGLang-style "write-all, copy-move on acceptance" approach: after tree verification, physically copy an accepted sibling's K/V from its physical cache slot to the canonical sequential position. New CUDA kernel: copy_kv_position_kernel in reshape_and_cache.cu. For one token (src_pos → dst_pos), copies head_dim × num_kv_heads BF16 elements in both K and V pools. Grid = num_kv_heads, block = head_dim. Cost for one token across 36 layers: ~5.3 MB D2D copy @ 900 GB/s = <6μs. Rust FFI: copy_kv_position(k_pool, v_pool, block_ids, src_pos, dst_pos, num_kv_heads, head_dim, block_size, stream). PagedKVCache method: copy_kv_position(slot, src_pos, dst_pos) — uploads block_ids for the sequence, calls the kernel per layer. This is the primitive needed by tree drafting: when a non-primary sibling at cache position P+2 is accepted as the "true" token for target position P+1, call copy_kv_position(slot, P+2, P+1) then truncate to P+2. Next: wire into bench-eagle3 tree drafting loop with top-2 siblings. --- crates/xserv-kernels/src/attention.rs | 41 ++++++++++++++++++ crates/xserv-kernels/src/lib.rs | 6 +-- crates/xserv-model/src/paged_kv_cache.rs | 45 ++++++++++++++++++++ csrc/attention/reshape_and_cache.cu | 54 ++++++++++++++++++++++++ 4 files changed, 143 insertions(+), 3 deletions(-) diff --git a/crates/xserv-kernels/src/attention.rs b/crates/xserv-kernels/src/attention.rs index 9d1f340..9da96a2 100644 --- a/crates/xserv-kernels/src/attention.rs +++ b/crates/xserv-kernels/src/attention.rs @@ -145,6 +145,17 @@ unsafe extern "C" { max_blocks_per_seq: i32, stream: *mut c_void, ); + fn launch_copy_kv_position( + k_pool: *mut c_void, + v_pool: *mut c_void, + block_ids: *const i32, + src_pos: i32, + dst_pos: i32, + num_kv_heads: i32, + head_dim: i32, + block_size: i32, + stream: *mut c_void, + ); } /// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged @@ -231,6 +242,36 @@ pub unsafe fn reshape_and_cache_batched_bf16( } } +/// Copy one token's K/V from `src_pos` to `dst_pos` within the same sequence's +/// paged cache (one layer). Used by tree speculative decoding to remap +/// accepted sibling K/V to canonical sequential positions after acceptance. +/// +/// # Safety +/// Pool and block_ids pointers must be valid GPU pointers for the given layer. +pub unsafe fn copy_kv_position( + k_pool_ptr: *mut c_void, + v_pool_ptr: *mut c_void, + block_ids_gpu: *const i32, + src_pos: usize, + dst_pos: usize, + num_kv_heads: usize, + head_dim: usize, + block_size: usize, + stream: *mut c_void, +) { + launch_copy_kv_position( + k_pool_ptr, + v_pool_ptr, + block_ids_gpu, + src_pos as i32, + dst_pos as i32, + num_kv_heads as i32, + head_dim as i32, + block_size as i32, + stream, + ); +} + fn apply_causal_mask(scores: &Tensor, offset: usize) { let ndim = scores.ndim(); let rows = scores.shape()[ndim - 2]; diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 73ba21d..2cd8122 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -15,9 +15,9 @@ pub mod transpose; pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul}; pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; pub use attention::{ - attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, - paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16, - reshape_and_cache_bf16, + attention, copy_kv_position, decode_attention, flash_attention, flash_attention_sinks, + paged_decode_attention, paged_decode_attention_sinks, paged_decode_attention_tree, + reshape_and_cache_batched_bf16, reshape_and_cache_bf16, }; pub use embedding::{embedding, embedding_device_ids}; pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv}; diff --git a/crates/xserv-model/src/paged_kv_cache.rs b/crates/xserv-model/src/paged_kv_cache.rs index 34f9bac..96c0adc 100644 --- a/crates/xserv-model/src/paged_kv_cache.rs +++ b/crates/xserv-model/src/paged_kv_cache.rs @@ -515,6 +515,51 @@ impl PagedKVCache { Ok(()) } + /// Copy K/V data from `src_pos` to `dst_pos` within the same slot, across + /// all layers. Used by tree speculative decoding to remap an accepted + /// sibling's K/V to the canonical sequential position after acceptance. + /// + /// Requires: both positions within the currently-allocated block range. + pub fn copy_kv_position(&self, slot: usize, src_pos: usize, dst_pos: usize) { + let state = self.seq_states[slot] + .as_ref() + .expect("copy_kv_position: slot not registered"); + assert!( + src_pos < state.seq_len && dst_pos < state.seq_len, + "copy_kv_position: positions must be within seq_len" + ); + // Upload this sequence's block_ids to a small GPU buffer. + let block_ids_host: Vec = state.block_ids.iter().map(|&b| b as i32).collect(); + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + block_ids_host.as_ptr() as *const u8, + block_ids_host.len() * 4, + ) + }; + let mut ids_buf = + xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc block_ids for copy"); + ids_buf.copy_from_host(bytes).unwrap(); + let ids_ptr = ids_buf.as_ptr() as *const i32; + + let stream = xserv_cuda::current_stream_raw(); + let num_layers = self.k_pools.len(); + for layer in 0..num_layers { + unsafe { + xserv_kernels::copy_kv_position( + self.k_pools[layer].as_ptr() as *mut std::ffi::c_void, + self.v_pools[layer].as_ptr() as *mut std::ffi::c_void, + ids_ptr, + src_pos, + dst_pos, + self.num_kv_heads, + self.head_dim, + BLOCK_SIZE, + stream, + ); + } + } + } + /// Refresh the host-side block table + context lens from `seq_states`, /// then upload to GPU. Call once per decode step before the paged kernel. pub fn sync_to_gpu(&mut self) { diff --git a/csrc/attention/reshape_and_cache.cu b/csrc/attention/reshape_and_cache.cu index cc14a50..c2928d1 100644 --- a/csrc/attention/reshape_and_cache.cu +++ b/csrc/attention/reshape_and_cache.cu @@ -158,4 +158,58 @@ void launch_reshape_and_cache_batched_bf16( CUDA_CHECK_LAST_ERROR(); } +// Copy one token's K/V from src_pos to dst_pos within one pool. +// Grid: (num_kv_heads,). Block: head_dim threads. +// pool: [num_blocks_total, num_kv_heads, block_size, head_dim] +// block_ids: [max_blocks] for this sequence (logical → physical block map). +__global__ void copy_kv_position_kernel( + __nv_bfloat16* __restrict__ pool, + const int* __restrict__ block_ids, + int src_pos, int dst_pos, + int head_dim, int block_size +) { + int h = blockIdx.x; + int d = threadIdx.x; + if (d >= head_dim) return; + + int num_kv_heads = gridDim.x; + + int src_blk = src_pos / block_size; + int src_slot = src_pos % block_size; + int src_phys = block_ids[src_blk]; + + int dst_blk = dst_pos / block_size; + int dst_slot = dst_pos % block_size; + int dst_phys = block_ids[dst_blk]; + + long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim + + src_slot * head_dim + d; + long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim + + dst_slot * head_dim + d; + + pool[dst_off] = pool[src_off]; +} + +void launch_copy_kv_position( + void* k_pool, void* v_pool, + const int* block_ids, + int src_pos, int dst_pos, + int num_kv_heads, int head_dim, int block_size, + void* stream +) { + int threads = head_dim < 32 ? 32 : head_dim; + if (threads > 1024) threads = 1024; + dim3 grid(num_kv_heads); + copy_kv_position_kernel<<>>( + (__nv_bfloat16*)k_pool, block_ids, + src_pos, dst_pos, head_dim, block_size + ); + CUDA_CHECK_LAST_ERROR(); + copy_kv_position_kernel<<>>( + (__nv_bfloat16*)v_pool, block_ids, + src_pos, dst_pos, head_dim, block_size + ); + CUDA_CHECK_LAST_ERROR(); +} + }