Files
xserv/csrc/common.cuh
Gahow Wang c8e8153702 phase 4: transformer core kernels
CUDA kernels (csrc/):
- common.cuh: shared warp_reduce_sum/max, block_reduce_sum/max
- normalization/rmsnorm.cu: RMSNorm (F32 + BF16)
- normalization/layernorm.cu: LayerNorm with Welford (F32 + BF16)
- activation/activations.cu: GELU tanh-approx + SiLU (F32 + BF16)
- reduce/softmax.cu: safe softmax, 3-pass (F32 + BF16)
- embedding/embedding.cu: gather lookup (F32 + BF16)
- embedding/rope.cu: RoPE in-place + precomputed cos/sin cache (F32 + BF16)

Rust wrappers (xserv-kernels/src/):
- rmsnorm.rs, layernorm.rs, activation.rs, softmax.rs, embedding.rs, rope.rs
- RopeCache struct with GPU-side precomputation

Tests: 12 new tests (ops_test.rs), all passing with good precision:
- F32: max_err 1e-6 ~ 1e-9
- BF16: max_err 2e-3 ~ 7e-3
Total: 29 kernel tests + 27 prior = 56 tests passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-21 21:07:24 +08:00

51 lines
1.4 KiB
Plaintext

#pragma once
#include <cuda_bf16.h>
// --- Warp-level reductions (no shared memory needed) ---
__device__ __forceinline__ float warp_reduce_sum(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val += __shfl_down_sync(0xffffffff, val, offset);
return val;
}
__device__ __forceinline__ float warp_reduce_max(float val) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
val = fmaxf(val, __shfl_down_sync(0xffffffff, val, offset));
return val;
}
// --- Block-level reductions ---
__device__ __forceinline__ float block_reduce_sum(float val) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp_id = threadIdx.x >> 5;
int num_warps = (blockDim.x + 31) >> 5;
val = warp_reduce_sum(val);
if (lane == 0) shared[warp_id] = val;
__syncthreads();
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : 0.0f;
if (warp_id == 0) val = warp_reduce_sum(val);
return val;
}
__device__ __forceinline__ float block_reduce_max(float val) {
__shared__ float shared[32];
int lane = threadIdx.x & 31;
int warp_id = threadIdx.x >> 5;
int num_warps = (blockDim.x + 31) >> 5;
val = warp_reduce_max(val);
if (lane == 0) shared[warp_id] = val;
__syncthreads();
val = (threadIdx.x < num_warps) ? shared[threadIdx.x] : -INFINITY;
if (warp_id == 0) val = warp_reduce_max(val);
return val;
}