Files
xserv/csrc/moe/moe_kernels.cu
Gahow Wang 4368e79695 model: fused GPU MoE kernel — eliminate CPU roundtrip
Replace the per-token CPU-routed MoE forward with an all-GPU path:

  1. moe_topk_softmax: GPU top-k + softmax (was CPU sort + softmax)
  2. moe_replicate: broadcast input to all local experts
  3. cublasGemmStridedBatchedEx: batched expert matmul (was per-expert cuBLAS)
  4. moe_weighted_sum: FP32-accumulated weighted sum on GPU (was GPU→CPU→F32→BF16→GPU)

Expert weights stored as contiguous 3D tensors for strided batched GEMM.
Zero CPU↔GPU transfers per MoE layer (was ~40 per token per layer).

Also: configurable geglu_alpha, LayerNorm bias auto-detect, unused-weight
diagnostic at load time.

GSM8K 30-problem: 11/30 → 23/30 (76.7%) vs llama.cpp 30/30 (100%).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-31 13:22:59 +08:00

248 lines
7.9 KiB
Plaintext

#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// ============================================================
// MoE Top-K + Softmax kernel
//
// Input: router_logits [num_tokens, num_experts] BF16
// Output: topk_ids [num_tokens, top_k] int32
// topk_weights [num_tokens, top_k] float32
//
// One block per token. Threads cooperatively find top-k indices
// via repeated argmax, then compute softmax over the k winners.
// num_experts <= 256 (fits in registers / shared memory).
// ============================================================
#define MOE_MAX_EXPERTS 256
#define MOE_MAX_TOPK 8
__global__ void moe_topk_softmax_bf16_kernel(
const __nv_bfloat16* __restrict__ router_logits,
int* __restrict__ topk_ids,
float* __restrict__ topk_weights,
int num_experts, int top_k
) {
int token = blockIdx.x;
int tid = threadIdx.x;
const __nv_bfloat16* row = router_logits + token * num_experts;
// Load logits into shared memory
__shared__ float smem_logits[MOE_MAX_EXPERTS];
__shared__ int smem_ids[MOE_MAX_TOPK];
__shared__ float smem_vals[MOE_MAX_TOPK];
for (int i = tid; i < num_experts; i += blockDim.x) {
smem_logits[i] = __bfloat162float(row[i]);
}
__syncthreads();
// Find top-k via repeated argmax (k is small, typically 4)
if (tid == 0) {
for (int k = 0; k < top_k; k++) {
float best_val = -INFINITY;
int best_idx = 0;
for (int e = 0; e < num_experts; e++) {
if (smem_logits[e] > best_val) {
best_val = smem_logits[e];
best_idx = e;
}
}
smem_ids[k] = best_idx;
smem_vals[k] = best_val;
smem_logits[best_idx] = -INFINITY; // mask out selected
}
// Softmax over top-k values (in FP32)
float max_val = smem_vals[0];
for (int k = 1; k < top_k; k++)
max_val = fmaxf(max_val, smem_vals[k]);
float exp_sum = 0.0f;
for (int k = 0; k < top_k; k++) {
smem_vals[k] = expf(smem_vals[k] - max_val);
exp_sum += smem_vals[k];
}
float inv_sum = 1.0f / exp_sum;
for (int k = 0; k < top_k; k++)
smem_vals[k] *= inv_sum;
// Write outputs
for (int k = 0; k < top_k; k++) {
topk_ids[token * top_k + k] = smem_ids[k];
topk_weights[token * top_k + k] = smem_vals[k];
}
}
}
// ============================================================
// MoE Replicate kernel
//
// Input: x [num_tokens, hidden] BF16
// Output: x_rep [local_experts, num_tokens, hidden] BF16
//
// Copies x into each expert's batch slot.
// ============================================================
__global__ void moe_replicate_bf16_kernel(
const __nv_bfloat16* __restrict__ x,
__nv_bfloat16* __restrict__ x_rep,
int num_tokens, int hidden, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = local_experts * num_tokens * hidden;
if (idx >= total) return;
int expert = idx / (num_tokens * hidden);
int remainder = idx % (num_tokens * hidden);
// x_rep[expert, token, dim] = x[token, dim]
x_rep[idx] = x[remainder];
(void)expert; // suppress unused warning
}
// ============================================================
// MoE Bias Add 3D kernel
//
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
// bias [batch, dim] BF16
//
// x[b, t, d] += bias[b, d]
// ============================================================
__global__ void moe_bias_add_3d_bf16_kernel(
__nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ bias,
int batch, int num_tokens, int dim
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * num_tokens * dim;
if (idx >= total) return;
int b = idx / (num_tokens * dim);
int d = idx % dim;
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[b * dim + d]);
x[idx] = __float2bfloat16(v);
}
// ============================================================
// MoE Weighted Sum kernel
//
// Input: expert_out [local_experts, num_tokens, hidden] BF16
// topk_ids [num_tokens, top_k] int32 (global expert ids)
// topk_weights[num_tokens, top_k] float32
// expert_start: first global expert id this rank owns
// local_experts: number of experts this rank owns
//
// Output: out [num_tokens, hidden] BF16
//
// For each (token, dim): accumulate in FP32:
// sum = 0
// for k in 0..top_k:
// global_id = topk_ids[token, k]
// if global_id in [expert_start, expert_start + local_experts):
// local_id = global_id - expert_start
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
// out[token, dim] = bf16(sum)
// ============================================================
__global__ void moe_weighted_sum_bf16_kernel(
const __nv_bfloat16* __restrict__ expert_out,
const int* __restrict__ topk_ids,
const float* __restrict__ topk_weights,
__nv_bfloat16* __restrict__ out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = num_tokens * hidden;
if (idx >= total) return;
int token = idx / hidden;
int dim = idx % hidden;
int expert_stride = num_tokens * hidden; // stride between experts in expert_out
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int global_id = topk_ids[token * top_k + k];
int local_id = global_id - expert_start;
if (local_id >= 0 && local_id < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_topk_softmax_bf16(
const void* router_logits,
void* topk_ids, void* topk_weights,
int num_tokens, int num_experts, int top_k,
void* stream
) {
int block = 128;
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)router_logits,
(int*)topk_ids, (float*)topk_weights,
num_experts, top_k
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_replicate_bf16(
const void* x, void* x_rep,
int num_tokens, int hidden, int local_experts,
void* stream
) {
int total = local_experts * num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
num_tokens, hidden, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_bias_add_3d_bf16(
void* x, const void* bias,
int batch, int num_tokens, int dim,
void* stream
) {
int total = batch * num_tokens * dim;
int block = 256;
int grid = (total + block - 1) / block;
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
batch, num_tokens, dim
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_bf16(
const void* expert_out,
const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
int total = num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)expert_out,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k,
expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}