Files
xserv/csrc/normalization/rmsnorm.cu
Gahow Wang 4c3f914459 kernels/cuda: paged-attention kernel, dispatch, pinned host memory
CUDA layer for the paged-KV + swap work:
- csrc: new paged_attention.cu plus updates across attention/gemm/norm/
  activation/embedding/reduce kernels and common.cuh.
- xserv-kernels: new dispatch module and kernel-binding updates.
- xserv-cuda: cudaMallocHost/FreeHost bindings + PinnedBuffer (host swap
  pool backing) and offset-aware D2H/H2D copies used to move KV blocks
  between the GPU pool and pinned host memory.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-28 19:58:36 +08:00

141 lines
4.7 KiB
Plaintext

#include "../common.cuh"
// RMSNorm: y[i] = x[i] * rsqrt(mean(x²) + eps) * gamma[i]
// Each block processes one row of shape [hidden_size].
__global__ void rmsnorm_f32(
const float* __restrict__ x,
const float* __restrict__ gamma,
float* __restrict__ out,
int hidden_size, float eps
) {
int row = blockIdx.x;
const float* x_row = x + row * hidden_size;
float* out_row = out + row * hidden_size;
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float v = x_row[i];
sum_sq += v * v;
}
sum_sq = block_reduce_sum(sum_sq);
__shared__ float s_rms_inv;
if (threadIdx.x == 0) {
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
}
__syncthreads();
float rms_inv = s_rms_inv;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
out_row[i] = x_row[i] * rms_inv * gamma[i];
}
}
__global__ void rmsnorm_bf16(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ gamma,
__nv_bfloat16* __restrict__ out,
int hidden_size, float eps
) {
int row = blockIdx.x;
const __nv_bfloat16* x_row = x + row * hidden_size;
__nv_bfloat16* out_row = out + row * hidden_size;
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float v = __bfloat162float(x_row[i]);
sum_sq += v * v;
}
sum_sq = block_reduce_sum(sum_sq);
__shared__ float s_rms_inv;
if (threadIdx.x == 0) {
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
}
__syncthreads();
float rms_inv = s_rms_inv;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float v = __bfloat162float(x_row[i]);
float g = __bfloat162float(gamma[i]);
out_row[i] = __float2bfloat16(v * rms_inv * g);
}
}
// Fused Add + RMSNorm: sum_out = x + residual, normed_out = rmsnorm(sum_out, gamma, eps)
// Each block handles one row of [hidden_size].
__global__ void add_rmsnorm_bf16(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ residual,
const __nv_bfloat16* __restrict__ gamma,
__nv_bfloat16* __restrict__ normed_out,
__nv_bfloat16* __restrict__ sum_out,
int hidden_size, float eps
) {
int row = blockIdx.x;
const __nv_bfloat16* x_row = x + row * hidden_size;
const __nv_bfloat16* res_row = residual + row * hidden_size;
__nv_bfloat16* sum_row = sum_out + row * hidden_size;
__nv_bfloat16* norm_row = normed_out + row * hidden_size;
// Pass 1: compute sum = x + residual, and accumulate sum_sq
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float s = __bfloat162float(x_row[i]) + __bfloat162float(res_row[i]);
sum_row[i] = __float2bfloat16(s);
sum_sq += s * s;
}
sum_sq = block_reduce_sum(sum_sq);
__shared__ float s_rms_inv;
if (threadIdx.x == 0) {
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
}
__syncthreads();
// Pass 2: normed_out = sum * rms_inv * gamma
float rms_inv = s_rms_inv;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float s = __bfloat162float(sum_row[i]);
float g = __bfloat162float(gamma[i]);
norm_row[i] = __float2bfloat16(s * rms_inv * g);
}
}
extern "C" {
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
int rows, int hidden_size, float eps, void* stream) {
int block = (hidden_size < 1024) ? hidden_size : 1024;
if (block < 32) block = 32;
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
int rows, int hidden_size, float eps, void* stream) {
int block = (hidden_size < 1024) ? hidden_size : 1024;
if (block < 32) block = 32;
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
(__nv_bfloat16*)out, hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
void launch_add_rmsnorm_bf16(const void* x, const void* residual, const void* gamma,
void* normed_out, void* sum_out,
int rows, int hidden_size, float eps, void* stream) {
int block = (hidden_size < 1024) ? hidden_size : 1024;
if (block < 32) block = 32;
add_rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_bfloat16*)residual,
(const __nv_bfloat16*)gamma,
(__nv_bfloat16*)normed_out, (__nv_bfloat16*)sum_out,
hidden_size, eps);
CUDA_CHECK_LAST_ERROR();
}
}