Files
xserv/csrc/moe/moe_sparse.cu
Gahow Wang fb20178992 moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)
Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.

New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.

dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-12 16:29:10 +08:00

255 lines
9.8 KiB
Plaintext

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cstdint>
#include "../common.cuh"
// ============================================================
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
//
// The dense path replicates x across all local experts and runs a
// batched GEMM, reading every expert's weights per token. Decode is
// memory-bound, so reading only the top-k routed experts' weights
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
//
// Each block handles one (token, slot) pair's tile of output columns.
// It reads topk_ids[token, slot] from device memory (no host sync),
// and exits early if the expert is not owned by this rank. The early
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
// and returns before the shared-memory staging + __syncthreads), so
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
//
// Outputs for non-local slots are NEVER written (uninitialized memory,
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
// slots rather than multiply by zero (NaN * 0 = NaN).
//
// Per-expert weight scale and bias are fused into the epilogue:
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
//
// Activation addressing (x_per_slot):
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
// down: each slot has its own activation row
// x[token * top_k + slot, :] (x_per_slot=1)
// ============================================================
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
// loses nothing to tensor cores and skips activation quantization.
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
const float* __restrict__ w_scales, // [local_experts]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[]; // [K] activation row as float
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return; // after __syncthreads: safe
int lane = threadIdx.x & 31;
// One warp per output column; uint4 = 16 FP8 weights per lane, the
// warp covers 512 contiguous bytes per iteration (coalesced).
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
float acc = 0.0f;
for (int i = lane; i < (K >> 4); i += 32) {
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
const float* xk = xs + i * 16;
#pragma unroll
for (int j = 0; j < 16; j++) {
acc += xk[j] * float(pw[j]);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc * w_scales[lid]
+ __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
// via topk_ids and with fused per-expert bias.
#define MXFP4_BLOCK 32
__device__ __constant__ float kSparseFp4Levels[8] =
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
float mag = kSparseFp4Levels[code & 0x7];
return (code & 0x8) ? -mag : mag;
}
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k]
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[];
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return;
int lane = threadIdx.x & 31;
int nblk = K / MXFP4_BLOCK;
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
float acc = 0.0f;
for (int blk = lane; blk < nblk; blk += 32) {
float scale = exp2f((float)((int)ws[blk] - 127));
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
const uint8_t* pb = (const uint8_t*)&packed;
const float* xk = xs + blk * MXFP4_BLOCK;
#pragma unroll
for (int i = 0; i < 16; i++) {
uint8_t b = pb[i];
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
// memory and are SKIPPED (not multiplied by zero).
__global__ void moe_weighted_sum_sparse_bf16_kernel(
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
const int* __restrict__ topk_ids, // [T, top_k]
const float* __restrict__ topk_weights, // [T, top_k]
__nv_bfloat16* __restrict__ out, // [T, hidden]
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = num_tokens * hidden;
if (idx >= total) return;
int token = idx / hidden;
int dim = idx % hidden;
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int lid = topk_ids[token * top_k + k] - expert_start;
if (lid >= 0 && lid < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(
down[((long long)token * top_k + k) * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_sparse_gemv_fp8_bf16(
const void* x, const void* w, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
(const float*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_sparse_gemv_mxfp4_bf16(
const void* x, const void* w_packed, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_sparse_bf16(
const void* down, const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
int total = num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)down,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k, expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}