phase 4: transformer core kernels
CUDA kernels (csrc/): - common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max - normalization/rmsnorm.cu: RMSNorm (F32 + BF16) - normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16) - activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16) - reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16) - embedding/embedding.cu: gather lookup (F32 + BF16) - embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16) Rust wrappers (xserv-kernels/src/): - rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs - RopeCache struct with GPU-side precomputation Tests: 12 new tests (ops_test.rs), all passing with good precision: - F32: max_err 1e-6 ~ 1e-9 - BF16: max_err 2e-3 ~ 7e-3 Total: 29 kernel tests + 27 prior = 56 tests passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
66
csrc/activation/activations.cu
Normal file
66
csrc/activation/activations.cu
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
|
||||
// GELU (tanh approximation):
|
||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
__device__ __forceinline__ float gelu_f(float x) {
|
||||
const float SQRT_2_OVER_PI = 0.7978845608f;
|
||||
float cube = x * x * x;
|
||||
float inner = SQRT_2_OVER_PI * (x + 0.044715f * cube);
|
||||
return 0.5f * x * (1.0f + tanhf(inner));
|
||||
}
|
||||
|
||||
// SiLU (Swish): silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
||||
__device__ __forceinline__ float silu_f(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
__global__ void gelu_f32(const float* x, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = gelu_f(x[idx]);
|
||||
}
|
||||
|
||||
__global__ void gelu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(gelu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
__global__ void silu_f32(const float* x, float* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = silu_f(x[idx]);
|
||||
}
|
||||
|
||||
__global__ void silu_bf16(const __nv_bfloat16* x, __nv_bfloat16* out, int n) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < n) out[idx] = __float2bfloat16(silu_f(__bfloat162float(x[idx])));
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_gelu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
}
|
||||
|
||||
void launch_gelu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
gelu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
void launch_silu_f32(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_f32<<<grid, block, 0, (cudaStream_t)stream>>>((const float*)x, (float*)out, n);
|
||||
}
|
||||
|
||||
void launch_silu_bf16(const void* x, void* out, int n, void* stream) {
|
||||
int block = 256;
|
||||
int grid = (n + block - 1) / block;
|
||||
silu_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, n);
|
||||
}
|
||||
|
||||
}
|
||||
50
csrc/common.cuh
Normal file
50
csrc/common.cuh
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// --- Warp-level reductions (no shared memory needed) ---
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warp_reduce_max(float val) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
|
||||
return val;
|
||||
}
|
||||
|
||||
// --- Block-level reductions ---
|
||||
|
||||
__device__ __forceinline__ float block_reduce_sum(float val) {
|
||||
__shared__ float shared[32];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp_id = threadIdx.x >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
val = warp_reduce_sum(val);
|
||||
if (lane == 0) shared[warp_id] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : 0.0f;
|
||||
if (warp_id == 0) val = warp_reduce_sum(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float block_reduce_max(float val) {
|
||||
__shared__ float shared[32];
|
||||
int lane = threadIdx.x & 31;
|
||||
int warp_id = threadIdx.x >> 5;
|
||||
int num_warps = (blockDim.x + 31) >> 5;
|
||||
|
||||
val = warp_reduce_max(val);
|
||||
if (lane == 0) shared[warp_id] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : -INFINITY;
|
||||
if (warp_id == 0) val = warp_reduce_max(val);
|
||||
return val;
|
||||
}
|
||||
55
csrc/embedding/embedding.cu
Normal file
55
csrc/embedding/embedding.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// Embedding lookup: out[seq_idx] = table[token_ids[seq_idx]]
|
||||
// Grid: num_tokens, Block: handles hidden_size elements per token.
|
||||
|
||||
__global__ void embedding_f32(
|
||||
const float* __restrict__ table, // [vocab_size, hidden_size]
|
||||
const int* __restrict__ token_ids, // [num_tokens]
|
||||
float* __restrict__ out, // [num_tokens, hidden_size]
|
||||
int hidden_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
const float* row = table + tid * hidden_size;
|
||||
float* dst = out + token_idx * hidden_size;
|
||||
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
dst[i] = row[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void embedding_bf16(
|
||||
const __nv_bfloat16* __restrict__ table,
|
||||
const int* __restrict__ token_ids,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int tid = token_ids[token_idx];
|
||||
const __nv_bfloat16* row = table + tid * hidden_size;
|
||||
__nv_bfloat16* dst = out + token_idx * hidden_size;
|
||||
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
dst[i] = row[i];
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_embedding_f32(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_f32<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)table, (const int*)token_ids, (float*)out, hidden_size);
|
||||
}
|
||||
|
||||
void launch_embedding_bf16(const void* table, const void* token_ids, void* out,
|
||||
int num_tokens, int hidden_size, void* stream) {
|
||||
int block = (hidden_size < 256) ? hidden_size : 256;
|
||||
embedding_bf16<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)table, (const int*)token_ids,
|
||||
(__nv_bfloat16*)out, hidden_size);
|
||||
}
|
||||
|
||||
}
|
||||
116
csrc/embedding/rope.cu
Normal file
116
csrc/embedding/rope.cu
Normal file
@@ -0,0 +1,116 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <math.h>
|
||||
|
||||
// RoPE: Rotary Position Embedding
|
||||
// For each pair (x[2i], x[2i+1]) at position `pos`:
|
||||
// y[2i] = x[2i] * cos - x[2i+1] * sin
|
||||
// y[2i+1] = x[2i] * sin + x[2i+1] * cos
|
||||
// where cos/sin come from precomputed cos_cache/sin_cache.
|
||||
//
|
||||
// cos_cache[pos][i] = cos(pos * freq[i])
|
||||
// sin_cache[pos][i] = sin(pos * freq[i])
|
||||
// freq[i] = 1.0 / (theta ^ (2i / head_dim))
|
||||
|
||||
// Apply RoPE in-place to Q or K tensor.
|
||||
// x shape: [num_tokens, num_heads, head_dim]
|
||||
// cos_cache, sin_cache shape: [max_seq_len, head_dim/2]
|
||||
// positions: [num_tokens] — the position index for each token
|
||||
|
||||
__global__ void rope_f32(
|
||||
float* __restrict__ x, // [num_tokens, num_heads, head_dim]
|
||||
const float* __restrict__ cos_cache, // [max_seq_len, half_dim]
|
||||
const float* __restrict__ sin_cache, // [max_seq_len, half_dim]
|
||||
const int* __restrict__ positions, // [num_tokens]
|
||||
int num_heads, int head_dim
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int head_idx = blockIdx.y;
|
||||
int half_dim = head_dim / 2;
|
||||
int pair_idx = threadIdx.x; // which pair (0..half_dim)
|
||||
|
||||
if (pair_idx >= half_dim) return;
|
||||
|
||||
int pos = positions[token_idx];
|
||||
float cos_val = cos_cache[pos * half_dim + pair_idx];
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = x[base + 2 * pair_idx];
|
||||
float x1 = x[base + 2 * pair_idx + 1];
|
||||
|
||||
x[base + 2 * pair_idx] = x0 * cos_val - x1 * sin_val;
|
||||
x[base + 2 * pair_idx + 1] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
|
||||
__global__ void rope_bf16(
|
||||
__nv_bfloat16* __restrict__ x,
|
||||
const float* __restrict__ cos_cache,
|
||||
const float* __restrict__ sin_cache,
|
||||
const int* __restrict__ positions,
|
||||
int num_heads, int head_dim
|
||||
) {
|
||||
int token_idx = blockIdx.x;
|
||||
int head_idx = blockIdx.y;
|
||||
int half_dim = head_dim / 2;
|
||||
int pair_idx = threadIdx.x;
|
||||
|
||||
if (pair_idx >= half_dim) return;
|
||||
|
||||
int pos = positions[token_idx];
|
||||
float cos_val = cos_cache[pos * half_dim + pair_idx];
|
||||
float sin_val = sin_cache[pos * half_dim + pair_idx];
|
||||
|
||||
int base = (token_idx * num_heads + head_idx) * head_dim;
|
||||
float x0 = __bfloat162float(x[base + 2 * pair_idx]);
|
||||
float x1 = __bfloat162float(x[base + 2 * pair_idx + 1]);
|
||||
|
||||
x[base + 2 * pair_idx] = __float2bfloat16(x0 * cos_val - x1 * sin_val);
|
||||
x[base + 2 * pair_idx + 1] = __float2bfloat16(x0 * sin_val + x1 * cos_val);
|
||||
}
|
||||
|
||||
// Precompute cos/sin cache on GPU
|
||||
__global__ void compute_rope_cache(
|
||||
float* __restrict__ cos_cache, // [max_seq_len, half_dim]
|
||||
float* __restrict__ sin_cache,
|
||||
int max_seq_len, int half_dim, float theta
|
||||
) {
|
||||
int pos = blockIdx.x;
|
||||
int i = threadIdx.x;
|
||||
if (i >= half_dim) return;
|
||||
|
||||
float freq = 1.0f / powf(theta, (float)(2 * i) / (float)(2 * half_dim));
|
||||
float angle = (float)pos * freq;
|
||||
cos_cache[pos * half_dim + i] = cosf(angle);
|
||||
sin_cache[pos * half_dim + i] = sinf(angle);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rope_f32(void* x, const void* cos_cache, const void* sin_cache,
|
||||
const void* positions, int num_tokens, int num_heads,
|
||||
int head_dim, void* stream) {
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
int block = head_dim / 2;
|
||||
rope_f32<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(float*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_rope_bf16(void* x, const void* cos_cache, const void* sin_cache,
|
||||
const void* positions, int num_tokens, int num_heads,
|
||||
int head_dim, void* stream) {
|
||||
dim3 grid(num_tokens, num_heads);
|
||||
int block = head_dim / 2;
|
||||
rope_bf16<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)x, (const float*)cos_cache, (const float*)sin_cache,
|
||||
(const int*)positions, num_heads, head_dim);
|
||||
}
|
||||
|
||||
void launch_compute_rope_cache(void* cos_cache, void* sin_cache,
|
||||
int max_seq_len, int half_dim, float theta,
|
||||
void* stream) {
|
||||
compute_rope_cache<<<max_seq_len, half_dim, 0, (cudaStream_t)stream>>>(
|
||||
(float*)cos_cache, (float*)sin_cache, max_seq_len, half_dim, theta);
|
||||
}
|
||||
|
||||
}
|
||||
102
csrc/normalization/layernorm.cu
Normal file
102
csrc/normalization/layernorm.cu
Normal file
@@ -0,0 +1,102 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// LayerNorm: y[i] = gamma[i] * (x[i] - mean) / sqrt(var + eps) + beta[i]
|
||||
// Each block processes one row of shape [hidden_size].
|
||||
|
||||
__global__ void layernorm_f32(
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ gamma,
|
||||
const float* __restrict__ beta,
|
||||
float* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * hidden_size;
|
||||
float* out_row = out + row * hidden_size;
|
||||
|
||||
// Welford online: compute mean and variance in one pass
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = x_row[i];
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
out_row[i] = gamma[i] * (x_row[i] - mean) * inv_std + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void layernorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
const __nv_bfloat16* __restrict__ beta,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
__nv_bfloat16* out_row = out + row * hidden_size;
|
||||
|
||||
float local_sum = 0.0f;
|
||||
float local_sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
local_sum += v;
|
||||
local_sum_sq += v * v;
|
||||
}
|
||||
local_sum = block_reduce_sum(local_sum);
|
||||
local_sum_sq = block_reduce_sum(local_sum_sq);
|
||||
|
||||
__shared__ float s_mean, s_inv_std;
|
||||
if (threadIdx.x == 0) {
|
||||
float mean = local_sum / hidden_size;
|
||||
float var = local_sum_sq / hidden_size - mean * mean;
|
||||
s_mean = mean;
|
||||
s_inv_std = rsqrtf(var + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float mean = s_mean;
|
||||
float inv_std = s_inv_std;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
float b = __bfloat162float(beta[i]);
|
||||
out_row[i] = __float2bfloat16(g * (v - mean) * inv_std + b);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_layernorm_f32(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
layernorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (const float*)beta,
|
||||
(float*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
void launch_layernorm_bf16(const void* x, const void* gamma, const void* beta,
|
||||
void* out, int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
layernorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma, (const __nv_bfloat16*)beta,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
}
|
||||
83
csrc/normalization/rmsnorm.cu
Normal file
83
csrc/normalization/rmsnorm.cu
Normal file
@@ -0,0 +1,83 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// RMSNorm: y[i] = x[i] * rsqrt(mean(x²) + eps) * gamma[i]
|
||||
// Each block processes one row of shape [hidden_size].
|
||||
|
||||
__global__ void rmsnorm_f32(
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ gamma,
|
||||
float* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * hidden_size;
|
||||
float* out_row = out + row * hidden_size;
|
||||
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = x_row[i];
|
||||
sum_sq += v * v;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
out_row[i] = x_row[i] * rms_inv * gamma[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void rmsnorm_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ gamma,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int hidden_size, float eps
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * hidden_size;
|
||||
__nv_bfloat16* out_row = out + row * hidden_size;
|
||||
|
||||
float sum_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
sum_sq += v * v;
|
||||
}
|
||||
sum_sq = block_reduce_sum(sum_sq);
|
||||
|
||||
__shared__ float s_rms_inv;
|
||||
if (threadIdx.x == 0) {
|
||||
s_rms_inv = rsqrtf(sum_sq / hidden_size + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float rms_inv = s_rms_inv;
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float v = __bfloat162float(x_row[i]);
|
||||
float g = __bfloat162float(gamma[i]);
|
||||
out_row[i] = __float2bfloat16(v * rms_inv * g);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_rmsnorm_f32(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
rmsnorm_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (const float*)gamma, (float*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
void launch_rmsnorm_bf16(const void* x, const void* gamma, void* out,
|
||||
int rows, int hidden_size, float eps, void* stream) {
|
||||
int block = (hidden_size < 1024) ? hidden_size : 1024;
|
||||
rmsnorm_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (const __nv_bfloat16*)gamma,
|
||||
(__nv_bfloat16*)out, hidden_size, eps);
|
||||
}
|
||||
|
||||
}
|
||||
106
csrc/reduce/softmax.cu
Normal file
106
csrc/reduce/softmax.cu
Normal file
@@ -0,0 +1,106 @@
|
||||
#include "../common.cuh"
|
||||
|
||||
// Safe softmax along the last dimension.
|
||||
// Each block handles one row of length `cols`.
|
||||
// Three-pass: 1) find max, 2) exp + sum, 3) normalize.
|
||||
|
||||
__global__ void softmax_f32(
|
||||
const float* __restrict__ x,
|
||||
float* __restrict__ out,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const float* x_row = x + row * cols;
|
||||
float* out_row = out + row * cols;
|
||||
|
||||
// Pass 1: find max
|
||||
float local_max = -INFINITY;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
local_max = fmaxf(local_max, x_row[i]);
|
||||
}
|
||||
float row_max = block_reduce_max(local_max);
|
||||
|
||||
__shared__ float s_max;
|
||||
if (threadIdx.x == 0) s_max = row_max;
|
||||
__syncthreads();
|
||||
row_max = s_max;
|
||||
|
||||
// Pass 2: exp and sum
|
||||
float local_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = expf(x_row[i] - row_max);
|
||||
out_row[i] = e;
|
||||
local_sum += e;
|
||||
}
|
||||
float row_sum = block_reduce_sum(local_sum);
|
||||
|
||||
__shared__ float s_inv_sum;
|
||||
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
|
||||
__syncthreads();
|
||||
float inv_sum = s_inv_sum;
|
||||
|
||||
// Pass 3: normalize
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
out_row[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void softmax_bf16(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int cols
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
const __nv_bfloat16* x_row = x + row * cols;
|
||||
__nv_bfloat16* out_row = out + row * cols;
|
||||
|
||||
float local_max = -INFINITY;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
local_max = fmaxf(local_max, __bfloat162float(x_row[i]));
|
||||
}
|
||||
float row_max = block_reduce_max(local_max);
|
||||
|
||||
__shared__ float s_max;
|
||||
if (threadIdx.x == 0) s_max = row_max;
|
||||
__syncthreads();
|
||||
row_max = s_max;
|
||||
|
||||
// We need float scratch for exp values. Reuse out (write bf16 in pass 3).
|
||||
// Use registers to hold exp values during sum pass instead.
|
||||
float local_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = expf(__bfloat162float(x_row[i]) - row_max);
|
||||
// Temporarily store exp in output as bf16 (slight precision loss, acceptable)
|
||||
out_row[i] = __float2bfloat16(e);
|
||||
local_sum += e;
|
||||
}
|
||||
float row_sum = block_reduce_sum(local_sum);
|
||||
|
||||
__shared__ float s_inv_sum;
|
||||
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
|
||||
__syncthreads();
|
||||
float inv_sum = s_inv_sum;
|
||||
|
||||
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
||||
float e = __bfloat162float(out_row[i]);
|
||||
out_row[i] = __float2bfloat16(e * inv_sum);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
if (block < 32) block = 32;
|
||||
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const float*)x, (float*)out, cols);
|
||||
}
|
||||
|
||||
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
|
||||
int block = (cols < 1024) ? cols : 1024;
|
||||
if (block < 32) block = 32;
|
||||
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user