moe: sparse top-k decode — compute only routed experts (1.8x, beats llama TP=2)
Dense MoE replicated x across all 16 local experts and ran the full batched GEMM, reading every expert's weights per token; the weighted sum then discarded 12 of 16 results. Decode is memory-bound, so this was ~8x wasted expert bytes — the entire decode gap vs llama.cpp. New fused expert-indexed GEMVs (csrc/moe/moe_sparse.cu) read topk_ids on-device (no host sync) and early-return block-uniformly for experts other ranks own. FP8 runs W8A16 (activations stay BF16 — tensor cores are irrelevant at M=1, and activation quantization error disappears); MXFP4 runs W4A16. Per-expert bias + scale fused into the GEMV epilogue; slot-indexed weighted sum skips (never multiplies) unwritten non-local slots. Dense path retained for num_tokens > 8 (prefill) and via XSERV_DENSE_MOE=1 for A/B. dash5 (RTX 5090), gpt-oss-20b FP8, TP=2: decode TPOT 13.9 -> 7.6 ms. Warm-server vs llama.cpp MXFP4 TP=2: TPOT 7.19-7.32 vs 7.54-8.42 ms — first config where xserv wins decode outright. GSM8K-100: 96% (dense FP8: 91%). llama TP=1 (2.9 ms) remains ahead: next levers are decode CUDA graphs, non-expert quantization, sparse prefill (docs/20). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -31,6 +31,7 @@ fn main() {
|
|||||||
.file("../../csrc/attention/paged_attention.cu")
|
.file("../../csrc/attention/paged_attention.cu")
|
||||||
.file("../../csrc/attention/reshape_and_cache.cu")
|
.file("../../csrc/attention/reshape_and_cache.cu")
|
||||||
.file("../../csrc/moe/moe_kernels.cu")
|
.file("../../csrc/moe/moe_kernels.cu")
|
||||||
|
.file("../../csrc/moe/moe_sparse.cu")
|
||||||
.file("../../csrc/quantization/dequant_fp8.cu")
|
.file("../../csrc/quantization/dequant_fp8.cu")
|
||||||
.file("../../csrc/quantization/quantize_fp8.cu")
|
.file("../../csrc/quantization/quantize_fp8.cu")
|
||||||
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
.file("../../csrc/quantization/mxfp4_gemm.cu")
|
||||||
|
|||||||
@@ -29,6 +29,29 @@ unsafe extern "C" {
|
|||||||
stream: *mut c_void,
|
stream: *mut c_void,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
fn launch_moe_sparse_gemv_fp8_bf16(
|
||||||
|
x: *const c_void, w: *const c_void, w_scales: *const c_void,
|
||||||
|
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||||
|
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||||
|
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
|
fn launch_moe_sparse_gemv_mxfp4_bf16(
|
||||||
|
x: *const c_void, w_packed: *const c_void, w_scales: *const c_void,
|
||||||
|
bias: *const c_void, topk_ids: *const c_void, y: *mut c_void,
|
||||||
|
num_tokens: i32, n: i32, k: i32, top_k: i32,
|
||||||
|
expert_start: i32, local_experts: i32, x_per_slot: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
|
fn launch_moe_weighted_sum_sparse_bf16(
|
||||||
|
down: *const c_void,
|
||||||
|
topk_ids: *const c_void, topk_weights: *const c_void,
|
||||||
|
out: *mut c_void,
|
||||||
|
num_tokens: i32, hidden: i32, top_k: i32,
|
||||||
|
expert_start: i32, local_experts: i32,
|
||||||
|
stream: *mut c_void,
|
||||||
|
);
|
||||||
|
|
||||||
fn cublasGemmStridedBatchedEx(
|
fn cublasGemmStridedBatchedEx(
|
||||||
handle: CublasHandle,
|
handle: CublasHandle,
|
||||||
transa: i32, transb: i32,
|
transa: i32, transb: i32,
|
||||||
@@ -158,6 +181,110 @@ pub fn moe_weighted_sum(
|
|||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sparse MoE GEMV (FP8 W8A16): compute only the routed experts.
|
||||||
|
///
|
||||||
|
/// x: [num_tokens, K] BF16 (x_per_slot=false, gate_up) or
|
||||||
|
/// [num_tokens * top_k, K] BF16 (x_per_slot=true, down)
|
||||||
|
/// w_fp8_t: [local_experts, N, K] FP8E4M3 (transposed weight layout)
|
||||||
|
/// w_scales: [local_experts] F32 per-expert scalar scales
|
||||||
|
/// bias: [local_experts, N] BF16 (fused into the epilogue)
|
||||||
|
/// topk_ids: [num_tokens, top_k] i32 global expert ids (GPU)
|
||||||
|
///
|
||||||
|
/// Returns y [num_tokens, top_k, N] BF16. Slots routed to experts NOT
|
||||||
|
/// owned by this rank are left UNWRITTEN (uninitialized memory) — the
|
||||||
|
/// consumer must skip them (see moe_weighted_sum_sparse).
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn moe_sparse_gemv_fp8(
|
||||||
|
x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||||
|
topk_ids: &Tensor, num_tokens: usize, top_k: usize,
|
||||||
|
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||||
|
) -> Tensor {
|
||||||
|
assert_eq!(x.dtype(), DType::BF16);
|
||||||
|
assert!(x.is_contiguous());
|
||||||
|
let n = w_fp8_t.shape()[1];
|
||||||
|
let k = w_fp8_t.shape()[2];
|
||||||
|
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||||
|
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||||
|
|
||||||
|
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||||
|
unsafe {
|
||||||
|
launch_moe_sparse_gemv_fp8_bf16(
|
||||||
|
x.data_ptr() as *const c_void,
|
||||||
|
w_fp8_t.data_ptr() as *const c_void,
|
||||||
|
w_scales.data_ptr() as *const c_void,
|
||||||
|
bias.data_ptr() as *const c_void,
|
||||||
|
topk_ids.data_ptr() as *const c_void,
|
||||||
|
y.data_ptr() as *mut c_void,
|
||||||
|
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||||
|
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
y
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sparse MoE GEMV (MXFP4 W4A16): same contract as moe_sparse_gemv_fp8,
|
||||||
|
/// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32].
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn moe_sparse_gemv_mxfp4(
|
||||||
|
x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor,
|
||||||
|
topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize,
|
||||||
|
expert_start: usize, local_experts: usize, x_per_slot: bool,
|
||||||
|
) -> Tensor {
|
||||||
|
assert_eq!(x.dtype(), DType::BF16);
|
||||||
|
assert!(x.is_contiguous());
|
||||||
|
assert_eq!(x.shape()[x.ndim() - 1], k);
|
||||||
|
assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens });
|
||||||
|
|
||||||
|
let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device());
|
||||||
|
unsafe {
|
||||||
|
launch_moe_sparse_gemv_mxfp4_bf16(
|
||||||
|
x.data_ptr() as *const c_void,
|
||||||
|
w_packed.data_ptr() as *const c_void,
|
||||||
|
w_scales.data_ptr() as *const c_void,
|
||||||
|
bias.data_ptr() as *const c_void,
|
||||||
|
topk_ids.data_ptr() as *const c_void,
|
||||||
|
y.data_ptr() as *mut c_void,
|
||||||
|
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||||
|
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
y
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Weighted sum over the slot axis of the sparse GEMV output.
|
||||||
|
///
|
||||||
|
/// down: [num_tokens, top_k, hidden] BF16 (non-local slots uninitialized
|
||||||
|
/// and skipped, never multiplied by zero — NaN * 0 = NaN).
|
||||||
|
pub fn moe_weighted_sum_sparse(
|
||||||
|
down: &Tensor,
|
||||||
|
topk_ids: &Tensor,
|
||||||
|
topk_weights: &Tensor,
|
||||||
|
expert_start: usize,
|
||||||
|
local_experts: usize,
|
||||||
|
) -> Tensor {
|
||||||
|
assert_eq!(down.ndim(), 3);
|
||||||
|
assert_eq!(down.dtype(), DType::BF16);
|
||||||
|
let num_tokens = down.shape()[0];
|
||||||
|
let top_k = down.shape()[1];
|
||||||
|
let hidden = down.shape()[2];
|
||||||
|
|
||||||
|
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, down.device());
|
||||||
|
unsafe {
|
||||||
|
launch_moe_weighted_sum_sparse_bf16(
|
||||||
|
down.data_ptr() as *const c_void,
|
||||||
|
topk_ids.data_ptr() as *const c_void,
|
||||||
|
topk_weights.data_ptr() as *const c_void,
|
||||||
|
out.data_ptr() as *mut c_void,
|
||||||
|
num_tokens as i32, hidden as i32, top_k as i32,
|
||||||
|
expert_start as i32, local_experts as i32,
|
||||||
|
std::ptr::null_mut(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
/// Strided batched GEMM for MoE expert forward.
|
/// Strided batched GEMM for MoE expert forward.
|
||||||
/// C[b] = A[b] @ B[b] for b in 0..batch
|
/// C[b] = A[b] @ B[b] for b in 0..batch
|
||||||
///
|
///
|
||||||
|
|||||||
@@ -549,6 +549,60 @@ impl GptOss {
|
|||||||
&router_logits, num_experts, top_k,
|
&router_logits, num_experts, top_k,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Sparse decode path: compute ONLY the routed experts. The dense path
|
||||||
|
// below reads every local expert's weights per forward; the sparse
|
||||||
|
// GEMVs read ~top_k/num_experts of the bytes, which dominates decode
|
||||||
|
// (memory-bound). Dense reads each weight once for ALL tokens, so it
|
||||||
|
// wins back at num_tokens ≈ local_experts / E[local hits] ≈ 8.
|
||||||
|
const SPARSE_MAX_TOKENS: usize = 8;
|
||||||
|
let quantized = layer.expert_gate_up_fp8.is_some() || layer.expert_gate_up_mxfp4.is_some();
|
||||||
|
if num_tokens <= SPARSE_MAX_TOKENS && quantized && !dense_moe_forced() {
|
||||||
|
let gate_up = if let Some((ref packed, ref scales)) = layer.expert_gate_up_mxfp4 {
|
||||||
|
let n = packed.shape()[1];
|
||||||
|
let k = packed.shape()[2] * 2;
|
||||||
|
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||||
|
x, packed, scales, &layer.expert_gate_up_bias, &topk_ids,
|
||||||
|
num_tokens, top_k, n, k, expert_start, local_experts, false,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||||
|
x, layer.expert_gate_up_fp8.as_ref().unwrap(),
|
||||||
|
layer.expert_gate_up_scale.as_ref().unwrap(),
|
||||||
|
&layer.expert_gate_up_bias, &topk_ids,
|
||||||
|
num_tokens, top_k, expert_start, local_experts, false,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// GLU over all slots. Non-local slots hold unwritten memory; they
|
||||||
|
// are never consumed (the down GEMV and the weighted sum both skip
|
||||||
|
// slots whose expert this rank does not own).
|
||||||
|
let inter2 = gate_up.shape()[2];
|
||||||
|
let gate_up_flat = gate_up.reshape(&[num_tokens * top_k, inter2]);
|
||||||
|
let activated = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
|
||||||
|
|
||||||
|
let down = if let Some((ref packed, ref scales)) = layer.expert_down_mxfp4 {
|
||||||
|
let n = packed.shape()[1];
|
||||||
|
let k = packed.shape()[2] * 2;
|
||||||
|
xserv_kernels::moe::moe_sparse_gemv_mxfp4(
|
||||||
|
&activated, packed, scales, &layer.expert_down_bias, &topk_ids,
|
||||||
|
num_tokens, top_k, n, k, expert_start, local_experts, true,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
xserv_kernels::moe::moe_sparse_gemv_fp8(
|
||||||
|
&activated, layer.expert_down_fp8.as_ref().unwrap(),
|
||||||
|
layer.expert_down_scale.as_ref().unwrap(),
|
||||||
|
&layer.expert_down_bias, &topk_ids,
|
||||||
|
num_tokens, top_k, expert_start, local_experts, true,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse(
|
||||||
|
&down, &topk_ids, &topk_weights, expert_start, local_experts,
|
||||||
|
);
|
||||||
|
self.all_reduce(&moe_out);
|
||||||
|
return moe_out;
|
||||||
|
}
|
||||||
|
|
||||||
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
|
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
|
||||||
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
|
||||||
|
|
||||||
@@ -625,6 +679,12 @@ impl GptOss {
|
|||||||
|
|
||||||
// --- Helpers ---
|
// --- Helpers ---
|
||||||
|
|
||||||
|
/// XSERV_DENSE_MOE=1 forces the dense all-expert path (A/B benchmarking).
|
||||||
|
fn dense_moe_forced() -> bool {
|
||||||
|
static FORCED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
|
||||||
|
*FORCED.get_or_init(|| std::env::var("XSERV_DENSE_MOE").is_ok_and(|v| v != "0"))
|
||||||
|
}
|
||||||
|
|
||||||
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor {
|
||||||
assert_eq!(a.ndim(), 2);
|
assert_eq!(a.ndim(), 2);
|
||||||
assert_eq!(b.ndim(), 2);
|
assert_eq!(b.ndim(), 2);
|
||||||
|
|||||||
254
csrc/moe/moe_sparse.cu
Normal file
254
csrc/moe/moe_sparse.cu
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include "../common.cuh"
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Sparse MoE decode GEMVs — compute ONLY the routed experts.
|
||||||
|
//
|
||||||
|
// The dense path replicates x across all local experts and runs a
|
||||||
|
// batched GEMM, reading every expert's weights per token. Decode is
|
||||||
|
// memory-bound, so reading only the top-k routed experts' weights
|
||||||
|
// (~2 of 16 local on average at TP=2) is a ~8x byte reduction.
|
||||||
|
//
|
||||||
|
// Each block handles one (token, slot) pair's tile of output columns.
|
||||||
|
// It reads topk_ids[token, slot] from device memory (no host sync),
|
||||||
|
// and exits early if the expert is not owned by this rank. The early
|
||||||
|
// return is BLOCK-UNIFORM (every thread sees the same topk_ids value
|
||||||
|
// and returns before the shared-memory staging + __syncthreads), so
|
||||||
|
// it is safe — unlike the divergent-return bug fixed in gemv.cu.
|
||||||
|
//
|
||||||
|
// Outputs for non-local slots are NEVER written (uninitialized memory,
|
||||||
|
// possibly NaN bit patterns). Downstream consumers must SKIP non-local
|
||||||
|
// slots rather than multiply by zero (NaN * 0 = NaN).
|
||||||
|
//
|
||||||
|
// Per-expert weight scale and bias are fused into the epilogue:
|
||||||
|
// y[t, slot, n] = acc * w_scale[lid] + bias[lid, n]
|
||||||
|
// which matches the dense path's GEMM -> moe_bias_add_3d sequence.
|
||||||
|
//
|
||||||
|
// Activation addressing (x_per_slot):
|
||||||
|
// gate_up: all slots of a token share x[token, :] (x_per_slot=0)
|
||||||
|
// down: each slot has its own activation row
|
||||||
|
// x[token * top_k + slot, :] (x_per_slot=1)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
#define SPARSE_TILE_N 8 // output columns per block (= warps per block)
|
||||||
|
|
||||||
|
// Weights FP8 E4M3 [local_experts, N, K], activations BF16 (W8A16).
|
||||||
|
// Decode is memory-bound (~2 FLOP/byte), so dequant-in-registers GEMV
|
||||||
|
// loses nothing to tensor cores and skips activation quantization.
|
||||||
|
__global__ void moe_sparse_gemv_fp8_bf16_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||||
|
const __nv_fp8_e4m3* __restrict__ w, // [local_experts, N, K]
|
||||||
|
const float* __restrict__ w_scales, // [local_experts]
|
||||||
|
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||||
|
const int* __restrict__ topk_ids, // [T, top_k] global expert ids
|
||||||
|
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||||
|
int N, int K, int top_k,
|
||||||
|
int expert_start, int local_experts,
|
||||||
|
int x_per_slot
|
||||||
|
) {
|
||||||
|
int token = blockIdx.z;
|
||||||
|
int slot = blockIdx.y;
|
||||||
|
int eid = topk_ids[token * top_k + slot];
|
||||||
|
int lid = eid - expert_start;
|
||||||
|
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||||
|
|
||||||
|
extern __shared__ float xs[]; // [K] activation row as float
|
||||||
|
const __nv_bfloat16* xrow =
|
||||||
|
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||||
|
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||||
|
xs[i] = __bfloat162float(xrow[i]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||||
|
if (n >= N) return; // after __syncthreads: safe
|
||||||
|
int lane = threadIdx.x & 31;
|
||||||
|
|
||||||
|
// One warp per output column; uint4 = 16 FP8 weights per lane, the
|
||||||
|
// warp covers 512 contiguous bytes per iteration (coalesced).
|
||||||
|
const uint8_t* wrow = (const uint8_t*)w + ((long long)lid * N + n) * K;
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int i = lane; i < (K >> 4); i += 32) {
|
||||||
|
uint4 packed = *(const uint4*)(wrow + (long long)i * 16);
|
||||||
|
const __nv_fp8_e4m3* pw = (const __nv_fp8_e4m3*)&packed;
|
||||||
|
const float* xk = xs + i * 16;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 16; j++) {
|
||||||
|
acc += xk[j] * float(pw[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int o = 16; o > 0; o >>= 1) {
|
||||||
|
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||||
|
}
|
||||||
|
if (lane == 0) {
|
||||||
|
float v = acc * w_scales[lid]
|
||||||
|
+ __bfloat162float(bias[(long long)lid * N + n]);
|
||||||
|
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MXFP4 W4A16 variant: packed E2M1 nibbles + per-32 UE8M0 block scale,
|
||||||
|
// same structure as batched_gemv_mxfp4_bf16_kernel but expert-indexed
|
||||||
|
// via topk_ids and with fused per-expert bias.
|
||||||
|
#define MXFP4_BLOCK 32
|
||||||
|
|
||||||
|
__device__ __constant__ float kSparseFp4Levels[8] =
|
||||||
|
{0.f, 0.5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f};
|
||||||
|
|
||||||
|
__device__ __forceinline__ float sparse_fp4_to_float(uint8_t code) {
|
||||||
|
float mag = kSparseFp4Levels[code & 0x7];
|
||||||
|
return (code & 0x8) ? -mag : mag;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void moe_sparse_gemv_mxfp4_bf16_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ x, // [T, K] or [T*top_k, K]
|
||||||
|
const uint8_t* __restrict__ w_packed, // [local_experts, N, K/2]
|
||||||
|
const uint8_t* __restrict__ w_scales, // [local_experts, N, K/32]
|
||||||
|
const __nv_bfloat16* __restrict__ bias, // [local_experts, N]
|
||||||
|
const int* __restrict__ topk_ids, // [T, top_k]
|
||||||
|
__nv_bfloat16* __restrict__ y, // [T, top_k, N]
|
||||||
|
int N, int K, int top_k,
|
||||||
|
int expert_start, int local_experts,
|
||||||
|
int x_per_slot
|
||||||
|
) {
|
||||||
|
int token = blockIdx.z;
|
||||||
|
int slot = blockIdx.y;
|
||||||
|
int eid = topk_ids[token * top_k + slot];
|
||||||
|
int lid = eid - expert_start;
|
||||||
|
if (lid < 0 || lid >= local_experts) return; // block-uniform: safe
|
||||||
|
|
||||||
|
extern __shared__ float xs[];
|
||||||
|
const __nv_bfloat16* xrow =
|
||||||
|
x + (long long)(x_per_slot ? token * top_k + slot : token) * K;
|
||||||
|
for (int i = threadIdx.x; i < K; i += blockDim.x) {
|
||||||
|
xs[i] = __bfloat162float(xrow[i]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int n = blockIdx.x * SPARSE_TILE_N + (threadIdx.x >> 5);
|
||||||
|
if (n >= N) return;
|
||||||
|
int lane = threadIdx.x & 31;
|
||||||
|
int nblk = K / MXFP4_BLOCK;
|
||||||
|
|
||||||
|
const uint8_t* wp = w_packed + ((long long)lid * N + n) * (K >> 1);
|
||||||
|
const uint8_t* ws = w_scales + ((long long)lid * N + n) * nblk;
|
||||||
|
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int blk = lane; blk < nblk; blk += 32) {
|
||||||
|
float scale = exp2f((float)((int)ws[blk] - 127));
|
||||||
|
uint4 packed = *(const uint4*)(wp + (long long)blk * 16); // 32 nibbles
|
||||||
|
const uint8_t* pb = (const uint8_t*)&packed;
|
||||||
|
const float* xk = xs + blk * MXFP4_BLOCK;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
uint8_t b = pb[i];
|
||||||
|
acc += xk[2 * i] * (sparse_fp4_to_float(b & 0xF) * scale);
|
||||||
|
acc += xk[2 * i + 1] * (sparse_fp4_to_float(b >> 4) * scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int o = 16; o > 0; o >>= 1) {
|
||||||
|
acc += __shfl_down_sync(0xffffffffu, acc, o);
|
||||||
|
}
|
||||||
|
if (lane == 0) {
|
||||||
|
float v = acc + __bfloat162float(bias[(long long)lid * N + n]);
|
||||||
|
y[((long long)token * top_k + slot) * N + n] = __float2bfloat16(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Weighted sum over the slot axis: out[t, d] = sum over local slots of
|
||||||
|
// topk_weights[t, k] * down[t, k, d]. Non-local slots hold uninitialized
|
||||||
|
// memory and are SKIPPED (not multiplied by zero).
|
||||||
|
__global__ void moe_weighted_sum_sparse_bf16_kernel(
|
||||||
|
const __nv_bfloat16* __restrict__ down, // [T, top_k, hidden]
|
||||||
|
const int* __restrict__ topk_ids, // [T, top_k]
|
||||||
|
const float* __restrict__ topk_weights, // [T, top_k]
|
||||||
|
__nv_bfloat16* __restrict__ out, // [T, hidden]
|
||||||
|
int num_tokens, int hidden, int top_k,
|
||||||
|
int expert_start, int local_experts
|
||||||
|
) {
|
||||||
|
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int total = num_tokens * hidden;
|
||||||
|
if (idx >= total) return;
|
||||||
|
|
||||||
|
int token = idx / hidden;
|
||||||
|
int dim = idx % hidden;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int k = 0; k < top_k; k++) {
|
||||||
|
int lid = topk_ids[token * top_k + k] - expert_start;
|
||||||
|
if (lid >= 0 && lid < local_experts) {
|
||||||
|
float w = topk_weights[token * top_k + k];
|
||||||
|
float v = __bfloat162float(
|
||||||
|
down[((long long)token * top_k + k) * hidden + dim]);
|
||||||
|
sum += w * v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out[idx] = __float2bfloat16(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void launch_moe_sparse_gemv_fp8_bf16(
|
||||||
|
const void* x, const void* w, const void* w_scales, const void* bias,
|
||||||
|
const void* topk_ids, void* y,
|
||||||
|
int num_tokens, int N, int K, int top_k,
|
||||||
|
int expert_start, int local_experts, int x_per_slot,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||||
|
int block = SPARSE_TILE_N * 32;
|
||||||
|
size_t smem = (size_t)K * sizeof(float);
|
||||||
|
moe_sparse_gemv_fp8_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)x, (const __nv_fp8_e4m3*)w,
|
||||||
|
(const float*)w_scales, (const __nv_bfloat16*)bias,
|
||||||
|
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||||
|
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_moe_sparse_gemv_mxfp4_bf16(
|
||||||
|
const void* x, const void* w_packed, const void* w_scales, const void* bias,
|
||||||
|
const void* topk_ids, void* y,
|
||||||
|
int num_tokens, int N, int K, int top_k,
|
||||||
|
int expert_start, int local_experts, int x_per_slot,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
dim3 grid((N + SPARSE_TILE_N - 1) / SPARSE_TILE_N, top_k, num_tokens);
|
||||||
|
int block = SPARSE_TILE_N * 32;
|
||||||
|
size_t smem = (size_t)K * sizeof(float);
|
||||||
|
moe_sparse_gemv_mxfp4_bf16_kernel<<<grid, block, smem, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)x, (const uint8_t*)w_packed,
|
||||||
|
(const uint8_t*)w_scales, (const __nv_bfloat16*)bias,
|
||||||
|
(const int*)topk_ids, (__nv_bfloat16*)y,
|
||||||
|
N, K, top_k, expert_start, local_experts, x_per_slot
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
void launch_moe_weighted_sum_sparse_bf16(
|
||||||
|
const void* down, const void* topk_ids, const void* topk_weights,
|
||||||
|
void* out,
|
||||||
|
int num_tokens, int hidden, int top_k,
|
||||||
|
int expert_start, int local_experts,
|
||||||
|
void* stream
|
||||||
|
) {
|
||||||
|
int total = num_tokens * hidden;
|
||||||
|
int block = 256;
|
||||||
|
int grid = (total + block - 1) / block;
|
||||||
|
moe_weighted_sum_sparse_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
|
||||||
|
(const __nv_bfloat16*)down,
|
||||||
|
(const int*)topk_ids, (const float*)topk_weights,
|
||||||
|
(__nv_bfloat16*)out,
|
||||||
|
num_tokens, hidden, top_k, expert_start, local_experts
|
||||||
|
);
|
||||||
|
CUDA_CHECK_LAST_ERROR();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
160
docs/20-sparse-moe.md
Normal file
160
docs/20-sparse-moe.md
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Phase 20: Sparse MoE Decode — 只算被路由到的专家
|
||||||
|
|
||||||
|
> 目标:消除 dense MoE 的无效权重读取,decode TPOT 追上并超过 llama.cpp。
|
||||||
|
> 前置:Phase 19(gpt-oss MoE 正确性)、FP8 W8A8 / MXFP4 W4A16 量化
|
||||||
|
> (见 `docs/benchmarks/fp8-quantization.md`、`docs/benchmarks/mxfp4-and-llama-decode.md`)。
|
||||||
|
|
||||||
|
## 1. 现状:dense MoE 在浪费什么
|
||||||
|
|
||||||
|
gpt-oss-20b 是 32 专家 top-4 的 MoE:router 给每个 token 选 4 个专家,
|
||||||
|
理论上每 token 只需要读 4/32 = 12.5% 的专家权重。但 `moe_forward`
|
||||||
|
(`crates/xserv-model/src/gpt_oss.rs`)目前是 **dense** 实现:
|
||||||
|
|
||||||
|
```text
|
||||||
|
1. router GEMV [T, 2880] → [T, 32]
|
||||||
|
2. topk_softmax (GPU) → topk_ids [T,4], topk_weights [T,4]
|
||||||
|
3. moe_replicate x 复制 16 份 → [16, T, 2880] ← 浪费开始
|
||||||
|
4. batched GEMM gate_up 全部 16 个本地专家都算 ← 读 16 份权重
|
||||||
|
5. bias + GLU
|
||||||
|
6. batched GEMM down 全部 16 个本地专家都算 ← 读 16 份权重
|
||||||
|
7. bias
|
||||||
|
8. moe_weighted_sum 只挑出 top-4 加权求和,其余 12 个全部丢弃
|
||||||
|
9. all-reduce
|
||||||
|
```
|
||||||
|
|
||||||
|
为什么当初这么写:batched GEMM(cuBLAS strided-batched)要求规则的
|
||||||
|
`[E, T, K]` 形状;top-4 的专家编号在 **GPU** 上(`topk_ids`),host 不知道
|
||||||
|
该挑哪几个,挑了形状也不规则。dense 是"先把正确性做出来"的合理起点,
|
||||||
|
但每 token 把 16 个专家的权重从 HBM 全部读一遍。
|
||||||
|
|
||||||
|
### 字节账本(decode,每 token,TP=2 每卡 16 个本地专家)
|
||||||
|
|
||||||
|
每层每专家:gate_up `[2880, 5760]` + down `[2880, 2880]` ≈ 24.9 M 参数。
|
||||||
|
|
||||||
|
| 方案 | 每卡每 token 专家字节 | 相对量 |
|
||||||
|
|---|---|---|
|
||||||
|
| xserv dense FP8(现状) | 16 × 24.9 MB × 24 层 ≈ **9.6 GB** | 1× |
|
||||||
|
| xserv sparse FP8(本阶段) | ~2 × 24.9 MB × 24 层 ≈ **1.2 GB** | 1/8 |
|
||||||
|
| llama.cpp sparse MXFP4 | ~2 × 12.5 MB × 24 层 ≈ **0.6 GB** | 1/16 |
|
||||||
|
|
||||||
|
(top-4 均匀散落在 2 张卡上,期望每卡 2 个命中;严格说每层取的是
|
||||||
|
两卡命中数的 max,期望 ≈ 2.6,仍是 ~6-8× 的节省。)
|
||||||
|
|
||||||
|
实测旁证:FP8 dense TP=2 TPOT 13.1 ms,其中专家 GEMM ≈ 9.6 GB ÷ ~1 TB/s
|
||||||
|
≈ 9.5 ms,其余(attention、qkv/o、lm_head、48 次 PCIe all-reduce)≈ 3.5 ms。
|
||||||
|
**专家权重读取占 TPOT 的 ~3/4,这就是与 llama.cpp(6.6 ms)的全部差距。**
|
||||||
|
|
||||||
|
## 2. Roofline:M=1 时为什么"省字节 = 省时间"
|
||||||
|
|
||||||
|
decode 的 GEMV(M=1)每读 1 字节 FP8 权重只做 2 FLOP(乘加)。
|
||||||
|
RTX 5090:HBM ~1.8 TB/s,BF16 算力 ~210 TFLOPS —— 算强比(arithmetic
|
||||||
|
intensity)需要 ~100 FLOP/byte 才能喂饱算力,GEMV 只有 2。结论:
|
||||||
|
|
||||||
|
1. **decode 完全 memory-bound**,tensor core 帮不上忙 → 手写 W8A16 GEMV
|
||||||
|
(权重 FP8、激活保持 BF16)不会输给 cuBLASLt 的 W8A8 tensor-core GEMM,
|
||||||
|
还省掉激活量化 kernel,精度更好(激活不再有量化误差)。
|
||||||
|
2. 优化只有一个方向:**少读字节**。sparse(×8)与 4-bit(×2)正交,
|
||||||
|
可叠加。本阶段先做 sparse,FP8 与 MXFP4 两种权重格式都支持。
|
||||||
|
|
||||||
|
## 3. Sparse 设计:让 kernel 自己按 topk_ids 索引权重
|
||||||
|
|
||||||
|
关键观察:`topk_ids` 本来就在 GPU 上。不需要 host 知道选了谁 ——
|
||||||
|
**让 GEMV kernel 的每个 block 自己读 `topk_ids[token, slot]`,
|
||||||
|
直接寻址到对应专家的权重**,不命中本卡就整块退出。零 host 同步,
|
||||||
|
管线保持完全异步(这是之前排查过的:decode 循环无 per-layer sync)。
|
||||||
|
|
||||||
|
新数据流(`num_tokens ≤ 8` 时启用):
|
||||||
|
|
||||||
|
```text
|
||||||
|
x [T, 2880]
|
||||||
|
├─ router → topk_ids/weights [T, 4] (不变)
|
||||||
|
├─ sparse GEMV gate_up → [T, 4, 5760] bias 已融合,非本地 slot 不写
|
||||||
|
├─ GLU → [T*4, 2880]
|
||||||
|
├─ sparse GEMV down → [T, 4, 2880] bias 已融合,非本地 slot 不写
|
||||||
|
└─ weighted_sum_sparse → [T, 2880] 只累加本地 slot
|
||||||
|
all-reduce (不变)
|
||||||
|
```
|
||||||
|
|
||||||
|
`moe_replicate` 和独立的 bias kernel 在 sparse 路径下消失;FP8 路径还省掉
|
||||||
|
`quantize_bf16_to_fp8_rowwise`。
|
||||||
|
|
||||||
|
### Kernel 设计(`csrc/moe/moe_sparse.cu`)
|
||||||
|
|
||||||
|
`moe_sparse_gemv_{fp8,mxfp4}_bf16_kernel`:
|
||||||
|
|
||||||
|
- **grid = (N/8, top_k, tokens)**,block = 8 warp × 32 lane。
|
||||||
|
每个 block 负责一个 (token, slot) 的 8 个输出列,**一个 warp 算一个输出**。
|
||||||
|
- block 先读 `eid = topk_ids[token*top_k + slot]`,折算 `lid = eid - expert_start`;
|
||||||
|
不在 `[0, local_experts)` 就整块 return。
|
||||||
|
- 命中的 block 把激活行(K=2880 个 BF16 → float)协作搬进 shared memory
|
||||||
|
(11.25 KB),`__syncthreads()`,然后每 warp 沿 K 维做点积:
|
||||||
|
每 lane 一次 `uint4` 读 16 字节权重(FP8 = 16 个权重,MXFP4 = 32 个 nibble),
|
||||||
|
warp 内 32 lane 连续 → 512B coalesced 事务。
|
||||||
|
- epilogue(lane 0):`y = acc * w_scale[lid] + bias[lid, n]` —— per-expert
|
||||||
|
scale 和 bias 都融合在这里,与 dense 路径的"GEMM → bias add → 路由加权"
|
||||||
|
语义逐位等价(HF 参考实现也是先加 bias 再乘路由权重)。
|
||||||
|
- gate_up 与 down 共用同一个 kernel,用 `x_per_slot` 区分激活寻址:
|
||||||
|
gate_up 时 4 个 slot 共享 `x[token]`;down 时各读自己的 `act[token*4+slot]`。
|
||||||
|
|
||||||
|
### 两个容易写错的安全点
|
||||||
|
|
||||||
|
1. **early-return 必须 block-uniform。** Phase 19 的 GEMV 垃圾输出 bug
|
||||||
|
(commit `3b9e32e`)正是"部分线程在 `__syncthreads()` 之前 return"导致
|
||||||
|
读未初始化 shared memory。这里的 return 发生在 smem 装载**之前**,且整个
|
||||||
|
block 基于同一个 `topk_ids` 值统一退出 —— 没有 divergence,合法且安全。
|
||||||
|
2. **weighted-sum 对非本地 slot 必须"跳过",不能"乘 0"。** 非本地 slot 的
|
||||||
|
GEMV 输出从未被写入(未初始化显存,可能是 NaN 位型),GLU 也会在上面算出
|
||||||
|
垃圾。`NaN × 0 = NaN`,所以求和 kernel 用 `if (local) sum += w*v` 跳过,
|
||||||
|
垃圾永远不进入数据流(dense 路径的 `moe_weighted_sum` 同理)。
|
||||||
|
|
||||||
|
## 4. 为什么 prefill 保持 dense
|
||||||
|
|
||||||
|
dense batched GEMM 把 16 份权重读**一次**,服务全部 M 个 token;
|
||||||
|
sparse GEMV 是**每 token** 重读自己的 ~2 份。字节交叉点:
|
||||||
|
|
||||||
|
```text
|
||||||
|
sparse 读 M × 2 份 vs dense 读 16 份 → M ≈ 8 (TP=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
M > 8 后 dense 更省(且 GEMM 是 compute-bound,tensor core 开始有用)。
|
||||||
|
所以 sparse 只在 `num_tokens ≤ 8` 启用 —— 覆盖 decode(连续批合并的
|
||||||
|
多请求 decode 也是小 M)和极短的 re-prefill。真正的 sparse prefill
|
||||||
|
(按专家对 token 做 permute/gather 的 grouped GEMM,vLLM 的做法)是
|
||||||
|
后续阶段,主要收益在长 prompt TTFT。
|
||||||
|
|
||||||
|
## 5. 实测结果(2026-06-12,完整数据见 `docs/benchmarks/sparse-moe.md`)
|
||||||
|
|
||||||
|
In-process decode(bench-gpt-oss,greedy 96 tok):
|
||||||
|
|
||||||
|
| | TPOT | tok/s |
|
||||||
|
|---|---|---|
|
||||||
|
| dense FP8 TP=2(基线) | 13.9 ms | 72 |
|
||||||
|
| **sparse FP8 TP=2** | **7.6 ms(1.8×)** | **132** |
|
||||||
|
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
|
||||||
|
| sparse FP8 TP=1(单卡) | 7.8 ms | 128 |
|
||||||
|
|
||||||
|
Warm-server 对打 llama.cpp(`tools/xserv_vs_llama.py`):
|
||||||
|
|
||||||
|
- **TP=2 vs TP=2:xserv 首次全面反超** —— TPOT 7.19-7.32 ms vs llama
|
||||||
|
7.54-8.42 ms;短/中 prompt TTFT 也领先(35/49 vs 63/65 ms)。
|
||||||
|
- **TP=1 vs TP=1:llama 大胜**(2.88-3.22 ms vs 7.0-7.2 ms,347 vs 140
|
||||||
|
tok/s)。单卡才是 llama 的最优配置:它的跨卡 split 在 PCIe 上每 token
|
||||||
|
损失 ~5 ms,而单卡时它"全模型 4-bit + CUDA graph 整 token 回放"的
|
||||||
|
优势全部兑现。xserv 的残余 ~7 ms ≈ ~3 ms HBM(其中非专家权重还是
|
||||||
|
BF16,含 1.16 GB 的 lm_head)+ ~4 ms 启动开销(~200 个 kernel
|
||||||
|
launch/token,无 CUDA graph)。
|
||||||
|
- **正确性:GSM8K-100 = 96%**(dense FP8 91% / BF16 90%,greedy 噪声内,
|
||||||
|
无回归)。
|
||||||
|
|
||||||
|
教训:之前"CUDA graph ≈ 无用(~0.5-1.5ms)"的结论是相对 13 ms 的
|
||||||
|
dense TPOT 而言;专家成本砍掉后,launch 开销变成了最大的单项。
|
||||||
|
|
||||||
|
## 6. 下一阶段(按收益排序)
|
||||||
|
|
||||||
|
1. **decode CUDA graph**(~2-4 ms):当前最大单项。
|
||||||
|
2. **非专家权重量化**(~1-1.5 ms):qkv/o + lm_head 仍是 BF16,每 token
|
||||||
|
白读 ~2.3 GB;llama 是全模型 4-bit。
|
||||||
|
3. **sparse prefill**(grouped GEMM):长 prompt TTFT 94-120 ms → llama
|
||||||
|
的 ~30 ms 量级。
|
||||||
|
4. **W4A4 FP4 tensor core / 带宽调优的 MXFP4 GEMV**:让 4-bit 专家真正
|
||||||
|
快过 FP8(目前 8.4 vs 7.6 ms,GEMV 效率抵消了字节优势)。
|
||||||
90
docs/benchmarks/sparse-moe.md
Normal file
90
docs/benchmarks/sparse-moe.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Sparse MoE decode — 1.8× over dense; beats llama.cpp at TP=2 (gpt-oss-20b, RTX 5090)
|
||||||
|
|
||||||
|
Phase 20 (`docs/20-sparse-moe.md`): decode computes only the routed top-4
|
||||||
|
experts via fused expert-indexed GEMVs (`csrc/moe/moe_sparse.cu`) instead of
|
||||||
|
the dense all-local-expert batched GEMM. FP8 weights run W8A16 (weights FP8,
|
||||||
|
activations BF16 — decode is memory-bound, tensor cores irrelevant at M=1);
|
||||||
|
MXFP4 runs W4A16. Dense path retained for prefill / `num_tokens > 8` and via
|
||||||
|
`XSERV_DENSE_MOE=1` for A/B.
|
||||||
|
|
||||||
|
## In-process decode (bench-gpt-oss, greedy, 96 tokens)
|
||||||
|
|
||||||
|
| config | TPOT | tok/s |
|
||||||
|
|---|---|---|
|
||||||
|
| dense FP8 TP=2 (baseline) | 13.9 ms | 72 |
|
||||||
|
| **sparse FP8 TP=2** | **7.6 ms** | **132** |
|
||||||
|
| sparse MXFP4 TP=2 | 8.4 ms | 118 |
|
||||||
|
| sparse FP8 TP=1 (one 5090) | 7.8 ms | 128 |
|
||||||
|
| sparse MXFP4 TP=1 | 8.9 ms | 113 |
|
||||||
|
|
||||||
|
- Sparse FP8 = **1.8× over dense**. Greedy output stays coherent.
|
||||||
|
- TP=1 ≈ TP=2: expert reads are now so small that PCIe all-reduce eats the
|
||||||
|
TP gain — single-GPU serving becomes the attractive deployment.
|
||||||
|
- MXFP4 reads half the bytes of FP8 but stays slower: the 4-bit dequant GEMV
|
||||||
|
has lower effective bandwidth (same fixed inefficiency seen in the dense
|
||||||
|
MXFP4 experiments); at sparse sizes both are partly launch/latency-bound.
|
||||||
|
|
||||||
|
## Head-to-head vs llama.cpp (tools/xserv_vs_llama.py, warm servers, TP=2, GPUs 0-1, 6 reps, 256 tok)
|
||||||
|
|
||||||
|
| prompt | metric | xserv sparse FP8 | llama MXFP4 | xserv vs llama |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| short | TTFT | **35.3 ms** | 62.7 ms | 1.78× faster |
|
||||||
|
| short | TPOT | **7.32 ms** | 8.42 ms | 1.15× faster |
|
||||||
|
| medium | TTFT | **49.4 ms** | 65.0 ms | 1.32× faster |
|
||||||
|
| medium | TPOT | **7.19 ms** | 7.54 ms | 1.05× faster |
|
||||||
|
| medium | tok/s | **139.1** | 132.7 | |
|
||||||
|
| long (1.6k) | TTFT | 94.1 ms | **44.7 ms** | 0.48× (llama wins) |
|
||||||
|
| long | TPOT | **7.25 ms** | 7.64 ms | 1.05× faster |
|
||||||
|
|
||||||
|
**Decode TPOT now beats llama.cpp at every prompt length** (was 2× slower:
|
||||||
|
13.1 vs 6.6 ms before sparse). Remaining loss: long-prompt TTFT — prefill is
|
||||||
|
still the dense all-expert GEMM; sparse/grouped prefill is the next phase.
|
||||||
|
|
||||||
|
## TP=1 head-to-head (single 5090; server now routes gpt-oss tp=1 to the TP engine)
|
||||||
|
|
||||||
|
| prompt | metric | xserv sparse FP8 | llama MXFP4 |
|
||||||
|
|---|---|---|---|
|
||||||
|
| short | TTFT / TPOT | 42.8 ms / 7.00 ms | **34.5 ms / 3.22 ms** |
|
||||||
|
| medium | TTFT / TPOT | 57.1 ms / 7.19 ms | **37.3 ms / 2.89 ms** |
|
||||||
|
| long | TTFT / TPOT | 119.6 ms / 7.20 ms | **27.8 ms / 2.88 ms** |
|
||||||
|
| | tok/s | 139–143 | **311–347** |
|
||||||
|
|
||||||
|
**Single-GPU is llama.cpp's sweet spot and it wins 2.2–2.5×.** Two structural
|
||||||
|
reasons, both instructive:
|
||||||
|
|
||||||
|
1. llama TP=2 (7.5–8.4 ms) is much WORSE than its TP=1 (2.9 ms): its PCIe
|
||||||
|
cross-GPU split costs ~5 ms/token. xserv's NCCL all-reduce is cheap enough
|
||||||
|
that TP=2 ≈ TP=1 (7.2 vs 7.0 ms) — but xserv's single-GPU floor is high.
|
||||||
|
2. xserv TP=1 reads ~4.7 GB/token (experts FP8 2.4 GB + **non-expert weights
|
||||||
|
still BF16** ~2.3 GB, half of that the 201k-vocab lm_head) ≈ 3.1 ms of pure
|
||||||
|
HBM time; the other ~4 ms is launch overhead (~200 kernels/token, no CUDA
|
||||||
|
graphs) + BF16 GEMV efficiency. llama reads ~1.3 GB (everything MXFP4) and
|
||||||
|
replays the whole token as one CUDA graph.
|
||||||
|
|
||||||
|
## Correctness
|
||||||
|
|
||||||
|
- Greedy generations coherent across prompts (FP8/MXFP4, TP=1/2).
|
||||||
|
- Sparse FP8 is W8A16 vs dense W8A8 — activations are no longer quantized, so
|
||||||
|
tokens are not expected to be byte-identical to dense; quality is checked by
|
||||||
|
GSM8K instead.
|
||||||
|
- **GSM8K-100 (greedy, TP=2, `tools/eval_gsm8k_fast.py`): 96/100 = 96.0%** vs
|
||||||
|
dense FP8 91.0% / BF16 90.0% — no regression (within greedy-nondeterminism
|
||||||
|
noise; W8A16 removes activation-quantization error so ≥ dense is expected).
|
||||||
|
Avg 1.3 s/problem also reflects the decode speedup.
|
||||||
|
|
||||||
|
## Remaining gaps / next levers (to catch llama TP=1 at 2.9 ms)
|
||||||
|
|
||||||
|
Sparse MoE removed the dominant cost; the residual ~7 ms splits roughly into
|
||||||
|
~3 ms HBM reads and ~4 ms fixed overhead. In impact order:
|
||||||
|
|
||||||
|
1. **CUDA graphs for decode** (~2–4 ms): with experts down to ~1–2 ms, the
|
||||||
|
~200 un-graphed launches/token are now the single largest cost. (The old
|
||||||
|
"graphs ≈ useless" conclusion was relative to a 13 ms dense TPOT — no
|
||||||
|
longer true.)
|
||||||
|
2. **Quantize non-expert weights** (~1–1.5 ms): attn qkv/o + the 1.16 GB BF16
|
||||||
|
lm_head read every token; FP8/MXFP4 them like llama quantizes everything.
|
||||||
|
3. **Sparse prefill** (permute tokens by expert + grouped GEMM): long-prompt
|
||||||
|
TTFT 94–120 ms → llama's ~30 ms territory.
|
||||||
|
4. **W4A4 FP4 tensor cores / bandwidth-tuned MXFP4 GEMV**: make 4-bit experts
|
||||||
|
actually beat FP8 (today sparse MXFP4 is 8.4 ms vs FP8 7.6 ms — the 4-bit
|
||||||
|
GEMV's lower effective bandwidth still cancels its byte advantage).
|
||||||
Reference in New Issue
Block a user