The dense MoE kernels (moe_replicate, moe_bias_add_3d, moe_weighted_sum) computed total / expert_stride / element indices in int32. gpt-oss prefill runs the whole prompt through the dense path unchunked (SPARSE_MAX_TOKENS=8), so local_experts*num_tokens*hidden (and batch*num_tokens*dim, and local_id*expert_stride) overflow int32 at ~3.6k-23k prefill tokens (TP-dependent) — well inside the supported context window. The launch then fails silently because CUDA_CHECK_LAST_ERROR was ((void)0) under NDEBUG, so the bias / weighted-sum simply never runs and the forward pass is corrupted with no error reported. Fix: switch the three kernels and their launchers to long long, mirroring the (long long) indexing already used in moe_sparse.cu. Also make CUDA_CHECK_LAST_ERROR always-on — cudaGetLastError does not sync, so the per-launch host cost is negligible, and a silent launch failure is exactly the class of bug this one was. Verified on dash5 (RTX 5090): a direct kernel test at 2.21B elements (>2^31) for both moe_replicate and moe_bias_add_3d produces correct results with no launch error; bench-gpt-oss TP=2 holds at 5.9ms TPOT, output unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
255 lines
8.7 KiB
Plaintext
255 lines
8.7 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <float.h>
|
|
#include "../common.cuh"
|
|
|
|
// ============================================================
|
|
// MoE Top-K + Softmax kernel
|
|
//
|
|
// Input: router_logits [num_tokens, num_experts] BF16
|
|
// Output: topk_ids [num_tokens, top_k] int32
|
|
// topk_weights [num_tokens, top_k] float32
|
|
//
|
|
// One block per token. Threads cooperatively find top-k indices
|
|
// via repeated argmax, then compute softmax over the k winners.
|
|
// num_experts <= 256 (fits in registers / shared memory).
|
|
// ============================================================
|
|
|
|
#define MOE_MAX_EXPERTS 256
|
|
#define MOE_MAX_TOPK 8
|
|
|
|
__global__ void moe_topk_softmax_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ router_logits,
|
|
int* __restrict__ topk_ids,
|
|
float* __restrict__ topk_weights,
|
|
int num_experts, int top_k
|
|
) {
|
|
int token = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
const __nv_bfloat16* row = router_logits + token * num_experts;
|
|
|
|
// Load logits into shared memory
|
|
__shared__ float smem_logits[MOE_MAX_EXPERTS];
|
|
__shared__ int smem_ids[MOE_MAX_TOPK];
|
|
__shared__ float smem_vals[MOE_MAX_TOPK];
|
|
|
|
for (int i = tid; i < num_experts; i += blockDim.x) {
|
|
smem_logits[i] = __bfloat162float(row[i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
// Find top-k via repeated argmax (k is small, typically 4)
|
|
if (tid == 0) {
|
|
for (int k = 0; k < top_k; k++) {
|
|
float best_val = -INFINITY;
|
|
int best_idx = 0;
|
|
for (int e = 0; e < num_experts; e++) {
|
|
if (smem_logits[e] > best_val) {
|
|
best_val = smem_logits[e];
|
|
best_idx = e;
|
|
}
|
|
}
|
|
smem_ids[k] = best_idx;
|
|
smem_vals[k] = best_val;
|
|
smem_logits[best_idx] = -INFINITY; // mask out selected
|
|
}
|
|
|
|
// Softmax over top-k values (in FP32)
|
|
float max_val = smem_vals[0];
|
|
for (int k = 1; k < top_k; k++)
|
|
max_val = fmaxf(max_val, smem_vals[k]);
|
|
|
|
float exp_sum = 0.0f;
|
|
for (int k = 0; k < top_k; k++) {
|
|
smem_vals[k] = expf(smem_vals[k] - max_val);
|
|
exp_sum += smem_vals[k];
|
|
}
|
|
float inv_sum = 1.0f / exp_sum;
|
|
for (int k = 0; k < top_k; k++)
|
|
smem_vals[k] *= inv_sum;
|
|
|
|
// Write outputs
|
|
for (int k = 0; k < top_k; k++) {
|
|
topk_ids[token * top_k + k] = smem_ids[k];
|
|
topk_weights[token * top_k + k] = smem_vals[k];
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================
|
|
// MoE Replicate kernel
|
|
//
|
|
// Input: x [num_tokens, hidden] BF16
|
|
// Output: x_rep [local_experts, num_tokens, hidden] BF16
|
|
//
|
|
// Copies x into each expert's batch slot.
|
|
// ============================================================
|
|
|
|
__global__ void moe_replicate_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ x,
|
|
__nv_bfloat16* __restrict__ x_rep,
|
|
int num_tokens, int hidden, int local_experts
|
|
) {
|
|
// 64-bit index: local_experts * num_tokens * hidden overflows int32 at
|
|
// ~2.3k prefill tokens (gpt-oss TP=1, 32 experts), which is inside the
|
|
// supported context window. A 32-bit `total` silently wraps, the launch
|
|
// fails, and (in release) the error is invisible — see common.cuh.
|
|
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
|
long long total = (long long)local_experts * num_tokens * hidden;
|
|
if (idx >= total) return;
|
|
|
|
// x_rep[expert, token, dim] = x[token, dim]
|
|
long long row_stride = (long long)num_tokens * hidden;
|
|
x_rep[idx] = x[idx % row_stride];
|
|
}
|
|
|
|
// ============================================================
|
|
// MoE Bias Add 3D kernel
|
|
//
|
|
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
|
|
// bias [batch, dim] BF16
|
|
//
|
|
// x[b, t, d] += bias[b, d]
|
|
// ============================================================
|
|
|
|
__global__ void moe_bias_add_3d_bf16_kernel(
|
|
__nv_bfloat16* __restrict__ x,
|
|
const __nv_bfloat16* __restrict__ bias,
|
|
int batch, int num_tokens, int dim
|
|
) {
|
|
// 64-bit index: batch * num_tokens * dim overflows int32 at ~3.6k prefill
|
|
// tokens (gpt-oss TP=1, 32 experts, 2*intermediate dim) — see moe_replicate.
|
|
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
|
long long total = (long long)batch * num_tokens * dim;
|
|
if (idx >= total) return;
|
|
|
|
long long td = (long long)num_tokens * dim;
|
|
int b = (int)(idx / td); // < batch (small)
|
|
int d = (int)(idx % dim); // < dim
|
|
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[(long long)b * dim + d]);
|
|
x[idx] = __float2bfloat16(v);
|
|
}
|
|
|
|
// ============================================================
|
|
// MoE Weighted Sum kernel
|
|
//
|
|
// Input: expert_out [local_experts, num_tokens, hidden] BF16
|
|
// topk_ids [num_tokens, top_k] int32 (global expert ids)
|
|
// topk_weights[num_tokens, top_k] float32
|
|
// expert_start: first global expert id this rank owns
|
|
// local_experts: number of experts this rank owns
|
|
//
|
|
// Output: out [num_tokens, hidden] BF16
|
|
//
|
|
// For each (token, dim): accumulate in FP32:
|
|
// sum = 0
|
|
// for k in 0..top_k:
|
|
// global_id = topk_ids[token, k]
|
|
// if global_id in [expert_start, expert_start + local_experts):
|
|
// local_id = global_id - expert_start
|
|
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
|
|
// out[token, dim] = bf16(sum)
|
|
// ============================================================
|
|
|
|
__global__ void moe_weighted_sum_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ expert_out,
|
|
const int* __restrict__ topk_ids,
|
|
const float* __restrict__ topk_weights,
|
|
__nv_bfloat16* __restrict__ out,
|
|
int num_tokens, int hidden, int top_k,
|
|
int expert_start, int local_experts
|
|
) {
|
|
// 64-bit index: `local_id * expert_stride` overflows int32 for long prefills
|
|
// (expert_stride = num_tokens * hidden), reading the wrong expert element.
|
|
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
|
|
long long total = (long long)num_tokens * hidden;
|
|
if (idx >= total) return;
|
|
|
|
long long token = idx / hidden;
|
|
int dim = (int)(idx % hidden);
|
|
|
|
long long expert_stride = (long long)num_tokens * hidden; // stride between experts in expert_out
|
|
|
|
float sum = 0.0f;
|
|
for (int k = 0; k < top_k; k++) {
|
|
int global_id = topk_ids[token * top_k + k];
|
|
int local_id = global_id - expert_start;
|
|
if (local_id >= 0 && local_id < local_experts) {
|
|
float w = topk_weights[token * top_k + k];
|
|
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
|
|
sum += w * v;
|
|
}
|
|
}
|
|
out[idx] = __float2bfloat16(sum);
|
|
}
|
|
|
|
|
|
extern "C" {
|
|
|
|
void launch_moe_topk_softmax_bf16(
|
|
const void* router_logits,
|
|
void* topk_ids, void* topk_weights,
|
|
int num_tokens, int num_experts, int top_k,
|
|
void* stream
|
|
) {
|
|
int block = 128;
|
|
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)router_logits,
|
|
(int*)topk_ids, (float*)topk_weights,
|
|
num_experts, top_k
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_moe_replicate_bf16(
|
|
const void* x, void* x_rep,
|
|
int num_tokens, int hidden, int local_experts,
|
|
void* stream
|
|
) {
|
|
long long total = (long long)local_experts * num_tokens * hidden;
|
|
int block = 256;
|
|
int grid = (int)((total + block - 1) / block);
|
|
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
|
|
num_tokens, hidden, local_experts
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_moe_bias_add_3d_bf16(
|
|
void* x, const void* bias,
|
|
int batch, int num_tokens, int dim,
|
|
void* stream
|
|
) {
|
|
long long total = (long long)batch * num_tokens * dim;
|
|
int block = 256;
|
|
int grid = (int)((total + block - 1) / block);
|
|
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
|
|
batch, num_tokens, dim
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_moe_weighted_sum_bf16(
|
|
const void* expert_out,
|
|
const void* topk_ids, const void* topk_weights,
|
|
void* out,
|
|
int num_tokens, int hidden, int top_k,
|
|
int expert_start, int local_experts,
|
|
void* stream
|
|
) {
|
|
long long total = (long long)num_tokens * hidden;
|
|
int block = 256;
|
|
int grid = (int)((total + block - 1) / block);
|
|
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)expert_out,
|
|
(const int*)topk_ids, (const float*)topk_weights,
|
|
(__nv_bfloat16*)out,
|
|
num_tokens, hidden, top_k,
|
|
expert_start, local_experts
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|