speculative: copy_kv_position primitive for tree drafting KV remap
SGLang-style "write-all, copy-move on acceptance" approach: after tree verification, physically copy an accepted sibling's K/V from its physical cache slot to the canonical sequential position. New CUDA kernel: copy_kv_position_kernel in reshape_and_cache.cu. For one token (src_pos → dst_pos), copies head_dim × num_kv_heads BF16 elements in both K and V pools. Grid = num_kv_heads, block = head_dim. Cost for one token across 36 layers: ~5.3 MB D2D copy @ 900 GB/s = <6μs. Rust FFI: copy_kv_position(k_pool, v_pool, block_ids, src_pos, dst_pos, num_kv_heads, head_dim, block_size, stream). PagedKVCache method: copy_kv_position(slot, src_pos, dst_pos) — uploads block_ids for the sequence, calls the kernel per layer. This is the primitive needed by tree drafting: when a non-primary sibling at cache position P+2 is accepted as the "true" token for target position P+1, call copy_kv_position(slot, P+2, P+1) then truncate to P+2. Next: wire into bench-eagle3 tree drafting loop with top-2 siblings.
This commit is contained in:
@@ -158,4 +158,58 @@ void launch_reshape_and_cache_batched_bf16(
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
// Copy one token's K/V from src_pos to dst_pos within one pool.
|
||||
// Grid: (num_kv_heads,). Block: head_dim threads.
|
||||
// pool: [num_blocks_total, num_kv_heads, block_size, head_dim]
|
||||
// block_ids: [max_blocks] for this sequence (logical → physical block map).
|
||||
__global__ void copy_kv_position_kernel(
|
||||
__nv_bfloat16* __restrict__ pool,
|
||||
const int* __restrict__ block_ids,
|
||||
int src_pos, int dst_pos,
|
||||
int head_dim, int block_size
|
||||
) {
|
||||
int h = blockIdx.x;
|
||||
int d = threadIdx.x;
|
||||
if (d >= head_dim) return;
|
||||
|
||||
int num_kv_heads = gridDim.x;
|
||||
|
||||
int src_blk = src_pos / block_size;
|
||||
int src_slot = src_pos % block_size;
|
||||
int src_phys = block_ids[src_blk];
|
||||
|
||||
int dst_blk = dst_pos / block_size;
|
||||
int dst_slot = dst_pos % block_size;
|
||||
int dst_phys = block_ids[dst_blk];
|
||||
|
||||
long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim
|
||||
+ src_slot * head_dim + d;
|
||||
long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim
|
||||
+ dst_slot * head_dim + d;
|
||||
|
||||
pool[dst_off] = pool[src_off];
|
||||
}
|
||||
|
||||
void launch_copy_kv_position(
|
||||
void* k_pool, void* v_pool,
|
||||
const int* block_ids,
|
||||
int src_pos, int dst_pos,
|
||||
int num_kv_heads, int head_dim, int block_size,
|
||||
void* stream
|
||||
) {
|
||||
int threads = head_dim < 32 ? 32 : head_dim;
|
||||
if (threads > 1024) threads = 1024;
|
||||
dim3 grid(num_kv_heads);
|
||||
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)k_pool, block_ids,
|
||||
src_pos, dst_pos, head_dim, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)v_pool, block_ids,
|
||||
src_pos, dst_pos, head_dim, block_size
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user