#include #include "../common.cuh" // Scatter [num_tokens] new K/V into a paged KV pool for ONE sequence. // // Source layouts (BF16, contiguous): // k_src, v_src : [num_kv_heads, num_tokens, head_dim] (head-major) // // Pool layouts (BF16, contiguous): // k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim] // // For token t (0 <= t < num_tokens): // p = start_pos + t // logical_blk = p / BLOCK_SIZE // slot_in_blk = p % BLOCK_SIZE // phys = block_ids[logical_blk] // pool[phys, h, slot_in_blk, :] := src[h, t, :] // // Replaces a Rust-side per-token, per-head cudaMemcpy loop. With Qwen3-8B // (8 KV heads, 36 layers) and a 1024-token prefill, that loop fired // ~290k device-side memcpys; one kernel launch per layer is dramatically // less overhead. // // Grid : (num_tokens, num_kv_heads) // Block: head_dim threads (≤128 in practice; head_dim is padded to a // multiple of 32 by the model and all our shipping configs are // 128, so a single warp's worth handles two slots in flight). __global__ void reshape_and_cache_bf16_kernel( const __nv_bfloat16* __restrict__ k_src, const __nv_bfloat16* __restrict__ v_src, __nv_bfloat16* __restrict__ k_pool, __nv_bfloat16* __restrict__ v_pool, const int* __restrict__ block_ids, int num_tokens, int num_heads, int head_dim, int start_pos, int block_size ) { int t = blockIdx.x; int h = blockIdx.y; if (t >= num_tokens || h >= num_heads) return; int p = start_pos + t; int logical_blk = p / block_size; int slot_in_blk = p - logical_blk * block_size; int phys = block_ids[logical_blk]; long long src_off = ((long long)h * num_tokens + t) * head_dim; long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim; int tid = threadIdx.x; int blockSize = blockDim.x; // Per-thread strided copy. head_dim is typically 128 and blockSize is // 128, so each thread copies exactly one element — but the loop keeps // the kernel correct for non-128 head_dim configs (Phi-style 64, etc.). for (int d = tid; d < head_dim; d += blockSize) { k_pool[dst_off + d] = k_src[src_off + d]; v_pool[dst_off + d] = v_src[src_off + d]; } } // Batched variant: writes one new K/V token per sequence into a paged // pool, indexed by a per-batch block table that also drives the paged // attention kernel. Used in the decode path where every seq advances // by exactly one position per step. // // Source layouts (BF16, contiguous): // k_src, v_src : [batch, num_kv_heads, head_dim] // // Pool layouts (BF16, contiguous): // k_pool, v_pool : [num_blocks_total, num_kv_heads, BLOCK_SIZE, head_dim] // // block_tables : int32 [batch, max_blocks_per_seq] // kv_lens : int32 [batch] (current seq_len BEFORE this step + 1 // — i.e. the same buffer paged attention // reads. The new token's logical index // is `kv_lens[b] - 1`.) // // Grid : (batch, num_kv_heads) // Block: head_dim threads. __global__ void reshape_and_cache_batched_bf16_kernel( const __nv_bfloat16* __restrict__ k_src, const __nv_bfloat16* __restrict__ v_src, __nv_bfloat16* __restrict__ k_pool, __nv_bfloat16* __restrict__ v_pool, const int* __restrict__ block_tables, const int* __restrict__ kv_lens, int num_heads, int head_dim, int block_size, int max_blocks_per_seq ) { int b = blockIdx.x; int h = blockIdx.y; int new_pos = kv_lens[b] - 1; int logical_blk = new_pos / block_size; int slot_in_blk = new_pos - logical_blk * block_size; int phys = block_tables[b * max_blocks_per_seq + logical_blk]; long long src_off = ((long long)b * num_heads + h) * head_dim; long long dst_off = (((long long)phys * num_heads + h) * block_size + slot_in_blk) * head_dim; int tid = threadIdx.x; int blockSize = blockDim.x; for (int d = tid; d < head_dim; d += blockSize) { k_pool[dst_off + d] = k_src[src_off + d]; v_pool[dst_off + d] = v_src[src_off + d]; } } extern "C" { void launch_reshape_and_cache_bf16( const void* k_src, const void* v_src, void* k_pool, void* v_pool, const void* block_ids, int num_tokens, int num_heads, int head_dim, int start_pos, int block_size, void* stream ) { if (num_tokens <= 0) return; int threads = head_dim < 32 ? 32 : head_dim; if (threads > 1024) threads = 1024; dim3 grid(num_tokens, num_heads); reshape_and_cache_bf16_kernel<<>>( (const __nv_bfloat16*)k_src, (const __nv_bfloat16*)v_src, (__nv_bfloat16*)k_pool, (__nv_bfloat16*)v_pool, (const int*)block_ids, num_tokens, num_heads, head_dim, start_pos, block_size ); CUDA_CHECK_LAST_ERROR(); } void launch_reshape_and_cache_batched_bf16( const void* k_src, const void* v_src, void* k_pool, void* v_pool, const void* block_tables, const void* kv_lens, int batch, int num_heads, int head_dim, int block_size, int max_blocks_per_seq, void* stream ) { if (batch <= 0 || num_heads <= 0) return; int threads = head_dim < 32 ? 32 : head_dim; if (threads > 1024) threads = 1024; dim3 grid(batch, num_heads); reshape_and_cache_batched_bf16_kernel<<>>( (const __nv_bfloat16*)k_src, (const __nv_bfloat16*)v_src, (__nv_bfloat16*)k_pool, (__nv_bfloat16*)v_pool, (const int*)block_tables, (const int*)kv_lens, num_heads, head_dim, block_size, max_blocks_per_seq ); CUDA_CHECK_LAST_ERROR(); } // Copy one token's K/V from src_pos to dst_pos within one pool. // Grid: (num_kv_heads,). Block: head_dim threads. // pool: [num_blocks_total, num_kv_heads, block_size, head_dim] // block_ids: [max_blocks] for this sequence (logical → physical block map). __global__ void copy_kv_position_kernel( __nv_bfloat16* __restrict__ pool, const int* __restrict__ block_ids, int src_pos, int dst_pos, int head_dim, int block_size ) { int h = blockIdx.x; int d = threadIdx.x; if (d >= head_dim) return; int num_kv_heads = gridDim.x; int src_blk = src_pos / block_size; int src_slot = src_pos % block_size; int src_phys = block_ids[src_blk]; int dst_blk = dst_pos / block_size; int dst_slot = dst_pos % block_size; int dst_phys = block_ids[dst_blk]; long long src_off = ((long long)src_phys * num_kv_heads + h) * block_size * head_dim + src_slot * head_dim + d; long long dst_off = ((long long)dst_phys * num_kv_heads + h) * block_size * head_dim + dst_slot * head_dim + d; pool[dst_off] = pool[src_off]; } void launch_copy_kv_position( void* k_pool, void* v_pool, const int* block_ids, int src_pos, int dst_pos, int num_kv_heads, int head_dim, int block_size, void* stream ) { int threads = head_dim < 32 ? 32 : head_dim; if (threads > 1024) threads = 1024; dim3 grid(num_kv_heads); copy_kv_position_kernel<<>>( (__nv_bfloat16*)k_pool, block_ids, src_pos, dst_pos, head_dim, block_size ); CUDA_CHECK_LAST_ERROR(); copy_kv_position_kernel<<>>( (__nv_bfloat16*)v_pool, block_ids, src_pos, dst_pos, head_dim, block_size ); CUDA_CHECK_LAST_ERROR(); } }