Files
xserv/csrc/reduce/softmax.cu
Gahow Wang a67753f516 softmax: cap block size at 512 threads
launch_softmax_{f32,bf16} clamped block to 1024 threads when cols was
larger. Halving the ceiling to 512 keeps two blocks per SM resident on
the large vocab kernels that dominate speculative verify workloads
without changing rows/block indexing, and never exceeds cols.
2026-07-01 14:16:32 +08:00

109 lines
3.2 KiB
Plaintext

#include "../common.cuh"
// Safe softmax along the last dimension.
// Each block handles one row of length `cols`.
// Three-pass: 1) find max, 2) exp + sum, 3) normalize.
__global__ void softmax_f32(
const float* __restrict__ x,
float* __restrict__ out,
int cols
) {
int row = blockIdx.x;
const float* x_row = x + row * cols;
float* out_row = out + row * cols;
// Pass 1: find max
float local_max = -INFINITY;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
local_max = fmaxf(local_max, x_row[i]);
}
float row_max = block_reduce_max(local_max);
__shared__ float s_max;
if (threadIdx.x == 0) s_max = row_max;
__syncthreads();
row_max = s_max;
// Pass 2: exp and sum
float local_sum = 0.0f;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
float e = expf(x_row[i] - row_max);
out_row[i] = e;
local_sum += e;
}
float row_sum = block_reduce_sum(local_sum);
__shared__ float s_inv_sum;
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
__syncthreads();
float inv_sum = s_inv_sum;
// Pass 3: normalize
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
out_row[i] *= inv_sum;
}
}
__global__ void softmax_bf16(
const __nv_bfloat16* __restrict__ x,
__nv_bfloat16* __restrict__ out,
int cols
) {
int row = blockIdx.x;
const __nv_bfloat16* x_row = x + row * cols;
__nv_bfloat16* out_row = out + row * cols;
float local_max = -INFINITY;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
local_max = fmaxf(local_max, __bfloat162float(x_row[i]));
}
float row_max = block_reduce_max(local_max);
__shared__ float s_max;
if (threadIdx.x == 0) s_max = row_max;
__syncthreads();
row_max = s_max;
// We need float scratch for exp values. Reuse out (write bf16 in pass 3).
// Use registers to hold exp values during sum pass instead.
float local_sum = 0.0f;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
float e = expf(__bfloat162float(x_row[i]) - row_max);
// Temporarily store exp in output as bf16 (slight precision loss, acceptable)
out_row[i] = __float2bfloat16(e);
local_sum += e;
}
float row_sum = block_reduce_sum(local_sum);
__shared__ float s_inv_sum;
if (threadIdx.x == 0) s_inv_sum = 1.0f / row_sum;
__syncthreads();
float inv_sum = s_inv_sum;
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
float e = __bfloat162float(out_row[i]);
out_row[i] = __float2bfloat16(e * inv_sum);
}
}
extern "C" {
void launch_softmax_f32(const void* x, void* out, int rows, int cols, void* stream) {
int block = (cols < 512) ? cols : 512;
if (block < 32) block = 32;
softmax_f32<<<rows, block, 0, (cudaStream_t)stream>>>(
(const float*)x, (float*)out, cols);
CUDA_CHECK_LAST_ERROR();
}
void launch_softmax_bf16(const void* x, void* out, int rows, int cols, void* stream) {
int block = (cols < 512) ? cols : 512;
if (block < 32) block = 32;
softmax_bf16<<<rows, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)out, cols);
CUDA_CHECK_LAST_ERROR();
}
}