speculative: copy_kv_position primitive for tree drafting KV remap
SGLang-style "write-all, copy-move on acceptance" approach: after tree verification, physically copy an accepted sibling's K/V from its physical cache slot to the canonical sequential position. New CUDA kernel: copy_kv_position_kernel in reshape_and_cache.cu. For one token (src_pos → dst_pos), copies head_dim × num_kv_heads BF16 elements in both K and V pools. Grid = num_kv_heads, block = head_dim. Cost for one token across 36 layers: ~5.3 MB D2D copy @ 900 GB/s = <6μs. Rust FFI: copy_kv_position(k_pool, v_pool, block_ids, src_pos, dst_pos, num_kv_heads, head_dim, block_size, stream). PagedKVCache method: copy_kv_position(slot, src_pos, dst_pos) — uploads block_ids for the sequence, calls the kernel per layer. This is the primitive needed by tree drafting: when a non-primary sibling at cache position P+2 is accepted as the "true" token for target position P+1, call copy_kv_position(slot, P+2, P+1) then truncate to P+2. Next: wire into bench-eagle3 tree drafting loop with top-2 siblings.
This commit is contained in:
@@ -145,6 +145,17 @@ unsafe extern "C" {
|
||||
max_blocks_per_seq: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_copy_kv_position(
|
||||
k_pool: *mut c_void,
|
||||
v_pool: *mut c_void,
|
||||
block_ids: *const i32,
|
||||
src_pos: i32,
|
||||
dst_pos: i32,
|
||||
num_kv_heads: i32,
|
||||
head_dim: i32,
|
||||
block_size: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
}
|
||||
|
||||
/// Scatter `[num_kv_heads, num_tokens, head_dim]` BF16 K/V into a paged
|
||||
@@ -231,6 +242,36 @@ pub unsafe fn reshape_and_cache_batched_bf16(
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy one token's K/V from `src_pos` to `dst_pos` within the same sequence's
|
||||
/// paged cache (one layer). Used by tree speculative decoding to remap
|
||||
/// accepted sibling K/V to canonical sequential positions after acceptance.
|
||||
///
|
||||
/// # Safety
|
||||
/// Pool and block_ids pointers must be valid GPU pointers for the given layer.
|
||||
pub unsafe fn copy_kv_position(
|
||||
k_pool_ptr: *mut c_void,
|
||||
v_pool_ptr: *mut c_void,
|
||||
block_ids_gpu: *const i32,
|
||||
src_pos: usize,
|
||||
dst_pos: usize,
|
||||
num_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
block_size: usize,
|
||||
stream: *mut c_void,
|
||||
) {
|
||||
launch_copy_kv_position(
|
||||
k_pool_ptr,
|
||||
v_pool_ptr,
|
||||
block_ids_gpu,
|
||||
src_pos as i32,
|
||||
dst_pos as i32,
|
||||
num_kv_heads as i32,
|
||||
head_dim as i32,
|
||||
block_size as i32,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
let ndim = scores.ndim();
|
||||
let rows = scores.shape()[ndim - 2];
|
||||
|
||||
@@ -15,9 +15,9 @@ pub mod transpose;
|
||||
pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul};
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use attention::{
|
||||
attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention,
|
||||
paged_decode_attention_sinks, paged_decode_attention_tree, reshape_and_cache_batched_bf16,
|
||||
reshape_and_cache_bf16,
|
||||
attention, copy_kv_position, decode_attention, flash_attention, flash_attention_sinks,
|
||||
paged_decode_attention, paged_decode_attention_sinks, paged_decode_attention_tree,
|
||||
reshape_and_cache_batched_bf16, reshape_and_cache_bf16,
|
||||
};
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{GemmBackend, batched_matmul, matmul, matmul_batched_gemv};
|
||||
|
||||
@@ -515,6 +515,51 @@ impl PagedKVCache {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Copy K/V data from `src_pos` to `dst_pos` within the same slot, across
|
||||
/// all layers. Used by tree speculative decoding to remap an accepted
|
||||
/// sibling's K/V to the canonical sequential position after acceptance.
|
||||
///
|
||||
/// Requires: both positions within the currently-allocated block range.
|
||||
pub fn copy_kv_position(&self, slot: usize, src_pos: usize, dst_pos: usize) {
|
||||
let state = self.seq_states[slot]
|
||||
.as_ref()
|
||||
.expect("copy_kv_position: slot not registered");
|
||||
assert!(
|
||||
src_pos < state.seq_len && dst_pos < state.seq_len,
|
||||
"copy_kv_position: positions must be within seq_len"
|
||||
);
|
||||
// Upload this sequence's block_ids to a small GPU buffer.
|
||||
let block_ids_host: Vec<i32> = state.block_ids.iter().map(|&b| b as i32).collect();
|
||||
let bytes: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
block_ids_host.as_ptr() as *const u8,
|
||||
block_ids_host.len() * 4,
|
||||
)
|
||||
};
|
||||
let mut ids_buf =
|
||||
xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc block_ids for copy");
|
||||
ids_buf.copy_from_host(bytes).unwrap();
|
||||
let ids_ptr = ids_buf.as_ptr() as *const i32;
|
||||
|
||||
let stream = xserv_cuda::current_stream_raw();
|
||||
let num_layers = self.k_pools.len();
|
||||
for layer in 0..num_layers {
|
||||
unsafe {
|
||||
xserv_kernels::copy_kv_position(
|
||||
self.k_pools[layer].as_ptr() as *mut std::ffi::c_void,
|
||||
self.v_pools[layer].as_ptr() as *mut std::ffi::c_void,
|
||||
ids_ptr,
|
||||
src_pos,
|
||||
dst_pos,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
BLOCK_SIZE,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Refresh the host-side block table + context lens from `seq_states`,
|
||||
/// then upload to GPU. Call once per decode step before the paged kernel.
|
||||
pub fn sync_to_gpu(&mut self) {
|
||||
|
||||
Reference in New Issue
Block a user