Files
xserv/csrc/moe/moe_kernels.cu
Gahow Wang cfbd64d206 cuda: fix int32 overflow in MoE dense kernels; surface launch errors in release
The dense MoE kernels (moe_replicate, moe_bias_add_3d, moe_weighted_sum)
computed total / expert_stride / element indices in int32. gpt-oss prefill
runs the whole prompt through the dense path unchunked (SPARSE_MAX_TOKENS=8),
so local_experts*num_tokens*hidden (and batch*num_tokens*dim, and
local_id*expert_stride) overflow int32 at ~3.6k-23k prefill tokens
(TP-dependent) — well inside the supported context window. The launch then
fails silently because CUDA_CHECK_LAST_ERROR was ((void)0) under NDEBUG, so
the bias / weighted-sum simply never runs and the forward pass is corrupted
with no error reported.

Fix: switch the three kernels and their launchers to long long, mirroring the
(long long) indexing already used in moe_sparse.cu. Also make
CUDA_CHECK_LAST_ERROR always-on — cudaGetLastError does not sync, so the
per-launch host cost is negligible, and a silent launch failure is exactly
the class of bug this one was.

Verified on dash5 (RTX 5090): a direct kernel test at 2.21B elements (>2^31)
for both moe_replicate and moe_bias_add_3d produces correct results with no
launch error; bench-gpt-oss TP=2 holds at 5.9ms TPOT, output unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-07-01 12:37:21 +08:00

255 lines
8.7 KiB
Plaintext

#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// ============================================================
// MoE Top-K + Softmax kernel
//
// Input: router_logits [num_tokens, num_experts] BF16
// Output: topk_ids [num_tokens, top_k] int32
// topk_weights [num_tokens, top_k] float32
//
// One block per token. Threads cooperatively find top-k indices
// via repeated argmax, then compute softmax over the k winners.
// num_experts <= 256 (fits in registers / shared memory).
// ============================================================
#define MOE_MAX_EXPERTS 256
#define MOE_MAX_TOPK 8
__global__ void moe_topk_softmax_bf16_kernel(
const __nv_bfloat16* __restrict__ router_logits,
int* __restrict__ topk_ids,
float* __restrict__ topk_weights,
int num_experts, int top_k
) {
int token = blockIdx.x;
int tid = threadIdx.x;
const __nv_bfloat16* row = router_logits + token * num_experts;
// Load logits into shared memory
__shared__ float smem_logits[MOE_MAX_EXPERTS];
__shared__ int smem_ids[MOE_MAX_TOPK];
__shared__ float smem_vals[MOE_MAX_TOPK];
for (int i = tid; i < num_experts; i += blockDim.x) {
smem_logits[i] = __bfloat162float(row[i]);
}
__syncthreads();
// Find top-k via repeated argmax (k is small, typically 4)
if (tid == 0) {
for (int k = 0; k < top_k; k++) {
float best_val = -INFINITY;
int best_idx = 0;
for (int e = 0; e < num_experts; e++) {
if (smem_logits[e] > best_val) {
best_val = smem_logits[e];
best_idx = e;
}
}
smem_ids[k] = best_idx;
smem_vals[k] = best_val;
smem_logits[best_idx] = -INFINITY; // mask out selected
}
// Softmax over top-k values (in FP32)
float max_val = smem_vals[0];
for (int k = 1; k < top_k; k++)
max_val = fmaxf(max_val, smem_vals[k]);
float exp_sum = 0.0f;
for (int k = 0; k < top_k; k++) {
smem_vals[k] = expf(smem_vals[k] - max_val);
exp_sum += smem_vals[k];
}
float inv_sum = 1.0f / exp_sum;
for (int k = 0; k < top_k; k++)
smem_vals[k] *= inv_sum;
// Write outputs
for (int k = 0; k < top_k; k++) {
topk_ids[token * top_k + k] = smem_ids[k];
topk_weights[token * top_k + k] = smem_vals[k];
}
}
}
// ============================================================
// MoE Replicate kernel
//
// Input: x [num_tokens, hidden] BF16
// Output: x_rep [local_experts, num_tokens, hidden] BF16
//
// Copies x into each expert's batch slot.
// ============================================================
__global__ void moe_replicate_bf16_kernel(
const __nv_bfloat16* __restrict__ x,
__nv_bfloat16* __restrict__ x_rep,
int num_tokens, int hidden, int local_experts
) {
// 64-bit index: local_experts * num_tokens * hidden overflows int32 at
// ~2.3k prefill tokens (gpt-oss TP=1, 32 experts), which is inside the
// supported context window. A 32-bit `total` silently wraps, the launch
// fails, and (in release) the error is invisible — see common.cuh.
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)local_experts * num_tokens * hidden;
if (idx >= total) return;
// x_rep[expert, token, dim] = x[token, dim]
long long row_stride = (long long)num_tokens * hidden;
x_rep[idx] = x[idx % row_stride];
}
// ============================================================
// MoE Bias Add 3D kernel
//
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
// bias [batch, dim] BF16
//
// x[b, t, d] += bias[b, d]
// ============================================================
__global__ void moe_bias_add_3d_bf16_kernel(
__nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ bias,
int batch, int num_tokens, int dim
) {
// 64-bit index: batch * num_tokens * dim overflows int32 at ~3.6k prefill
// tokens (gpt-oss TP=1, 32 experts, 2*intermediate dim) — see moe_replicate.
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)batch * num_tokens * dim;
if (idx >= total) return;
long long td = (long long)num_tokens * dim;
int b = (int)(idx / td); // < batch (small)
int d = (int)(idx % dim); // < dim
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[(long long)b * dim + d]);
x[idx] = __float2bfloat16(v);
}
// ============================================================
// MoE Weighted Sum kernel
//
// Input: expert_out [local_experts, num_tokens, hidden] BF16
// topk_ids [num_tokens, top_k] int32 (global expert ids)
// topk_weights[num_tokens, top_k] float32
// expert_start: first global expert id this rank owns
// local_experts: number of experts this rank owns
//
// Output: out [num_tokens, hidden] BF16
//
// For each (token, dim): accumulate in FP32:
// sum = 0
// for k in 0..top_k:
// global_id = topk_ids[token, k]
// if global_id in [expert_start, expert_start + local_experts):
// local_id = global_id - expert_start
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
// out[token, dim] = bf16(sum)
// ============================================================
__global__ void moe_weighted_sum_bf16_kernel(
const __nv_bfloat16* __restrict__ expert_out,
const int* __restrict__ topk_ids,
const float* __restrict__ topk_weights,
__nv_bfloat16* __restrict__ out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
// 64-bit index: `local_id * expert_stride` overflows int32 for long prefills
// (expert_stride = num_tokens * hidden), reading the wrong expert element.
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)num_tokens * hidden;
if (idx >= total) return;
long long token = idx / hidden;
int dim = (int)(idx % hidden);
long long expert_stride = (long long)num_tokens * hidden; // stride between experts in expert_out
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int global_id = topk_ids[token * top_k + k];
int local_id = global_id - expert_start;
if (local_id >= 0 && local_id < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_topk_softmax_bf16(
const void* router_logits,
void* topk_ids, void* topk_weights,
int num_tokens, int num_experts, int top_k,
void* stream
) {
int block = 128;
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)router_logits,
(int*)topk_ids, (float*)topk_weights,
num_experts, top_k
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_replicate_bf16(
const void* x, void* x_rep,
int num_tokens, int hidden, int local_experts,
void* stream
) {
long long total = (long long)local_experts * num_tokens * hidden;
int block = 256;
int grid = (int)((total + block - 1) / block);
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
num_tokens, hidden, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_bias_add_3d_bf16(
void* x, const void* bias,
int batch, int num_tokens, int dim,
void* stream
) {
long long total = (long long)batch * num_tokens * dim;
int block = 256;
int grid = (int)((total + block - 1) / block);
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
batch, num_tokens, dim
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_bf16(
const void* expert_out,
const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
long long total = (long long)num_tokens * hidden;
int block = 256;
int grid = (int)((total + block - 1) / block);
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)expert_out,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k,
expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}