moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)

Dense MoE replicated x across all 16 local experts and ran the full
batched GEMM, reading every expert's weights per token; the weighted
sum then discarded 12 of 16 results. Decode is memory-bound, so this
was ~8x wasted expert bytes — the entire decode gap vs llama.cpp.

New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read
topk_ids on-device (no host sync) and early-return block-uniformly
for experts other ranks own. FP8 runs W8A16 (activations stay BF16 —
tensor cores are irrelevant at M=1, and activation quantization error
disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the
GEMV epilogue; slot-indexed weighted sum skips (never multiplies)
unwritten non-local slots. Dense path retained for num_tokens > 8
(prefill) and via XSERV_DENSE_MOE=1 for A/B.

dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms.
Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms —
first config where xserv wins decode outright. GSM8K-100: 96% (dense
FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode
CUDA graphs, non-expert quantization, sparse prefill (docs/20).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 16:29:10 +08:00
parent cf1e9e41db
commit fb20178992
6 changed files with 692 additions and 0 deletions

View File

@@ -31,6 +31,7 @@ fn main() {
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.file("../../csrc/moe/moe_sparse.cu")
.file("../../csrc/quantization/dequant_fp8.cu")
.file("../../csrc/quantization/quantize_fp8.cu")
.file("../../csrc/quantization/mxfp4_gemm.cu")

View File

@@ -29,6 +29,29 @@ unsafe extern "C" {
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_fp8_bf16(
x: *const c_void, w: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_sparse_gemv_mxfp4_bf16(
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
num_tokens: i32, n: i32, k: i32, top_k: i32,
expert_start: i32, local_experts: i32, x_per_slot: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_sparse_bf16(
down: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
stream: *mut c_void,
);
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
@@ -158,6 +181,110 @@ pub fn moe_weighted_sum(
out
}
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
///
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
/// w_scales: [local_experts] F32 per-expert scalar scales
/// bias: [local_experts, N] BF16 (fused into the epilogue)
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
///
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
/// consumer must skip them (see moe_weighted_sum_sparse).
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_fp8(
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let n = w_fp8_t.shape()[1];
let k = w_fp8_t.shape()[2];
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_fp8_bf16(
x.data_ptr() as *const c_void,
w_fp8_t.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(),
);
}
y
}
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
#[allow(clippy::too_many_arguments)]
pub fn moe_sparse_gemv_mxfp4(
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
expert_start: usize, local_experts: usize, x_per_slot: bool,
) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
assert_eq!(x.shape()[x.ndim() - 1], k);
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
unsafe {
launch_moe_sparse_gemv_mxfp4_bf16(
x.data_ptr() as *const c_void,
w_packed.data_ptr() as *const c_void,
w_scales.data_ptr() as *const c_void,
bias.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(),
);
}
y
}
/// Weighted sum over the slot axis of the sparse GEMV output.
///
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
pub fn moe_weighted_sum_sparse(
down: &Tensor,
topk_ids: &Tensor,
topk_weights: &Tensor,
expert_start: usize,
local_experts: usize,
) -> Tensor {
assert_eq!(down.ndim(), 3);
assert_eq!(down.dtype(), DType::BF16);
let num_tokens = down.shape()[0];
let top_k = down.shape()[1];
let hidden = down.shape()[2];
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
unsafe {
launch_moe_weighted_sum_sparse_bf16(
down.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
std::ptr::null_mut(),
);
}
out
}
/// Strided batched GEMM for MoE expert forward.
/// C[b] = A[b] @ B[b] for b in 0..batch
///

View File

@@ -549,6 +549,60 @@ impl GptOss {
&router_logits, num_experts, top_k,
);
// Sparse decode path: compute ONLY the routed experts. The dense path
// below reads every local expert's weights per forward; the sparse
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
// (memory-bound). Dense reads each weight once for ALL tokens, so it
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
const SPARSE_MAX_TOKENS: usize = 8;
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, false,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
layer.expert_gate_up_scale.as_ref().unwrap(),
&layer.expert_gate_up_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, false,
)
};
// GLU over all slots. Non-local slots hold unwritten memory; they
// are never consumed (the down GEMV and the weighted sum both skip
// slots whose expert this rank does not own).
let inter2 = gate_up.shape()[2];
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
let n = packed.shape()[1];
let k = packed.shape()[2] * 2;
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
num_tokens, top_k, n, k, expert_start, local_experts, true,
)
} else {
xserv_kernels::moe::moe_sparse_gemv_fp8(
&activated, layer.expert_down_fp8.as_ref().unwrap(),
layer.expert_down_scale.as_ref().unwrap(),
&layer.expert_down_bias, &topk_ids,
num_tokens, top_k, expert_start, local_experts, true,
)
};
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
&down, &topk_ids, &topk_weights, expert_start, local_experts,
);
self.all_reduce(&moe_out);
return moe_out;
}
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
@@ -625,6 +679,12 @@ impl GptOss {
// --- Helpers ---
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
fn dense_moe_forced() -> bool {
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
}
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 2);
assert_eq!(b.ndim(), 2);