model: fused GPU MoE kernel — eliminate CPU roundtrip
Replace the per-token CPU-routed MoE forward with an all-GPU path: 1. moe_topk_softmax: GPU top-k + softmax (was CPU sort + softmax) 2. moe_replicate: broadcast input to all local experts 3. cublasGemmStridedBatchedEx: batched expert matmul (was per-expert cuBLAS) 4. moe_weighted_sum: FP32-accumulated weighted sum on GPU (was GPU→CPU→F32→BF16→GPU) Expert weights stored as contiguous 3D tensors for strided batched GEMM. Zero CPU↔GPU transfers per MoE layer (was ~40 per token per layer). Also: configurable geglu_alpha, LayerNorm bias auto-detect, unused-weight diagnostic at load time. GSM8K 30-problem: 11/30 → 23/30 (76.7%) vs llama.cpp 30/30 (100%). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
247
csrc/moe/moe_kernels.cu
Normal file
247
csrc/moe/moe_kernels.cu
Normal file
@@ -0,0 +1,247 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <float.h>
|
||||
#include "../common.cuh"
|
||||
|
||||
// ============================================================
|
||||
// MoE Top-K + Softmax kernel
|
||||
//
|
||||
// Input: router_logits [num_tokens, num_experts] BF16
|
||||
// Output: topk_ids [num_tokens, top_k] int32
|
||||
// topk_weights [num_tokens, top_k] float32
|
||||
//
|
||||
// One block per token. Threads cooperatively find top-k indices
|
||||
// via repeated argmax, then compute softmax over the k winners.
|
||||
// num_experts <= 256 (fits in registers / shared memory).
|
||||
// ============================================================
|
||||
|
||||
#define MOE_MAX_EXPERTS 256
|
||||
#define MOE_MAX_TOPK 8
|
||||
|
||||
__global__ void moe_topk_softmax_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ router_logits,
|
||||
int* __restrict__ topk_ids,
|
||||
float* __restrict__ topk_weights,
|
||||
int num_experts, int top_k
|
||||
) {
|
||||
int token = blockIdx.x;
|
||||
int tid = threadIdx.x;
|
||||
const __nv_bfloat16* row = router_logits + token * num_experts;
|
||||
|
||||
// Load logits into shared memory
|
||||
__shared__ float smem_logits[MOE_MAX_EXPERTS];
|
||||
__shared__ int smem_ids[MOE_MAX_TOPK];
|
||||
__shared__ float smem_vals[MOE_MAX_TOPK];
|
||||
|
||||
for (int i = tid; i < num_experts; i += blockDim.x) {
|
||||
smem_logits[i] = __bfloat162float(row[i]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Find top-k via repeated argmax (k is small, typically 4)
|
||||
if (tid == 0) {
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
float best_val = -INFINITY;
|
||||
int best_idx = 0;
|
||||
for (int e = 0; e < num_experts; e++) {
|
||||
if (smem_logits[e] > best_val) {
|
||||
best_val = smem_logits[e];
|
||||
best_idx = e;
|
||||
}
|
||||
}
|
||||
smem_ids[k] = best_idx;
|
||||
smem_vals[k] = best_val;
|
||||
smem_logits[best_idx] = -INFINITY; // mask out selected
|
||||
}
|
||||
|
||||
// Softmax over top-k values (in FP32)
|
||||
float max_val = smem_vals[0];
|
||||
for (int k = 1; k < top_k; k++)
|
||||
max_val = fmaxf(max_val, smem_vals[k]);
|
||||
|
||||
float exp_sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
smem_vals[k] = expf(smem_vals[k] - max_val);
|
||||
exp_sum += smem_vals[k];
|
||||
}
|
||||
float inv_sum = 1.0f / exp_sum;
|
||||
for (int k = 0; k < top_k; k++)
|
||||
smem_vals[k] *= inv_sum;
|
||||
|
||||
// Write outputs
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
topk_ids[token * top_k + k] = smem_ids[k];
|
||||
topk_weights[token * top_k + k] = smem_vals[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Replicate kernel
|
||||
//
|
||||
// Input: x [num_tokens, hidden] BF16
|
||||
// Output: x_rep [local_experts, num_tokens, hidden] BF16
|
||||
//
|
||||
// Copies x into each expert's batch slot.
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_replicate_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
__nv_bfloat16* __restrict__ x_rep,
|
||||
int num_tokens, int hidden, int local_experts
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = local_experts * num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int expert = idx / (num_tokens * hidden);
|
||||
int remainder = idx % (num_tokens * hidden);
|
||||
// x_rep[expert, token, dim] = x[token, dim]
|
||||
x_rep[idx] = x[remainder];
|
||||
(void)expert; // suppress unused warning
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Bias Add 3D kernel
|
||||
//
|
||||
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
|
||||
// bias [batch, dim] BF16
|
||||
//
|
||||
// x[b, t, d] += bias[b, d]
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_bias_add_3d_bf16_kernel(
|
||||
__nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ bias,
|
||||
int batch, int num_tokens, int dim
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = batch * num_tokens * dim;
|
||||
if (idx >= total) return;
|
||||
|
||||
int b = idx / (num_tokens * dim);
|
||||
int d = idx % dim;
|
||||
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[b * dim + d]);
|
||||
x[idx] = __float2bfloat16(v);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// MoE Weighted Sum kernel
|
||||
//
|
||||
// Input: expert_out [local_experts, num_tokens, hidden] BF16
|
||||
// topk_ids [num_tokens, top_k] int32 (global expert ids)
|
||||
// topk_weights[num_tokens, top_k] float32
|
||||
// expert_start: first global expert id this rank owns
|
||||
// local_experts: number of experts this rank owns
|
||||
//
|
||||
// Output: out [num_tokens, hidden] BF16
|
||||
//
|
||||
// For each (token, dim): accumulate in FP32:
|
||||
// sum = 0
|
||||
// for k in 0..top_k:
|
||||
// global_id = topk_ids[token, k]
|
||||
// if global_id in [expert_start, expert_start + local_experts):
|
||||
// local_id = global_id - expert_start
|
||||
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
|
||||
// out[token, dim] = bf16(sum)
|
||||
// ============================================================
|
||||
|
||||
__global__ void moe_weighted_sum_bf16_kernel(
|
||||
const __nv_bfloat16* __restrict__ expert_out,
|
||||
const int* __restrict__ topk_ids,
|
||||
const float* __restrict__ topk_weights,
|
||||
__nv_bfloat16* __restrict__ out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts
|
||||
) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total = num_tokens * hidden;
|
||||
if (idx >= total) return;
|
||||
|
||||
int token = idx / hidden;
|
||||
int dim = idx % hidden;
|
||||
|
||||
int expert_stride = num_tokens * hidden; // stride between experts in expert_out
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < top_k; k++) {
|
||||
int global_id = topk_ids[token * top_k + k];
|
||||
int local_id = global_id - expert_start;
|
||||
if (local_id >= 0 && local_id < local_experts) {
|
||||
float w = topk_weights[token * top_k + k];
|
||||
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
|
||||
sum += w * v;
|
||||
}
|
||||
}
|
||||
out[idx] = __float2bfloat16(sum);
|
||||
}
|
||||
|
||||
|
||||
extern "C" {
|
||||
|
||||
void launch_moe_topk_softmax_bf16(
|
||||
const void* router_logits,
|
||||
void* topk_ids, void* topk_weights,
|
||||
int num_tokens, int num_experts, int top_k,
|
||||
void* stream
|
||||
) {
|
||||
int block = 128;
|
||||
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)router_logits,
|
||||
(int*)topk_ids, (float*)topk_weights,
|
||||
num_experts, top_k
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_replicate_bf16(
|
||||
const void* x, void* x_rep,
|
||||
int num_tokens, int hidden, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
int total = local_experts * num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
|
||||
num_tokens, hidden, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_bias_add_3d_bf16(
|
||||
void* x, const void* bias,
|
||||
int batch, int num_tokens, int dim,
|
||||
void* stream
|
||||
) {
|
||||
int total = batch * num_tokens * dim;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
|
||||
batch, num_tokens, dim
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
void launch_moe_weighted_sum_bf16(
|
||||
const void* expert_out,
|
||||
const void* topk_ids, const void* topk_weights,
|
||||
void* out,
|
||||
int num_tokens, int hidden, int top_k,
|
||||
int expert_start, int local_experts,
|
||||
void* stream
|
||||
) {
|
||||
int total = num_tokens * hidden;
|
||||
int block = 256;
|
||||
int grid = (total + block - 1) / block;
|
||||
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||
(const __nv_bfloat16*)expert_out,
|
||||
(const int*)topk_ids, (const float*)topk_weights,
|
||||
(__nv_bfloat16*)out,
|
||||
num_tokens, hidden, top_k,
|
||||
expert_start, local_experts
|
||||
);
|
||||
CUDA_CHECK_LAST_ERROR();
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user