SGLang-style "write-all, copy-move on acceptance" approach: after tree verification, physically copy an accepted sibling's K/V from its physical cache slot to the canonical sequential position. New CUDA kernel: copy_kv_position_kernel in reshape_and_cache.cu. For one token (src_pos → dst_pos), copies head_dim × num_kv_heads BF16 elements in both K and V pools. Grid = num_kv_heads, block = head_dim. Cost for one token across 36 layers: ~5.3 MB D2D copy @ 900 GB/s = <6μs. Rust FFI: copy_kv_position(k_pool, v_pool, block_ids, src_pos, dst_pos, num_kv_heads, head_dim, block_size, stream). PagedKVCache method: copy_kv_position(slot, src_pos, dst_pos) — uploads block_ids for the sequence, calls the kernel per layer. This is the primitive needed by tree drafting: when a non-primary sibling at cache position P+2 is accepted as the "true" token for target position P+1, call copy_kv_position(slot, P+2, P+1) then truncate to P+2. Next: wire into bench-eagle3 tree drafting loop with top-2 siblings.
216 lines
7.5 KiB
Plaintext
216 lines
7.5 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include "../common.cuh"
|
|
|
|
// Scatter [num_tokens] new K/V into a paged KV pool for ONE sequence.
|
|
//
|
|
// Source layouts (BF16, contiguous):
|
|
// k_src, v_src : [num_kv_heads, num_tokens, head_dim] (head-major)
|
|
//
|
|
// Pool layouts (BF16, contiguous):
|
|
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
|
//
|
|
// For token t (0 <= t < num_tokens):
|
|
// p = start_pos + t
|
|
// logical_blk = p / BLOCK_SIZE
|
|
// slot_in_blk = p % BLOCK_SIZE
|
|
// phys = block_ids[logical_blk]
|
|
// pool[phys, h, slot_in_blk, :] := src[h, t, :]
|
|
//
|
|
// Replaces a Rust-side per-token, per-head cudaMemcpy loop. With Qwen3-8B
|
|
// (8 KV heads, 36 layers) and a 1024-token prefill, that loop fired
|
|
// ~290k device-side memcpys; one kernel launch per layer is dramatically
|
|
// less overhead.
|
|
//
|
|
// Grid : (num_tokens, num_kv_heads)
|
|
// Block: head_dim threads (≤128 in practice; head_dim is padded to a
|
|
// multiple of 32 by the model and all our shipping configs are
|
|
// 128, so a single warp's worth handles two slots in flight).
|
|
|
|
__global__ void reshape_and_cache_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ k_src,
|
|
const __nv_bfloat16* __restrict__ v_src,
|
|
__nv_bfloat16* __restrict__ k_pool,
|
|
__nv_bfloat16* __restrict__ v_pool,
|
|
const int* __restrict__ block_ids,
|
|
int num_tokens, int num_heads,
|
|
int head_dim, int start_pos, int block_size
|
|
) {
|
|
int t = blockIdx.x;
|
|
int h = blockIdx.y;
|
|
if (t >= num_tokens || h >= num_heads) return;
|
|
|
|
int p = start_pos + t;
|
|
int logical_blk = p / block_size;
|
|
int slot_in_blk = p - logical_blk * block_size;
|
|
int phys = block_ids[logical_blk];
|
|
|
|
long long src_off = ((long long)h * num_tokens + t) * head_dim;
|
|
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
|
|
|
int tid = threadIdx.x;
|
|
int blockSize = blockDim.x;
|
|
|
|
// Per-thread strided copy. head_dim is typically 128 and blockSize is
|
|
// 128, so each thread copies exactly one element — but the loop keeps
|
|
// the kernel correct for non-128 head_dim configs (Phi-style 64, etc.).
|
|
for (int d = tid; d < head_dim; d += blockSize) {
|
|
k_pool[dst_off + d] = k_src[src_off + d];
|
|
v_pool[dst_off + d] = v_src[src_off + d];
|
|
}
|
|
}
|
|
|
|
// Batched variant: writes one new K/V token per sequence into a paged
|
|
// pool, indexed by a per-batch block table that also drives the paged
|
|
// attention kernel. Used in the decode path where every seq advances
|
|
// by exactly one position per step.
|
|
//
|
|
// Source layouts (BF16, contiguous):
|
|
// k_src, v_src : [batch, num_kv_heads, head_dim]
|
|
//
|
|
// Pool layouts (BF16, contiguous):
|
|
// k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim]
|
|
//
|
|
// block_tables : int32 [batch, max_blocks_per_seq]
|
|
// kv_lens : int32 [batch] (current seq_len BEFORE this step + 1
|
|
// — i.e. the same buffer paged attention
|
|
// reads. The new token's logical index
|
|
// is `kv_lens[b] - 1`.)
|
|
//
|
|
// Grid : (batch, num_kv_heads)
|
|
// Block: head_dim threads.
|
|
|
|
__global__ void reshape_and_cache_batched_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ k_src,
|
|
const __nv_bfloat16* __restrict__ v_src,
|
|
__nv_bfloat16* __restrict__ k_pool,
|
|
__nv_bfloat16* __restrict__ v_pool,
|
|
const int* __restrict__ block_tables,
|
|
const int* __restrict__ kv_lens,
|
|
int num_heads, int head_dim,
|
|
int block_size, int max_blocks_per_seq
|
|
) {
|
|
int b = blockIdx.x;
|
|
int h = blockIdx.y;
|
|
|
|
int new_pos = kv_lens[b] - 1;
|
|
int logical_blk = new_pos / block_size;
|
|
int slot_in_blk = new_pos - logical_blk * block_size;
|
|
int phys = block_tables[b * max_blocks_per_seq + logical_blk];
|
|
|
|
long long src_off = ((long long)b * num_heads + h) * head_dim;
|
|
long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim;
|
|
|
|
int tid = threadIdx.x;
|
|
int blockSize = blockDim.x;
|
|
for (int d = tid; d < head_dim; d += blockSize) {
|
|
k_pool[dst_off + d] = k_src[src_off + d];
|
|
v_pool[dst_off + d] = v_src[src_off + d];
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_reshape_and_cache_bf16(
|
|
const void* k_src, const void* v_src,
|
|
void* k_pool, void* v_pool,
|
|
const void* block_ids,
|
|
int num_tokens, int num_heads,
|
|
int head_dim, int start_pos, int block_size,
|
|
void* stream
|
|
) {
|
|
if (num_tokens <= 0) return;
|
|
int threads = head_dim < 32 ? 32 : head_dim;
|
|
if (threads > 1024) threads = 1024;
|
|
dim3 grid(num_tokens, num_heads);
|
|
reshape_and_cache_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)k_src,
|
|
(const __nv_bfloat16*)v_src,
|
|
(__nv_bfloat16*)k_pool,
|
|
(__nv_bfloat16*)v_pool,
|
|
(const int*)block_ids,
|
|
num_tokens, num_heads,
|
|
head_dim, start_pos, block_size
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_reshape_and_cache_batched_bf16(
|
|
const void* k_src, const void* v_src,
|
|
void* k_pool, void* v_pool,
|
|
const void* block_tables, const void* kv_lens,
|
|
int batch, int num_heads,
|
|
int head_dim, int block_size, int max_blocks_per_seq,
|
|
void* stream
|
|
) {
|
|
if (batch <= 0 || num_heads <= 0) return;
|
|
int threads = head_dim < 32 ? 32 : head_dim;
|
|
if (threads > 1024) threads = 1024;
|
|
dim3 grid(batch, num_heads);
|
|
reshape_and_cache_batched_bf16_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)k_src,
|
|
(const __nv_bfloat16*)v_src,
|
|
(__nv_bfloat16*)k_pool,
|
|
(__nv_bfloat16*)v_pool,
|
|
(const int*)block_tables,
|
|
(const int*)kv_lens,
|
|
num_heads, head_dim, block_size, max_blocks_per_seq
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
// Copy one token's K/V from src_pos to dst_pos within one pool.
|
|
// Grid: (num_kv_heads,). Block: head_dim threads.
|
|
// pool: [num_blocks_total, num_kv_heads, block_size, head_dim]
|
|
// block_ids: [max_blocks] for this sequence (logical → physical block map).
|
|
__global__ void copy_kv_position_kernel(
|
|
__nv_bfloat16* __restrict__ pool,
|
|
const int* __restrict__ block_ids,
|
|
int src_pos, int dst_pos,
|
|
int head_dim, int block_size
|
|
) {
|
|
int h = blockIdx.x;
|
|
int d = threadIdx.x;
|
|
if (d >= head_dim) return;
|
|
|
|
int num_kv_heads = gridDim.x;
|
|
|
|
int src_blk = src_pos / block_size;
|
|
int src_slot = src_pos % block_size;
|
|
int src_phys = block_ids[src_blk];
|
|
|
|
int dst_blk = dst_pos / block_size;
|
|
int dst_slot = dst_pos % block_size;
|
|
int dst_phys = block_ids[dst_blk];
|
|
|
|
long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim
|
|
+ src_slot * head_dim + d;
|
|
long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim
|
|
+ dst_slot * head_dim + d;
|
|
|
|
pool[dst_off] = pool[src_off];
|
|
}
|
|
|
|
void launch_copy_kv_position(
|
|
void* k_pool, void* v_pool,
|
|
const int* block_ids,
|
|
int src_pos, int dst_pos,
|
|
int num_kv_heads, int head_dim, int block_size,
|
|
void* stream
|
|
) {
|
|
int threads = head_dim < 32 ? 32 : head_dim;
|
|
if (threads > 1024) threads = 1024;
|
|
dim3 grid(num_kv_heads);
|
|
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
|
(__nv_bfloat16*)k_pool, block_ids,
|
|
src_pos, dst_pos, head_dim, block_size
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
copy_kv_position_kernel<<<grid, threads, 0, (cudaStream_t)stream>>>(
|
|
(__nv_bfloat16*)v_pool, block_ids,
|
|
src_pos, dst_pos, head_dim, block_size
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|