CUDA layer for the paged-KV + swap work: - csrc: new paged_attention.cu plus updates across attention/gemm/norm/ activation/embedding/reduce kernels and common.cuh. - xserv-kernels: new dispatch module and kernel-binding updates. - xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap pool backing) and offset-aware D2H/H2D copies used to move KV blocks between the GPU pool and pinned host memory. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
141 lines
4.7 KiB
Plaintext
141 lines
4.7 KiB
Plaintext
#include "../common.cuh"
|
|
|
|
// RMSNorm: y[i] = x[i] * rsqrt(mean(x²) + eps) * gamma[i]
|
|
// Each block processes one row of shape [hidden_size].
|
|
|
|
__global__ void rmsnorm_f32(
|
|
const float* __restrict__ x,
|
|
const float* __restrict__ gamma,
|
|
float* __restrict__ out,
|
|
int hidden_size, float eps
|
|
) {
|
|
int row = blockIdx.x;
|
|
const float* x_row = x + row * hidden_size;
|
|
float* out_row = out + row * hidden_size;
|
|
|
|
float sum_sq = 0.0f;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
float v = x_row[i];
|
|
sum_sq += v * v;
|
|
}
|
|
sum_sq = block_reduce_sum(sum_sq);
|
|
|
|
__shared__ float s_rms_inv;
|
|
if (threadIdx.x == 0) {
|
|
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
|
}
|
|
__syncthreads();
|
|
|
|
float rms_inv = s_rms_inv;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
out_row[i] = x_row[i] * rms_inv * gamma[i];
|
|
}
|
|
}
|
|
|
|
__global__ void rmsnorm_bf16(
|
|
const __nv_bfloat16* __restrict__ x,
|
|
const __nv_bfloat16* __restrict__ gamma,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int hidden_size, float eps
|
|
) {
|
|
int row = blockIdx.x;
|
|
const __nv_bfloat16* x_row = x + row * hidden_size;
|
|
__nv_bfloat16* out_row = out + row * hidden_size;
|
|
|
|
float sum_sq = 0.0f;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
float v = __bfloat162float(x_row[i]);
|
|
sum_sq += v * v;
|
|
}
|
|
sum_sq = block_reduce_sum(sum_sq);
|
|
|
|
__shared__ float s_rms_inv;
|
|
if (threadIdx.x == 0) {
|
|
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
|
}
|
|
__syncthreads();
|
|
|
|
float rms_inv = s_rms_inv;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
float v = __bfloat162float(x_row[i]);
|
|
float g = __bfloat162float(gamma[i]);
|
|
out_row[i] = __float2bfloat16(v * rms_inv * g);
|
|
}
|
|
}
|
|
|
|
// Fused Add + RMSNorm: sum_out = x + residual, normed_out = rmsnorm(sum_out, gamma, eps)
|
|
// Each block handles one row of [hidden_size].
|
|
__global__ void add_rmsnorm_bf16(
|
|
const __nv_bfloat16* __restrict__ x,
|
|
const __nv_bfloat16* __restrict__ residual,
|
|
const __nv_bfloat16* __restrict__ gamma,
|
|
__nv_bfloat16* __restrict__ normed_out,
|
|
__nv_bfloat16* __restrict__ sum_out,
|
|
int hidden_size, float eps
|
|
) {
|
|
int row = blockIdx.x;
|
|
const __nv_bfloat16* x_row = x + row * hidden_size;
|
|
const __nv_bfloat16* res_row = residual + row * hidden_size;
|
|
__nv_bfloat16* sum_row = sum_out + row * hidden_size;
|
|
__nv_bfloat16* norm_row = normed_out + row * hidden_size;
|
|
|
|
// Pass 1: compute sum = x + residual, and accumulate sum_sq
|
|
float sum_sq = 0.0f;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
float s = __bfloat162float(x_row[i]) + __bfloat162float(res_row[i]);
|
|
sum_row[i] = __float2bfloat16(s);
|
|
sum_sq += s * s;
|
|
}
|
|
sum_sq = block_reduce_sum(sum_sq);
|
|
|
|
__shared__ float s_rms_inv;
|
|
if (threadIdx.x == 0) {
|
|
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
|
}
|
|
__syncthreads();
|
|
|
|
// Pass 2: normed_out = sum * rms_inv * gamma
|
|
float rms_inv = s_rms_inv;
|
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
|
float s = __bfloat162float(sum_row[i]);
|
|
float g = __bfloat162float(gamma[i]);
|
|
norm_row[i] = __float2bfloat16(s * rms_inv * g);
|
|
}
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
|
int rows, int hidden_size, float eps, void* stream) {
|
|
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
|
if (block < 32) block = 32;
|
|
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
|
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
|
int rows, int hidden_size, float eps, void* stream) {
|
|
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
|
if (block < 32) block = 32;
|
|
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
|
(__nv_bfloat16*)out, hidden_size, eps);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
|
|
void* normed_out, void* sum_out,
|
|
int rows, int hidden_size, float eps, void* stream) {
|
|
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
|
if (block < 32) block = 32;
|
|
add_rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)x, (const __nv_bfloat16*)residual,
|
|
(const __nv_bfloat16*)gamma,
|
|
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
|
|
hidden_size, eps);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|