moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)

Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.

New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.

dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 16:29:10 +08:00
parent cf1e9e41db
commit fb20178992
6 changed files with 692 additions and 0 deletions

View File

@@ -31,6 +31,7 @@ fn main() {
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/moe/moe_sparse.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")

View File

@@ -29,6 +29,29 @@ unsafe extern "C" {
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_fp8_bf16(
x: *const c_void, w: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_sparse_bf16(
down: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
stream: *mut c_void,
);
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
@@ -158,6 +181,110 @@ pub fn moe_weighted_sum(
out
}
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
///
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
/// w_scales: [local_experts] F32 per-expert scalar scales
/// bias: [local_experts, N] BF16 (fused into the epilogue)
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
///
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
/// consumer must skip them (see moe_weighted_sum_sparse).
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_fp8(
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let n = w_fp8_t.shape()[1];
let k = w_fp8_t.shape()[2];
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_fp8_bf16(
x.data_ptr() as *const c_void,
w_fp8_t.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(),
);
}
y
}
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_mxfp4(
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_mxfp4_bf16(
x.data_ptr() as *const c_void,
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(),
);
}
y
}
/// Weighted sum over the slot axis of the sparse GEMV output.
///
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
pub fn moe_weighted_sum_sparse(
down: &Tensor,
topk_ids: &Tensor,
topk_weights: &Tensor,
expert_start: usize,
local_experts: usize,
) -> Tensor {
assert_eq!(down.ndim(), 3);
assert_eq!(down.dtype(), DType::BF16);
let num_tokens = down.shape()[0];
let top_k = down.shape()[1];
let hidden = down.shape()[2];
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
unsafe {
launch_moe_weighted_sum_sparse_bf16(
down.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
std::ptr::null_mut(),
);
}
out
}
/// Strided batched GEMM for MoE expert forward.
/// C[b] = A[b] @ B[b] for b in 0..batch
///

View File

@@ -549,6 +549,60 @@ impl GptOss {
&router_logits, num_experts, top_k,
);
// Sparse decode path: compute ONLY the routed experts. The dense path
// below reads every local expert's weights per forward; the sparse
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
// (memory-bound). Dense reads each weight once for ALL tokens, so it
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
const SPARSE_MAX_TOKENS: usize = 8;
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, false,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
layer.expert_gate_up_scale.as_ref().unwrap(),
&layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, false,
)
};
// GLU over all slots. Non-local slots hold unwritten memory; they
// are never consumed (the down GEMV and the weighted sum both skip
// slots whose expert this rank does not own).
let inter2 = gate_up.shape()[2];
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, true,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
&activated, layer.expert_down_fp8.as_ref().unwrap(),
layer.expert_down_scale.as_ref().unwrap(),
&layer.expert_down_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, true,
)
};
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
&down, &topk_ids, &topk_weights, expert_start, local_experts,
);
self.all_reduce(&moe_out);
return moe_out;
}
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
@@ -625,6 +679,12 @@ impl GptOss {
// --- Helpers ---
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
fn dense_moe_forced() -> bool {
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
}
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);

254
csrc/moe/moe_sparse.cu Normal file
View File

@@ -0,0 +1,254 @@
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cstdint>
#include "../common.cuh"
// ============================================================
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
//
// The dense path replicates x across all local experts and runs a
// batched GEMM, reading every expert's weights per token. Decode is
// memory-bound, so reading only the top-k routed experts' weights
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
//
// Each block handles one (token, slot) pair's tile of output columns.
// It reads topk_ids[token, slot] from device memory (no host sync),
// and exits early if the expert is not owned by this rank. The early
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
// and returns before the shared-memory staging + __syncthreads), so
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
//
// Outputs for non-local slots are NEVER written (uninitialized memory,
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
// slots rather than multiply by zero (NaN * 0 = NaN).
//
// Per-expert weight scale and bias are fused into the epilogue:
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
//
// Activation addressing (x_per_slot):
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
// down: each slot has its own activation row
// x[token * top_k + slot, :] (x_per_slot=1)
// ============================================================
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
// loses nothing to tensor cores and skips activation quantization.
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
const float* __restrict__ w_scales, // [local_experts]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[]; // [K] activation row as float
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return; // after __syncthreads: safe
int lane = threadIdx.x & 31;
// One warp per output column; uint4 = 16 FP8 weights per lane, the
// warp covers 512 contiguous bytes per iteration (coalesced).
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
float acc = 0.0f;
for (int i = lane; i < (K >> 4); i += 32) {
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
const float* xk = xs + i * 16;
#pragma unroll
for (int j = 0; j < 16; j++) {
acc += xk[j] * float(pw[j]);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc * w_scales[lid]
+ __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
// via topk_ids and with fused per-expert bias.
#define MXFP4_BLOCK 32
__device__ __constant__ float kSparseFp4Levels[8] =
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
float mag = kSparseFp4Levels[code & 0x7];
return (code & 0x8) ? -mag : mag;
}
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
const int* __restrict__ topk_ids, // [T, top_k]
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
int N, int K, int top_k,
int expert_start, int local_experts,
int x_per_slot
) {
int token = blockIdx.z;
int slot = blockIdx.y;
int eid = topk_ids[token * top_k + slot];
int lid = eid - expert_start;
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
extern __shared__ float xs[];
const __nv_bfloat16* xrow =
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
for (int i = threadIdx.x; i < K; i += blockDim.x) {
xs[i] = __bfloat162float(xrow[i]);
}
__syncthreads();
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
if (n >= N) return;
int lane = threadIdx.x & 31;
int nblk = K / MXFP4_BLOCK;
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
float acc = 0.0f;
for (int blk = lane; blk < nblk; blk += 32) {
float scale = exp2f((float)((int)ws[blk] - 127));
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
const uint8_t* pb = (const uint8_t*)&packed;
const float* xk = xs + blk * MXFP4_BLOCK;
#pragma unroll
for (int i = 0; i < 16; i++) {
uint8_t b = pb[i];
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
}
}
#pragma unroll
for (int o = 16; o > 0; o >>= 1) {
acc += __shfl_down_sync(0xffffffffu, acc, o);
}
if (lane == 0) {
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
}
}
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
// memory and are SKIPPED (not multiplied by zero).
__global__ void moe_weighted_sum_sparse_bf16_kernel(
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
const int* __restrict__ topk_ids, // [T, top_k]
const float* __restrict__ topk_weights, // [T, top_k]
__nv_bfloat16* __restrict__ out, // [T, hidden]
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = num_tokens * hidden;
if (idx >= total) return;
int token = idx / hidden;
int dim = idx % hidden;
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int lid = topk_ids[token * top_k + k] - expert_start;
if (lid >= 0 && lid < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(
down[((long long)token * top_k + k) * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_sparse_gemv_fp8_bf16(
const void* x, const void* w, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
(const float*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_sparse_gemv_mxfp4_bf16(
const void* x, const void* w_packed, const void* w_scales, const void* bias,
const void* topk_ids, void* y,
int num_tokens, int N, int K, int top_k,
int expert_start, int local_experts, int x_per_slot,
void* stream
) {
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
int block = SPARSE_TILE_N * 32;
size_t smem = (size_t)K * sizeof(float);
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
(const int*)topk_ids, (__nv_bfloat16*)y,
N, K, top_k, expert_start, local_experts, x_per_slot
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_sparse_bf16(
const void* down, const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
int total = num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)down,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k, expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}

160
docs/20-sparse-moe.md Normal file
View File

@@ -0,0 +1,160 @@
# Phase 20: Sparse MoE Decode — 只算被路由到的专家
> 目标:消除 dense MoE 的无效权重读取,decode TPOT 追上并超过 llama.cpp。
> 前置:Phase 19(gpt-oss MoE 正确性)、FP8 W8A8 / MXFP4 W4A16 量化
> (见 `docs/benchmarks/fp8-quantization.md`、`docs/benchmarks/mxfp4-and-llama-decode.md`)。
## 1. 现状:dense MoE 在浪费什么
gpt-oss-20b 是 32 专家 top-4 的 MoE:router 给每个 token 选 4 个专家,
理论上每 token 只需要读 4/32 = 12.5% 的专家权重。但 `moe_forward`
(`crates/xserv-model/src/gpt_oss.rs`)目前是 **dense** 实现:
```text
1. router GEMV [T, 2880] → [T, 32]
2. topk_softmax (GPU) → topk_ids [T,4], topk_weights [T,4]
3. moe_replicate x 复制 16 份 → [16, T, 2880] ← 浪费开始
4. batched GEMM gate_up 全部 16 个本地专家都算 ← 读 16 份权重
5. bias + GLU
6. batched GEMM down 全部 16 个本地专家都算 ← 读 16 份权重
7. bias
8. moe_weighted_sum 只挑出 top-4 加权求和,其余 12 个全部丢弃
9. all-reduce
```
为什么当初这么写:batched GEMM(cuBLAS strided-batched)要求规则的
`[E, T, K]` 形状;top-4 的专家编号在 **GPU** 上(`topk_ids`),host 不知道
该挑哪几个,挑了形状也不规则。dense 是"先把正确性做出来"的合理起点,
但每 token 把 16 个专家的权重从 HBM 全部读一遍。
### 字节账本(decode,每 token,TP=2 每卡 16 个本地专家)
每层每专家:gate_up `[2880, 5760]` + down `[2880, 2880]` ≈ 24.9 M 参数。
| 方案 | 每卡每 token 专家字节 | 相对量 |
|---|---|---|
| xserv dense FP8(现状) | 16 × 24.9 MB × 24 层 ≈ **9.6 GB** | 1× |
| xserv sparse FP8(本阶段) | ~2 × 24.9 MB × 24 层 ≈ **1.2 GB** | 1/8 |
| llama.cpp sparse MXFP4 | ~2 × 12.5 MB × 24 层 ≈ **0.6 GB** | 1/16 |
(top-4 均匀散落在 2 张卡上,期望每卡 2 个命中;严格说每层取的是
两卡命中数的 max,期望 ≈ 2.6,仍是 ~6-8× 的节省。)
实测旁证:FP8 dense TP=2 TPOT 13.1 ms,其中专家 GEMM ≈ 9.6 GB ÷ ~1 TB/s
≈ 9.5 ms,其余(attention、qkv/o、lm_head、48 次 PCIe all-reduce)≈ 3.5 ms。
**专家权重读取占 TPOT 的 ~3/4,这就是与 llama.cpp(6.6 ms)的全部差距。**
## 2. Roofline:M=1 时为什么"省字节 = 省时间"
decode 的 GEMV(M=1)每读 1 字节 FP8 权重只做 2 FLOP(乘加)。
RTX 5090:HBM ~1.8 TB/s,BF16 算力 ~210 TFLOPS —— 算强比(arithmetic
intensity)需要 ~100 FLOP/byte 才能喂饱算力,GEMV 只有 2。结论:
1. **decode 完全 memory-bound**,tensor core 帮不上忙 → 手写 W8A16 GEMV
(权重 FP8、激活保持 BF16)不会输给 cuBLASLt 的 W8A8 tensor-core GEMM,
还省掉激活量化 kernel,精度更好(激活不再有量化误差)。
2. 优化只有一个方向:**少读字节**。sparse(×8)与 4-bit(×2)正交,
可叠加。本阶段先做 sparse,FP8 与 MXFP4 两种权重格式都支持。
## 3. Sparse 设计:让 kernel 自己按 topk_ids 索引权重
关键观察:`topk_ids` 本来就在 GPU 上。不需要 host 知道选了谁 ——
**让 GEMV kernel 的每个 block 自己读 `topk_ids[token, slot]`,
直接寻址到对应专家的权重**,不命中本卡就整块退出。零 host 同步,
管线保持完全异步(这是之前排查过的:decode 循环无 per-layer sync)。
新数据流(`num_tokens ≤ 8` 时启用):
```text
x [T, 2880]
├─ router → topk_ids/weights [T, 4] (不变)
├─ sparse GEMV gate_up → [T, 4, 5760] bias 已融合,非本地 slot 不写
├─ GLU → [T*4, 2880]
├─ sparse GEMV down → [T, 4, 2880] bias 已融合,非本地 slot 不写
└─ weighted_sum_sparse → [T, 2880] 只累加本地 slot
all-reduce (不变)
```
`moe_replicate` 和独立的 bias kernel 在 sparse 路径下消失;FP8 路径还省掉
`quantize_bf16_to_fp8_rowwise`
### Kernel 设计(`csrc/moe/moe_sparse.cu`)
`moe_sparse_gemv_{fp8,mxfp4}_bf16_kernel`:
- **grid = (N/8, top_k, tokens)**,block = 8 warp × 32 lane。
每个 block 负责一个 (token, slot) 的 8 个输出列,**一个 warp 算一个输出**。
- block 先读 `eid = topk_ids[token*top_k + slot]`,折算 `lid = eid - expert_start`;
不在 `[0, local_experts)` 就整块 return。
- 命中的 block 把激活行(K=2880 个 BF16 → float)协作搬进 shared memory
(11.25 KB),`__syncthreads()`,然后每 warp 沿 K 维做点积:
每 lane 一次 `uint4` 读 16 字节权重(FP8 = 16 个权重,MXFP4 = 32 个 nibble),
warp 内 32 lane 连续 → 512B coalesced 事务。
- epilogue(lane 0):`y = acc * w_scale[lid] + bias[lid, n]` —— per-expert
scale 和 bias 都融合在这里,与 dense 路径的"GEMM → bias add → 路由加权"
语义逐位等价(HF 参考实现也是先加 bias 再乘路由权重)。
- gate_up 与 down 共用同一个 kernel,用 `x_per_slot` 区分激活寻址:
gate_up 时 4 个 slot 共享 `x[token]`;down 时各读自己的 `act[token*4+slot]`
### 两个容易写错的安全点
1. **early-return 必须 block-uniform。** Phase 19 的 GEMV 垃圾输出 bug
(commit `3b9e32e`)正是"部分线程在 `__syncthreads()` 之前 return"导致
读未初始化 shared memory。这里的 return 发生在 smem 装载**之前**,且整个
block 基于同一个 `topk_ids` 值统一退出 —— 没有 divergence,合法且安全。
2. **weighted-sum 对非本地 slot 必须"跳过",不能"乘 0"。** 非本地 slot 的
GEMV 输出从未被写入(未初始化显存,可能是 NaN 位型),GLU 也会在上面算出
垃圾。`NaN × 0 = NaN`,所以求和 kernel 用 `if (local) sum += w*v` 跳过,
垃圾永远不进入数据流(dense 路径的 `moe_weighted_sum` 同理)。
## 4. 为什么 prefill 保持 dense
dense batched GEMM 把 16 份权重读**一次**,服务全部 M 个 token;
sparse GEMV 是**每 token** 重读自己的 ~2 份。字节交叉点:
```text
sparse 读 M × 2 份 vs dense 读 16 份 → M ≈ 8 (TP=2)
```
M > 8 后 dense 更省(且 GEMM 是 compute-bound,tensor core 开始有用)。
所以 sparse 只在 `num_tokens ≤ 8` 启用 —— 覆盖 decode(连续批合并的
多请求 decode 也是小 M)和极短的 re-prefill。真正的 sparse prefill
(按专家对 token 做 permute/gather 的 grouped GEMM,vLLM 的做法)是
后续阶段,主要收益在长 prompt TTFT。
## 5. 实测结果(2026-06-12,完整数据见 `docs/benchmarks/sparse-moe.md`)
In-process decode(bench-gpt-oss,greedy 96 tok):
| | TPOT | tok/s |
|---|---|---|
| dense FP8 TP=2(基线) | 13.9 ms | 72 |
| **sparse FP8 TP=2** | **7.6 ms(1.8×)** | **132** |
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
| sparse FP8 TP=1(单卡) | 7.8 ms | 128 |
Warm-server 对打 llama.cpp(`tools/xserv_vs_llama.py`):
- **TP=2 vs TP=2:xserv 首次全面反超** —— TPOT 7.19-7.32 ms vs llama
7.54-8.42 ms;短/中 prompt TTFT 也领先(35/49 vs 63/65 ms)。
- **TP=1 vs TP=1:llama 大胜**(2.88-3.22 ms vs 7.0-7.2 ms,347 vs 140
tok/s)。单卡才是 llama 的最优配置:它的跨卡 split 在 PCIe 上每 token
损失 ~5 ms,而单卡时它"全模型 4-bit + CUDA graph 整 token 回放"的
优势全部兑现。xserv 的残余 ~7 ms ≈ ~3 ms HBM(其中非专家权重还是
BF16,含 1.16 GB 的 lm_head)+ ~4 ms 启动开销(~200 个 kernel
launch/token,无 CUDA graph)。
- **正确性:GSM8K-100 = 96%**(dense FP8 91% / BF16 90%,greedy 噪声内,
无回归)。
教训:之前"CUDA graph ≈ 无用(~0.5-1.5ms)"的结论是相对 13 ms 的
dense TPOT 而言;专家成本砍掉后,launch 开销变成了最大的单项。
## 6. 下一阶段(按收益排序)
1. **decode CUDA graph**(~2-4 ms):当前最大单项。
2. **非专家权重量化**(~1-1.5 ms):qkv/o + lm_head 仍是 BF16,每 token
白读 ~2.3 GB;llama 是全模型 4-bit。
3. **sparse prefill**(grouped GEMM):长 prompt TTFT 94-120 ms → llama
的 ~30 ms 量级。
4. **W4A4 FP4 tensor core / 带宽调优的 MXFP4 GEMV**:让 4-bit 专家真正
快过 FP8(目前 8.4 vs 7.6 ms,GEMV 效率抵消了字节优势)。

View File

@@ -0,0 +1,90 @@
# Sparse MoE decode — 1.8× over dense; beats llama.cpp at TP=2 (gpt-oss-20b, RTX 5090)
Phase 20 (`docs/20-sparse-moe.md`): decode computes only the routed top-4
experts via fused expert-indexed GEMVs (`csrc/moe/moe_sparse.cu`) instead of
the dense all-local-expert batched GEMM. FP8 weights run W8A16 (weights FP8,
activations BF16 — decode is memory-bound, tensor cores irrelevant at M=1);
MXFP4 runs W4A16. Dense path retained for prefill / `num_tokens > 8` and via
`XSERV_DENSE_MOE=1` for A/B.
## In-process decode (bench-gpt-oss, greedy, 96 tokens)
| config | TPOT | tok/s |
|---|---|---|
| dense FP8 TP=2 (baseline) | 13.9 ms | 72 |
| **sparse FP8 TP=2** | **7.6 ms** | **132** |
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
| sparse FP8 TP=1 (one 5090) | 7.8 ms | 128 |
| sparse MXFP4 TP=1 | 8.9 ms | 113 |
- Sparse FP8 = **1.8× over dense**. Greedy output stays coherent.
- TP=1 ≈ TP=2: expert reads are now so small that PCIe all-reduce eats the
TP gain — single-GPU serving becomes the attractive deployment.
- MXFP4 reads half the bytes of FP8 but stays slower: the 4-bit dequant GEMV
has lower effective bandwidth (same fixed inefficiency seen in the dense
MXFP4 experiments); at sparse sizes both are partly launch/latency-bound.
## Head-to-head vs llama.cpp (tools/xserv_vs_llama.py, warm servers, TP=2, GPUs 0-1, 6 reps, 256 tok)
| prompt | metric | xserv sparse FP8 | llama MXFP4 | xserv vs llama |
|---|---|---|---|---|
| short | TTFT | **35.3 ms** | 62.7 ms | 1.78× faster |
| short | TPOT | **7.32 ms** | 8.42 ms | 1.15× faster |
| medium | TTFT | **49.4 ms** | 65.0 ms | 1.32× faster |
| medium | TPOT | **7.19 ms** | 7.54 ms | 1.05× faster |
| medium | tok/s | **139.1** | 132.7 | |
| long (1.6k) | TTFT | 94.1 ms | **44.7 ms** | 0.48× (llama wins) |
| long | TPOT | **7.25 ms** | 7.64 ms | 1.05× faster |
**Decode TPOT now beats llama.cpp at every prompt length** (was 2× slower:
13.1 vs 6.6 ms before sparse). Remaining loss: long-prompt TTFT — prefill is
still the dense all-expert GEMM; sparse/grouped prefill is the next phase.
## TP=1 head-to-head (single 5090; server now routes gpt-oss tp=1 to the TP engine)
| prompt | metric | xserv sparse FP8 | llama MXFP4 |
|---|---|---|---|
| short | TTFT / TPOT | 42.8 ms / 7.00 ms | **34.5 ms / 3.22 ms** |
| medium | TTFT / TPOT | 57.1 ms / 7.19 ms | **37.3 ms / 2.89 ms** |
| long | TTFT / TPOT | 119.6 ms / 7.20 ms | **27.8 ms / 2.88 ms** |
| | tok/s | 139143 | **311347** |
**Single-GPU is llama.cpp's sweet spot and it wins 2.22.5×.** Two structural
reasons, both instructive:
1. llama TP=2 (7.58.4 ms) is much WORSE than its TP=1 (2.9 ms): its PCIe
cross-GPU split costs ~5 ms/token. xserv's NCCL all-reduce is cheap enough
that TP=2 ≈ TP=1 (7.2 vs 7.0 ms) — but xserv's single-GPU floor is high.
2. xserv TP=1 reads ~4.7 GB/token (experts FP8 2.4 GB + **non-expert weights
still BF16** ~2.3 GB, half of that the 201k-vocab lm_head) ≈ 3.1 ms of pure
HBM time; the other ~4 ms is launch overhead (~200 kernels/token, no CUDA
graphs) + BF16 GEMV efficiency. llama reads ~1.3 GB (everything MXFP4) and
replays the whole token as one CUDA graph.
## Correctness
- Greedy generations coherent across prompts (FP8/MXFP4, TP=1/2).
- Sparse FP8 is W8A16 vs dense W8A8 — activations are no longer quantized, so
tokens are not expected to be byte-identical to dense; quality is checked by
GSM8K instead.
- **GSM8K-100 (greedy, TP=2, `tools/eval_gsm8k_fast.py`): 96/100 = 96.0%** vs
dense FP8 91.0% / BF16 90.0% — no regression (within greedy-nondeterminism
noise; W8A16 removes activation-quantization error so ≥ dense is expected).
Avg 1.3 s/problem also reflects the decode speedup.
## Remaining gaps / next levers (to catch llama TP=1 at 2.9 ms)
Sparse MoE removed the dominant cost; the residual ~7 ms splits roughly into
~3 ms HBM reads and ~4 ms fixed overhead. In impact order:
1. **CUDA graphs for decode** (~24 ms): with experts down to ~12 ms, the
~200 un-graphed launches/token are now the single largest cost. (The old
"graphs ≈ useless" conclusion was relative to a 13 ms dense TPOT — no
longer true.)
2. **Quantize non-expert weights** (~11.5 ms): attn qkv/o + the 1.16 GB BF16
lm_head read every token; FP8/MXFP4 them like llama quantizes everything.
3. **Sparse prefill** (permute tokens by expert + grouped GEMM): long-prompt
TTFT 94120 ms → llama's ~30 ms territory.
4. **W4A4 FP4 tensor cores / bandwidth-tuned MXFP4 GEMV**: make 4-bit experts
actually beat FP8 (today sparse MXFP4 is 8.4 ms vs FP8 7.6 ms — the 4-bit
GEMV's lower effective bandwidth still cancels its byte advantage).