Dense MoE replicated x across all 16 local experts and ran the full batched GEMM, reading every expert's weights per token; the weighted sum then discarded 12 of 16 results. Decode is memory-bound, so this was ~8x wasted expert bytes — the entire decode gap vs llama.cpp. New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read topk_ids on-device (no host sync) and early-return block-uniformly for experts other ranks own. FP8 runs W8A16 (activations stay BF16 — tensor cores are irrelevant at M=1, and activation quantization error disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the GEMV epilogue; slot-indexed weighted sum skips (never multiplies) unwritten non-local slots. Dense path retained for num_tokens > 8 (prefill) and via XSERV_DENSE_MOE=1 for A/B. dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms. Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms — first config where xserv wins decode outright. GSM8K-100: 96% (dense FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode CUDA graphs, non-expert quantization, sparse prefill (docs/20). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
255 lines
9.8 KiB
Plaintext
255 lines
9.8 KiB
Plaintext
#include <cuda_bf16.h>
|
|
#include <cuda_fp8.h>
|
|
#include <cstdint>
|
|
#include "../common.cuh"
|
|
|
|
// ============================================================
|
|
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
|
|
//
|
|
// The dense path replicates x across all local experts and runs a
|
|
// batched GEMM, reading every expert's weights per token. Decode is
|
|
// memory-bound, so reading only the top-k routed experts' weights
|
|
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
|
|
//
|
|
// Each block handles one (token, slot) pair's tile of output columns.
|
|
// It reads topk_ids[token, slot] from device memory (no host sync),
|
|
// and exits early if the expert is not owned by this rank. The early
|
|
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
|
|
// and returns before the shared-memory staging + __syncthreads), so
|
|
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
|
|
//
|
|
// Outputs for non-local slots are NEVER written (uninitialized memory,
|
|
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
|
|
// slots rather than multiply by zero (NaN * 0 = NaN).
|
|
//
|
|
// Per-expert weight scale and bias are fused into the epilogue:
|
|
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
|
|
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
|
|
//
|
|
// Activation addressing (x_per_slot):
|
|
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
|
|
// down: each slot has its own activation row
|
|
// x[token * top_k + slot, :] (x_per_slot=1)
|
|
// ============================================================
|
|
|
|
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
|
|
|
|
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
|
|
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
|
|
// loses nothing to tensor cores and skips activation quantization.
|
|
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
|
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
|
|
const float* __restrict__ w_scales, // [local_experts]
|
|
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
|
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
|
|
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
|
int N, int K, int top_k,
|
|
int expert_start, int local_experts,
|
|
int x_per_slot
|
|
) {
|
|
int token = blockIdx.z;
|
|
int slot = blockIdx.y;
|
|
int eid = topk_ids[token * top_k + slot];
|
|
int lid = eid - expert_start;
|
|
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
|
|
|
extern __shared__ float xs[]; // [K] activation row as float
|
|
const __nv_bfloat16* xrow =
|
|
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
|
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
|
xs[i] = __bfloat162float(xrow[i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
|
if (n >= N) return; // after __syncthreads: safe
|
|
int lane = threadIdx.x & 31;
|
|
|
|
// One warp per output column; uint4 = 16 FP8 weights per lane, the
|
|
// warp covers 512 contiguous bytes per iteration (coalesced).
|
|
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
|
|
float acc = 0.0f;
|
|
for (int i = lane; i < (K >> 4); i += 32) {
|
|
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
|
|
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
|
|
const float* xk = xs + i * 16;
|
|
#pragma unroll
|
|
for (int j = 0; j < 16; j++) {
|
|
acc += xk[j] * float(pw[j]);
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int o = 16; o > 0; o >>= 1) {
|
|
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
|
}
|
|
if (lane == 0) {
|
|
float v = acc * w_scales[lid]
|
|
+ __bfloat162float(bias[(long long)lid * N + n]);
|
|
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
|
}
|
|
}
|
|
|
|
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
|
|
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
|
|
// via topk_ids and with fused per-expert bias.
|
|
#define MXFP4_BLOCK 32
|
|
|
|
__device__ __constant__ float kSparseFp4Levels[8] =
|
|
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
|
|
|
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
|
|
float mag = kSparseFp4Levels[code & 0x7];
|
|
return (code & 0x8) ? -mag : mag;
|
|
}
|
|
|
|
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
|
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
|
|
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
|
|
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
|
const int* __restrict__ topk_ids, // [T, top_k]
|
|
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
|
int N, int K, int top_k,
|
|
int expert_start, int local_experts,
|
|
int x_per_slot
|
|
) {
|
|
int token = blockIdx.z;
|
|
int slot = blockIdx.y;
|
|
int eid = topk_ids[token * top_k + slot];
|
|
int lid = eid - expert_start;
|
|
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
|
|
|
extern __shared__ float xs[];
|
|
const __nv_bfloat16* xrow =
|
|
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
|
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
|
xs[i] = __bfloat162float(xrow[i]);
|
|
}
|
|
__syncthreads();
|
|
|
|
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
|
if (n >= N) return;
|
|
int lane = threadIdx.x & 31;
|
|
int nblk = K / MXFP4_BLOCK;
|
|
|
|
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
|
|
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
|
|
|
|
float acc = 0.0f;
|
|
for (int blk = lane; blk < nblk; blk += 32) {
|
|
float scale = exp2f((float)((int)ws[blk] - 127));
|
|
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
|
|
const uint8_t* pb = (const uint8_t*)&packed;
|
|
const float* xk = xs + blk * MXFP4_BLOCK;
|
|
#pragma unroll
|
|
for (int i = 0; i < 16; i++) {
|
|
uint8_t b = pb[i];
|
|
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
|
|
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int o = 16; o > 0; o >>= 1) {
|
|
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
|
}
|
|
if (lane == 0) {
|
|
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
|
|
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
|
}
|
|
}
|
|
|
|
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
|
|
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
|
|
// memory and are SKIPPED (not multiplied by zero).
|
|
__global__ void moe_weighted_sum_sparse_bf16_kernel(
|
|
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
|
|
const int* __restrict__ topk_ids, // [T, top_k]
|
|
const float* __restrict__ topk_weights, // [T, top_k]
|
|
__nv_bfloat16* __restrict__ out, // [T, hidden]
|
|
int num_tokens, int hidden, int top_k,
|
|
int expert_start, int local_experts
|
|
) {
|
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int total = num_tokens * hidden;
|
|
if (idx >= total) return;
|
|
|
|
int token = idx / hidden;
|
|
int dim = idx % hidden;
|
|
|
|
float sum = 0.0f;
|
|
for (int k = 0; k < top_k; k++) {
|
|
int lid = topk_ids[token * top_k + k] - expert_start;
|
|
if (lid >= 0 && lid < local_experts) {
|
|
float w = topk_weights[token * top_k + k];
|
|
float v = __bfloat162float(
|
|
down[((long long)token * top_k + k) * hidden + dim]);
|
|
sum += w * v;
|
|
}
|
|
}
|
|
out[idx] = __float2bfloat16(sum);
|
|
}
|
|
|
|
extern "C" {
|
|
|
|
void launch_moe_sparse_gemv_fp8_bf16(
|
|
const void* x, const void* w, const void* w_scales, const void* bias,
|
|
const void* topk_ids, void* y,
|
|
int num_tokens, int N, int K, int top_k,
|
|
int expert_start, int local_experts, int x_per_slot,
|
|
void* stream
|
|
) {
|
|
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
|
int block = SPARSE_TILE_N * 32;
|
|
size_t smem = (size_t)K * sizeof(float);
|
|
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
|
|
(const float*)w_scales, (const __nv_bfloat16*)bias,
|
|
(const int*)topk_ids, (__nv_bfloat16*)y,
|
|
N, K, top_k, expert_start, local_experts, x_per_slot
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_moe_sparse_gemv_mxfp4_bf16(
|
|
const void* x, const void* w_packed, const void* w_scales, const void* bias,
|
|
const void* topk_ids, void* y,
|
|
int num_tokens, int N, int K, int top_k,
|
|
int expert_start, int local_experts, int x_per_slot,
|
|
void* stream
|
|
) {
|
|
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
|
int block = SPARSE_TILE_N * 32;
|
|
size_t smem = (size_t)K * sizeof(float);
|
|
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
|
|
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
|
|
(const int*)topk_ids, (__nv_bfloat16*)y,
|
|
N, K, top_k, expert_start, local_experts, x_per_slot
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
void launch_moe_weighted_sum_sparse_bf16(
|
|
const void* down, const void* topk_ids, const void* topk_weights,
|
|
void* out,
|
|
int num_tokens, int hidden, int top_k,
|
|
int expert_start, int local_experts,
|
|
void* stream
|
|
) {
|
|
int total = num_tokens * hidden;
|
|
int block = 256;
|
|
int grid = (total + block - 1) / block;
|
|
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
|
(const __nv_bfloat16*)down,
|
|
(const int*)topk_ids, (const float*)topk_weights,
|
|
(__nv_bfloat16*)out,
|
|
num_tokens, hidden, top_k, expert_start, local_experts
|
|
);
|
|
CUDA_CHECK_LAST_ERROR();
|
|
}
|
|
|
|
}
|