moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)
Dense MoE replicated x across all 16 local experts and ran the full batched GEMM, reading every expert's weights per token; the weighted sum then discarded 12 of 16 results. Decode is memory-bound, so this was ~8x wasted expert bytes — the entire decode gap vs llama.cpp. New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read topk_ids on-device (no host sync) and early-return block-uniformly for experts other ranks own. FP8 runs W8A16 (activations stay BF16 — tensor cores are irrelevant at M=1, and activation quantization error disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the GEMV epilogue; slot-indexed weighted sum skips (never multiplies) unwritten non-local slots. Dense path retained for num_tokens > 8 (prefill) and via XSERV_DENSE_MOE=1 for A/B. dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms. Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms — first config where xserv wins decode outright. GSM8K-100: 96% (dense FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode CUDA graphs, non-expert quantization, sparse prefill (docs/20). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -31,6 +31,7 @@ fn main() {
|
||||
.file("../../csrc/attention/paged_attention.cu")
|
||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||
.file("../../csrc/moe/moe_kernels.cu")
|
||||
.file("../../csrc/moe/moe_sparse.cu")
|
||||
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||
.file("../../csrc/quantization/quantize_fp8.cu")
|
||||
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
||||
|
||||
@@ -29,6 +29,29 @@ unsafe extern "C" {
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn launch_moe_sparse_gemv_fp8_bf16(
|
||||
x: *const c_void, w: *const c_void, w_scales: *const c_void,
|
||||
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
|
||||
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
fn launch_moe_weighted_sum_sparse_bf16(
|
||||
down: *const c_void,
|
||||
topk_ids: *const c_void, topk_weights: *const c_void,
|
||||
out: *mut c_void,
|
||||
num_tokens: i32, hidden: i32, top_k: i32,
|
||||
expert_start: i32, local_experts: i32,
|
||||
stream: *mut c_void,
|
||||
);
|
||||
|
||||
fn cublasGemmStridedBatchedEx(
|
||||
handle: CublasHandle,
|
||||
transa: i32, transb: i32,
|
||||
@@ -158,6 +181,110 @@ pub fn moe_weighted_sum(
|
||||
out
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
|
||||
///
|
||||
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
|
||||
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
|
||||
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
|
||||
/// w_scales: [local_experts] F32 per-expert scalar scales
|
||||
/// bias: [local_experts, N] BF16 (fused into the epilogue)
|
||||
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
|
||||
///
|
||||
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
|
||||
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
|
||||
/// consumer must skip them (see moe_weighted_sum_sparse).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_fp8(
|
||||
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
|
||||
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
let n = w_fp8_t.shape()[1];
|
||||
let k = w_fp8_t.shape()[2];
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_fp8_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_fp8_t.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
|
||||
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn moe_sparse_gemv_mxfp4(
|
||||
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
|
||||
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||
) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous());
|
||||
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||
|
||||
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_moe_sparse_gemv_mxfp4_bf16(
|
||||
x.data_ptr() as *const c_void,
|
||||
w_packed.data_ptr() as *const c_void,
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
y
|
||||
}
|
||||
|
||||
/// Weighted sum over the slot axis of the sparse GEMV output.
|
||||
///
|
||||
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
|
||||
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
|
||||
pub fn moe_weighted_sum_sparse(
|
||||
down: &Tensor,
|
||||
topk_ids: &Tensor,
|
||||
topk_weights: &Tensor,
|
||||
expert_start: usize,
|
||||
local_experts: usize,
|
||||
) -> Tensor {
|
||||
assert_eq!(down.ndim(), 3);
|
||||
assert_eq!(down.dtype(), DType::BF16);
|
||||
let num_tokens = down.shape()[0];
|
||||
let top_k = down.shape()[1];
|
||||
let hidden = down.shape()[2];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
|
||||
unsafe {
|
||||
launch_moe_weighted_sum_sparse_bf16(
|
||||
down.data_ptr() as *const c_void,
|
||||
topk_ids.data_ptr() as *const c_void,
|
||||
topk_weights.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Strided batched GEMM for MoE expert forward.
|
||||
/// C[b] = A[b] @ B[b] for b in 0..batch
|
||||
///
|
||||
|
||||
@@ -549,6 +549,60 @@ impl GptOss {
|
||||
&router_logits, num_experts, top_k,
|
||||
);
|
||||
|
||||
// Sparse decode path: compute ONLY the routed experts. The dense path
|
||||
// below reads every local expert's weights per forward; the sparse
|
||||
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
|
||||
// (memory-bound). Dense reads each weight once for ALL tokens, so it
|
||||
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
|
||||
const SPARSE_MAX_TOKENS: usize = 8;
|
||||
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
|
||||
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
|
||||
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
|
||||
num_tokens, top_k, n, k, expert_start, local_experts, false,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
|
||||
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||
&layer.expert_gate_up_bias, &topk_ids,
|
||||
num_tokens, top_k, expert_start, local_experts, false,
|
||||
)
|
||||
};
|
||||
|
||||
// GLU over all slots. Non-local slots hold unwritten memory; they
|
||||
// are never consumed (the down GEMV and the weighted sum both skip
|
||||
// slots whose expert this rank does not own).
|
||||
let inter2 = gate_up.shape()[2];
|
||||
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
|
||||
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
|
||||
|
||||
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
|
||||
let n = packed.shape()[1];
|
||||
let k = packed.shape()[2] * 2;
|
||||
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
|
||||
num_tokens, top_k, n, k, expert_start, local_experts, true,
|
||||
)
|
||||
} else {
|
||||
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||
&activated, layer.expert_down_fp8.as_ref().unwrap(),
|
||||
layer.expert_down_scale.as_ref().unwrap(),
|
||||
&layer.expert_down_bias, &topk_ids,
|
||||
num_tokens, top_k, expert_start, local_experts, true,
|
||||
)
|
||||
};
|
||||
|
||||
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
|
||||
&down, &topk_ids, &topk_weights, expert_start, local_experts,
|
||||
);
|
||||
self.all_reduce(&moe_out);
|
||||
return moe_out;
|
||||
}
|
||||
|
||||
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
|
||||
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
||||
|
||||
@@ -625,6 +679,12 @@ impl GptOss {
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
|
||||
fn dense_moe_forced() -> bool {
|
||||
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
|
||||
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
|
||||
}
|
||||
|
||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
assert_eq!(a.ndim(), 2);
|
||||
assert_eq!(b.ndim(), 2);
|
||||
|
||||
Reference in New Issue
Block a user