CUDA layer for the paged-KV + swap work: - csrc: new paged_attention.cu plus updates across attention/gemm/norm/ activation/embedding/reduce kernels and common.cuh. - xserv-kernels: new dispatch module and kernel-binding updates. - xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap pool backing) and offset-aware D2H/H2D copies used to move KV blocks between the GPU pool and pinned host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
65 lines
1.8 KiB
Plaintext
65 lines
1.8 KiB
Plaintext
#pragma once
|
|
#include <cuda_bf16.h>
|
|
|
|
// --- Warp-level reductions (no shared memory needed) ---
|
|
|
|
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
val += __shfl_down_sync(0xffffffff, val, offset);
|
|
return val;
|
|
}
|
|
|
|
__device__ __forceinline__ float warp_reduce_max(float val) {
|
|
#pragma unroll
|
|
for (int offset = 16; offset > 0; offset >>= 1)
|
|
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
|
|
return val;
|
|
}
|
|
|
|
// --- Block-level reductions ---
|
|
|
|
__device__ __forceinline__ float block_reduce_sum(float val) {
|
|
__shared__ float shared[32];
|
|
int lane = threadIdx.x & 31;
|
|
int warp_id = threadIdx.x >> 5;
|
|
int num_warps = (blockDim.x + 31) >> 5;
|
|
|
|
val = warp_reduce_sum(val);
|
|
if (lane == 0) shared[warp_id] = val;
|
|
__syncthreads();
|
|
|
|
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : 0.0f;
|
|
if (warp_id == 0) val = warp_reduce_sum(val);
|
|
return val;
|
|
}
|
|
|
|
__device__ __forceinline__ float block_reduce_max(float val) {
|
|
__shared__ float shared[32];
|
|
int lane = threadIdx.x & 31;
|
|
int warp_id = threadIdx.x >> 5;
|
|
int num_warps = (blockDim.x + 31) >> 5;
|
|
|
|
val = warp_reduce_max(val);
|
|
if (lane == 0) shared[warp_id] = val;
|
|
__syncthreads();
|
|
|
|
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : -INFINITY;
|
|
if (warp_id == 0) val = warp_reduce_max(val);
|
|
return val;
|
|
}
|
|
|
|
// --- Launch error checking (debug builds only) ---
|
|
#ifdef NDEBUG
|
|
#define CUDA_CHECK_LAST_ERROR() ((void)0)
|
|
#else
|
|
#include <cstdio>
|
|
#define CUDA_CHECK_LAST_ERROR() do { \
|
|
cudaError_t err = cudaGetLastError(); \
|
|
if (err != cudaSuccess) { \
|
|
fprintf(stderr, "CUDA kernel launch error at %s:%d: %s\n", \
|
|
__FILE__, __LINE__, cudaGetErrorString(err)); \
|
|
} \
|
|
} while(0)
|
|
#endif
|