diff --git a/crates/xserv-kernels/src/attention.rs b/crates/xserv-kernels/src/attention.rs index 9d1f340..9da96a2 100644 --- a/crates/xserv-kernels/src/attention.rs +++ b/crates/xserv-kernels/src/attention.rs @@ -145,6 +145,17 @@ unsafe extern "C" { max_blocks_per_seq: i32, stream: *mut c_void, ); + fn launch_copy_kv_position( + k_pool: *mut c_void, + v_pool: *mut c_void, + block_ids: *const i32, + src_pos: i32, + dst_pos: i32, + num_kv_heads: i32, + head_dim: i32, + block_size: i32, + stream: *mut c_void, + ); } /// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged @@ -231,6 +242,36 @@ pub unsafe fn reshape_and_cache_batched_bf16( } } +/// Copy one token's K/V from `src_pos` to `dst_pos` within the same sequence's +/// paged cache (one layer). Used by tree speculative decoding to remap +/// accepted sibling K/V to canonical sequential positions after acceptance. +/// +/// # Safety +/// Pool and block_ids pointers must be valid GPU pointers for the given layer. +pub unsafe fn copy_kv_position( + k_pool_ptr: *mut c_void, + v_pool_ptr: *mut c_void, + block_ids_gpu: *const i32, + src_pos: usize, + dst_pos: usize, + num_kv_heads: usize, + head_dim: usize, + block_size: usize, + stream: *mut c_void, +) { + launch_copy_kv_position( + k_pool_ptr, + v_pool_ptr, + block_ids_gpu, + src_pos as i32, + dst_pos as i32, + num_kv_heads as i32, + head_dim as i32, + block_size as i32, + stream, + ); +} + fn apply_causal_mask(scores: &Tensor, offset: usize) { let ndim = scores.ndim(); let rows = scores.shape()[ndim - 2]; diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 73ba21d..2cd8122 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -15,9 +15,9 @@ pub mod transpose; pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul}; pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; pub use attention::{ - attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, - paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16, - reshape_and_cache_bf16, + attention, copy_kv_position, decode_attention, flash_attention, flash_attention_sinks, + paged_decode_attention, paged_decode_attention_sinks, paged_decode_attention_tree, + reshape_and_cache_batched_bf16, reshape_and_cache_bf16, }; pub use embedding::{embedding, embedding_device_ids}; pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv}; diff --git a/crates/xserv-model/src/paged_kv_cache.rs b/crates/xserv-model/src/paged_kv_cache.rs index 34f9bac..96c0adc 100644 --- a/crates/xserv-model/src/paged_kv_cache.rs +++ b/crates/xserv-model/src/paged_kv_cache.rs @@ -515,6 +515,51 @@ impl PagedKVCache { Ok(()) } + /// Copy K/V data from `src_pos` to `dst_pos` within the same slot, across + /// all layers. Used by tree speculative decoding to remap an accepted + /// sibling's K/V to the canonical sequential position after acceptance. + /// + /// Requires: both positions within the currently-allocated block range. + pub fn copy_kv_position(&self, slot: usize, src_pos: usize, dst_pos: usize) { + let state = self.seq_states[slot] + .as_ref() + .expect("copy_kv_position: slot not registered"); + assert!( + src_pos < state.seq_len && dst_pos < state.seq_len, + "copy_kv_position: positions must be within seq_len" + ); + // Upload this sequence's block_ids to a small GPU buffer. + let block_ids_host: Vec = state.block_ids.iter().map(|&b| b as i32).collect(); + let bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + block_ids_host.as_ptr() as *const u8, + block_ids_host.len() * 4, + ) + }; + let mut ids_buf = + xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc block_ids for copy"); + ids_buf.copy_from_host(bytes).unwrap(); + let ids_ptr = ids_buf.as_ptr() as *const i32; + + let stream = xserv_cuda::current_stream_raw(); + let num_layers = self.k_pools.len(); + for layer in 0..num_layers { + unsafe { + xserv_kernels::copy_kv_position( + self.k_pools[layer].as_ptr() as *mut std::ffi::c_void, + self.v_pools[layer].as_ptr() as *mut std::ffi::c_void, + ids_ptr, + src_pos, + dst_pos, + self.num_kv_heads, + self.head_dim, + BLOCK_SIZE, + stream, + ); + } + } + } + /// Refresh the host-side block table + context lens from `seq_states`, /// then upload to GPU. Call once per decode step before the paged kernel. pub fn sync_to_gpu(&mut self) { diff --git a/csrc/attention/reshape_and_cache.cu b/csrc/attention/reshape_and_cache.cu index cc14a50..c2928d1 100644 --- a/csrc/attention/reshape_and_cache.cu +++ b/csrc/attention/reshape_and_cache.cu @@ -158,4 +158,58 @@ void launch_reshape_and_cache_batched_bf16( CUDA_CHECK_LAST_ERROR(); } +// Copy one token's K/V from src_pos to dst_pos within one pool. +// Grid: (num_kv_heads,). Block: head_dim threads. +// pool: [num_blocks_total, num_kv_heads, block_size, head_dim] +// block_ids: [max_blocks] for this sequence (logical → physical block map). +__global__ void copy_kv_position_kernel( + __nv_bfloat16* __restrict__ pool, + const int* __restrict__ block_ids, + int src_pos, int dst_pos, + int head_dim, int block_size +) { + int h = blockIdx.x; + int d = threadIdx.x; + if (d >= head_dim) return; + + int num_kv_heads = gridDim.x; + + int src_blk = src_pos / block_size; + int src_slot = src_pos % block_size; + int src_phys = block_ids[src_blk]; + + int dst_blk = dst_pos / block_size; + int dst_slot = dst_pos % block_size; + int dst_phys = block_ids[dst_blk]; + + long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim + + src_slot * head_dim + d; + long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim + + dst_slot * head_dim + d; + + pool[dst_off] = pool[src_off]; +} + +void launch_copy_kv_position( + void* k_pool, void* v_pool, + const int* block_ids, + int src_pos, int dst_pos, + int num_kv_heads, int head_dim, int block_size, + void* stream +) { + int threads = head_dim < 32 ? 32 : head_dim; + if (threads > 1024) threads = 1024; + dim3 grid(num_kv_heads); + copy_kv_position_kernel<<>>( + (__nv_bfloat16*)k_pool, block_ids, + src_pos, dst_pos, head_dim, block_size + ); + CUDA_CHECK_LAST_ERROR(); + copy_kv_position_kernel<<>>( + (__nv_bfloat16*)v_pool, block_ids, + src_pos, dst_pos, head_dim, block_size + ); + CUDA_CHECK_LAST_ERROR(); +} + }