model: fused GPU MoE kernel — eliminate CPU roundtrip

Replace the per-token CPU-routed MoE forward with an all-GPU path:

  1. moe_topk_softmax: GPU top-k + softmax (was CPU sort + softmax)
  2. moe_replicate: broadcast input to all local experts
  3. cublasGemmStridedBatchedEx: batched expert matmul (was per-expert cuBLAS)
  4. moe_weighted_sum: FP32-accumulated weighted sum on GPU (was GPU→CPU→F32→BF16→GPU)

Expert weights stored as contiguous 3D tensors for strided batched GEMM.
Zero CPU↔GPU transfers per MoE layer (was ~40 per token per layer).

Also: configurable geglu_alpha, LayerNorm bias auto-detect, unused-weight
diagnostic at load time.

GSM8K 30-problem: 11/30 → 23/30 (76.7%) vs llama.cpp 30/30 (100%).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-31 13:22:59 +08:00
parent 377a04b81f
commit 4368e79695
6 changed files with 617 additions and 110 deletions

View File

@@ -29,6 +29,7 @@ fn main() {
.file("../../csrc/attention/flash_attention.cu")
.file("../../csrc/attention/paged_attention.cu")
.file("../../csrc/attention/reshape_and_cache.cu")
.file("../../csrc/moe/moe_kernels.cu")
.compile("xserv_kernels");
println!("cargo:rerun-if-changed=../../csrc/");

View File

@@ -5,6 +5,7 @@ pub mod dispatch;
pub mod embedding;
pub mod gemm;
pub mod layernorm;
pub mod moe;
pub mod rmsnorm;
pub mod rope;
pub mod softmax;

View File

@@ -0,0 +1,223 @@
use std::ffi::c_void;
use xserv_tensor::{DType, Device, Tensor};
use crate::gemm::{cublas_handle, CublasHandle};
unsafe extern "C" {
fn launch_moe_topk_softmax_bf16(
router_logits: *const c_void,
topk_ids: *mut c_void, topk_weights: *mut c_void,
num_tokens: i32, num_experts: i32, top_k: i32,
stream: *mut c_void,
);
fn launch_moe_replicate_bf16(
x: *const c_void, x_rep: *mut c_void,
num_tokens: i32, hidden: i32, local_experts: i32,
stream: *mut c_void,
);
fn launch_moe_bias_add_3d_bf16(
x: *mut c_void, bias: *const c_void,
batch: i32, num_tokens: i32, dim: i32,
stream: *mut c_void,
);
fn launch_moe_weighted_sum_bf16(
expert_out: *const c_void,
topk_ids: *const c_void, topk_weights: *const c_void,
out: *mut c_void,
num_tokens: i32, hidden: i32, top_k: i32,
expert_start: i32, local_experts: i32,
stream: *mut c_void,
);
fn cublasGemmStridedBatchedEx(
handle: CublasHandle,
transa: i32, transb: i32,
m: i32, n: i32, k: i32,
alpha: *const c_void,
a: *const c_void, a_type: i32, lda: i32, stride_a: i64,
b: *const c_void, b_type: i32, ldb: i32, stride_b: i64,
beta: *const c_void,
c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64,
batch_count: i32,
compute_type: i32,
algo: i32,
) -> i32;
fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32;
}
const CUDA_R_16BF: i32 = 14;
const CUBLAS_COMPUTE_32F: i32 = 68;
const CUBLAS_GEMM_DEFAULT: i32 = -1;
/// GPU top-k selection + softmax over router logits.
///
/// Input: router_logits [num_tokens, num_experts] BF16 on GPU
/// Output: (topk_ids [num_tokens, top_k] i32, topk_weights [num_tokens, top_k] f32)
pub fn moe_topk_softmax(
router_logits: &Tensor,
num_experts: usize,
top_k: usize,
) -> (Tensor, Tensor) {
assert_eq!(router_logits.ndim(), 2);
assert_eq!(router_logits.dtype(), DType::BF16);
assert!(router_logits.is_contiguous());
let num_tokens = router_logits.shape()[0];
assert_eq!(router_logits.shape()[1], num_experts);
let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device());
unsafe {
launch_moe_topk_softmax_bf16(
router_logits.data_ptr() as *const c_void,
topk_ids.data_ptr() as *mut c_void,
topk_weights.data_ptr() as *mut c_void,
num_tokens as i32, num_experts as i32, top_k as i32,
std::ptr::null_mut(),
);
}
(topk_ids, topk_weights)
}
/// Replicate x [num_tokens, hidden] → [local_experts, num_tokens, hidden].
pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
assert_eq!(x.ndim(), 2);
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let num_tokens = x.shape()[0];
let hidden = x.shape()[1];
let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device());
unsafe {
launch_moe_replicate_bf16(
x.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, local_experts as i32,
std::ptr::null_mut(),
);
}
out
}
/// In-place 3D bias add: x [batch, num_tokens, dim] += bias [batch, dim].
pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
assert_eq!(x.ndim(), 3);
assert_eq!(bias.ndim(), 2);
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous());
let batch = x.shape()[0];
let num_tokens = x.shape()[1];
let dim = x.shape()[2];
assert_eq!(bias.shape(), &[batch, dim]);
unsafe {
launch_moe_bias_add_3d_bf16(
x.data_ptr() as *mut c_void,
bias.data_ptr() as *const c_void,
batch as i32, num_tokens as i32, dim as i32,
std::ptr::null_mut(),
);
}
}
/// Weighted sum of expert outputs → [num_tokens, hidden].
///
/// expert_out: [local_experts, num_tokens, hidden] BF16
/// topk_ids: [num_tokens, top_k] i32 (global expert indices)
/// topk_weights: [num_tokens, top_k] f32
pub fn moe_weighted_sum(
expert_out: &Tensor,
topk_ids: &Tensor,
topk_weights: &Tensor,
expert_start: usize,
local_experts: usize,
top_k: usize,
) -> Tensor {
assert_eq!(expert_out.ndim(), 3);
assert_eq!(expert_out.dtype(), DType::BF16);
let num_tokens = expert_out.shape()[1];
let hidden = expert_out.shape()[2];
let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, expert_out.device());
unsafe {
launch_moe_weighted_sum_bf16(
expert_out.data_ptr() as *const c_void,
topk_ids.data_ptr() as *const c_void,
topk_weights.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32,
std::ptr::null_mut(),
);
}
out
}
/// Strided batched GEMM for MoE expert forward.
/// C[b] = A[b] @ B[b] for b in 0..batch
///
/// A: [batch, M, K] BF16 contiguous
/// B: [batch, K, N] BF16 contiguous
/// Returns C: [batch, M, N] BF16
#[allow(clippy::too_many_arguments)]
pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.ndim(), 3);
assert_eq!(b.ndim(), 3);
assert_eq!(a.dtype(), DType::BF16);
assert_eq!(b.dtype(), DType::BF16);
assert!(a.is_contiguous() && b.is_contiguous());
assert_eq!(a.shape()[0], b.shape()[0]);
assert_eq!(a.shape()[2], b.shape()[1]);
let batch = a.shape()[0];
let m = a.shape()[1];
let k = a.shape()[2];
let n = b.shape()[2];
let c = Tensor::empty(&[batch, m, n], DType::BF16, a.device());
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
// cuBLAS column-major: we compute C^T = B^T @ A^T
// A is [batch, M, K] row-major → A^T is [K, M] col-major, lda=K
// B is [batch, K, N] row-major → B^T is [N, K] col-major, ldb=N? No...
//
// Actually for row-major: A[M,K] in memory = col-major A^T[K,M] with lda=K.
// So we call cublasGemmStridedBatchedEx with:
// transa=N, transb=N
// m=N, n=M, k=K (because cuBLAS sees col-major)
// A_cublas = B_row (pointer), lda=N
// B_cublas = A_row (pointer), ldb=K
// C_cublas = C_row (pointer), ldc=N
let stride_a = (m * k) as i64;
let stride_b = (k * n) as i64;
let stride_c = (m * n) as i64;
let handle = cublas_handle();
unsafe {
cublasSetStream_v2(handle, std::ptr::null_mut());
let status = cublasGemmStridedBatchedEx(
handle,
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
n as i32, m as i32, k as i32,
&alpha as *const f32 as *const c_void,
b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b,
a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a,
&beta as *const f32 as *const c_void,
c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c,
batch as i32,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT,
);
assert_eq!(status, 0, "cublasGemmStridedBatchedEx failed: {status}");
}
c
}

View File

@@ -73,6 +73,10 @@ pub struct ModelConfig {
pub rope_scaling: Option<RopeScaling>,
#[serde(default)]
pub swiglu_limit: Option<f64>,
#[serde(default)]
pub geglu_alpha: Option<f64>,
#[serde(default)]
pub hidden_act: Option<String>,
}
impl ModelConfig {
@@ -144,4 +148,8 @@ impl ModelConfig {
pub fn window_size(&self) -> usize {
self.sliding_window.unwrap_or(0)
}
pub fn geglu_alpha(&self) -> f32 {
self.geglu_alpha.unwrap_or(1.702) as f32
}
}

View File

@@ -12,15 +12,18 @@ pub struct GptOss {
embed_tokens: Tensor,
layers: Vec<GptOssBlock>,
norm: Tensor,
norm_bias: Option<Tensor>,
lm_head_t: Tensor,
rope_cache: RopeCache,
tp: Option<std::sync::Arc<xserv_distributed::TpContext>>,
local_num_heads: usize,
local_num_kv_heads: usize,
has_norm_bias: bool,
}
struct GptOssBlock {
input_norm: Tensor,
input_norm_bias: Option<Tensor>,
// Attention (with bias)
q_proj_wt: Tensor,
q_proj_bias: Tensor,
@@ -36,12 +39,15 @@ struct GptOssBlock {
window_size: usize,
// MoE MLP
post_norm: Tensor,
post_norm_bias: Option<Tensor>,
router_wt: Tensor,
router_bias: Tensor,
expert_gate_up_wt: Vec<Tensor>,
expert_gate_up_bias: Vec<Tensor>,
expert_down_wt: Vec<Tensor>,
expert_down_bias: Vec<Tensor>,
// 3D expert weights for batched GEMM (contiguous on GPU)
expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter]
expert_gate_up_bias: Tensor, // [local_experts, 2*inter]
expert_down_wt: Tensor, // [local_experts, inter, hidden]
expert_down_bias: Tensor, // [local_experts, hidden]
local_experts: usize,
// Activation params
glu_alpha: f32,
glu_limit: f32,
@@ -80,6 +86,7 @@ impl GptOss {
let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight"));
let norm = repl(take(&mut w, "model.norm.weight"));
let norm_bias = w.remove("model.norm.bias").map(|t| repl(t));
let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous();
let head_dim = config.head_dim();
@@ -106,13 +113,13 @@ impl GptOss {
let num_layers = config.num_layers();
let num_experts = config.num_experts();
let glu_alpha = 1.702f32;
let glu_alpha = config.geglu_alpha();
let glu_limit = config.swiglu_limit.unwrap_or(7.0) as f32;
let mut layers = Vec::with_capacity(num_layers);
if rank == 0 {
eprintln!(
"Loading gpt-oss weights: {} layers, {} experts, world={world}...",
"Loading gpt-oss weights: {} layers, {} experts, world={world}, glu_alpha={glu_alpha}...",
num_layers, num_experts
);
}
@@ -152,34 +159,26 @@ impl GptOss {
let local_experts = num_experts / world;
let expert_start = rank * local_experts;
let mut expert_gate_up_wt = Vec::with_capacity(local_experts);
let mut expert_gate_up_bias = Vec::with_capacity(local_experts);
let mut expert_down_wt = Vec::with_capacity(local_experts);
let mut expert_down_bias = Vec::with_capacity(local_experts);
let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size
let hidden = gate_up_3d.shape()[1];
let inter = down_3d.shape()[1]; // intermediate_size
for local_e in 0..local_experts {
let e = expert_start + local_e;
let gu_slice = slice_expert_3d(&gate_up_3d, e, hidden, inter2);
expert_gate_up_wt.push(gu_slice.to_device(dev));
let gu_bias = slice_expert_2d(&gate_up_bias_2d, e, inter2);
expert_gate_up_bias.push(gu_bias.to_device(dev));
let d_slice = slice_expert_3d(&down_3d, e, inter, hidden);
expert_down_wt.push(d_slice.to_device(dev));
let d_bias = slice_expert_2d(&down_bias_2d, e, hidden);
expert_down_bias.push(d_bias.to_device(dev));
}
// Slice the rank's range of experts as contiguous 3D tensors on GPU
let expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev);
let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev);
let expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev);
let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev);
xserv_cuda::allocator::cached_trim();
let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight")));
let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t));
let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight")));
let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t));
layers.push(GptOssBlock {
input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))),
input_norm,
input_norm_bias,
q_proj_wt,
q_proj_bias,
k_proj_wt,
@@ -191,13 +190,15 @@ impl GptOss {
sinks,
is_sliding,
window_size,
post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))),
post_norm,
post_norm_bias,
router_wt,
router_bias,
expert_gate_up_wt,
expert_gate_up_bias,
expert_down_wt,
expert_down_bias,
local_experts,
glu_alpha,
glu_limit,
});
@@ -206,16 +207,35 @@ impl GptOss {
let local_num_heads = config.num_heads() / world;
let local_num_kv_heads = config.num_kv_heads() / world;
let has_norm_bias = norm_bias.is_some();
if rank == 0 {
if has_norm_bias {
eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm");
}
}
// Warn about unused weights that the model didn't consume
if rank == 0 && !w.is_empty() {
eprintln!("WARNING: {} unused weight(s) in model:", w.len());
let mut keys: Vec<_> = w.keys().collect();
keys.sort();
for k in &keys {
eprintln!(" {k}");
}
}
Self {
config,
embed_tokens,
layers,
norm,
norm_bias,
lm_head_t,
rope_cache,
tp,
local_num_heads,
local_num_kv_heads,
has_norm_bias,
}
}
@@ -229,6 +249,34 @@ impl GptOss {
}
}
#[inline]
fn norm(x: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> Tensor {
match bias {
Some(b) => layernorm(x, weight, b, eps),
None => rmsnorm(x, weight, eps),
}
}
#[inline]
fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option<Tensor>, eps: f32) -> (Tensor, Tensor) {
match bias {
Some(b) => {
let sum = xserv_kernels::add(x, residual);
let normed = layernorm(&sum, weight, b, eps);
(normed, sum)
}
None => add_rmsnorm(x, residual, weight, eps),
}
}
fn norm_eps(&self) -> f32 {
if self.has_norm_bias {
self.config.ln_eps()
} else {
self.config.rms_norm_eps.unwrap_or(1e-5) as f32
}
}
/// Paged decode: process one token per sequence using paged KV cache.
pub fn forward_decode_paged(
&self,
@@ -245,7 +293,7 @@ impl GptOss {
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-5) as f32;
let eps = self.norm_eps();
let kv_lens: Vec<i32> = positions.iter().map(|&p| (p + 1) as i32).collect();
for (b, &slot) in seq_slots.iter().enumerate() {
@@ -263,7 +311,7 @@ impl GptOss {
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let normed = Self::norm(&x, &layer.input_norm, &layer.input_norm_bias, eps);
// Q/K/V projections with bias
let q_all = add_bias(&matmul_2d(&normed, &layer.q_proj_wt), &layer.q_proj_bias);
@@ -304,13 +352,11 @@ impl GptOss {
// Residual + post-norm
let (normed, x_new) = add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let residual = x_new;
let normed = normed.contiguous();
// MoE MLP
let moe_out = self.moe_forward(&normed, layer, batch);
x = xserv_kernels::add(&residual, &moe_out);
@@ -322,7 +368,7 @@ impl GptOss {
}
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let x = rmsnorm(&x, &self.norm, eps);
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
@@ -341,7 +387,7 @@ impl GptOss {
let num_heads = self.local_num_heads;
let num_kv_heads = self.local_num_kv_heads;
let head_dim = self.config.head_dim();
let eps = self.config.rms_norm_eps.unwrap_or(1e-5) as f32;
let eps = self.norm_eps();
paged_cache.ensure_capacity(slot, pos_offset + new_tokens);
paged_cache.advance_seq_len(slot, new_tokens);
@@ -351,7 +397,7 @@ impl GptOss {
for (layer_idx, layer) in self.layers.iter().enumerate() {
let residual = x.clone();
let normed = rmsnorm(&x, &layer.input_norm, eps);
let normed = Self::norm(&x, &layer.input_norm, &layer.input_norm_bias, eps);
let q = add_bias(&matmul_2d(&normed, &layer.q_proj_wt), &layer.q_proj_bias);
let k = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias);
@@ -381,7 +427,7 @@ impl GptOss {
self.all_reduce(&attn_proj);
let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias);
let (normed, x_new) = add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps);
let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps);
let residual = x_new;
// MoE MLP
@@ -389,90 +435,65 @@ impl GptOss {
x = xserv_kernels::add(&residual, &moe_out);
}
let x = rmsnorm(&x, &self.norm, eps);
let x = Self::norm(&x, &self.norm, &self.norm_bias, eps);
let logits = matmul_2d(&x, &self.lm_head_t);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
logits
}
/// MoE forward pass for one layer with expert parallelism.
/// Each rank owns `num_experts / world` experts. Tokens routed to non-local
/// experts get zero contribution from this rank; AllReduce sums all ranks.
/// Input: [tokens, hidden], Output: [tokens, hidden]
/// MoE forward pass — fully on GPU via batched GEMM.
///
/// Each rank owns `local_experts` experts. The input is replicated across all
/// local experts, processed via two batched cuBLAS GEMMs (gate_up and down),
/// and the selected experts' outputs are weighted-summed on GPU. Non-selected
/// experts contribute zero (via the routing weights), so no scatter/gather is
/// needed. AllReduce sums partial results across TP ranks.
fn moe_forward(&self, x: &Tensor, layer: &GptOssBlock, num_tokens: usize) -> Tensor {
let hidden = self.config.hidden();
let num_experts = self.config.num_experts();
let top_k = self.config.experts_per_token();
let world = self.tp.as_ref().map(|tp| tp.world).unwrap_or(1);
let rank = self.tp.as_ref().map(|tp| tp.rank).unwrap_or(0);
let local_experts = num_experts / world;
let local_experts = layer.local_experts;
let expert_start = rank * local_experts;
// Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
// 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts]
let router_logits = add_bias(
&matmul_2d(x, &layer.router_wt),
&layer.router_bias,
);
unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); }
let router_cpu = router_logits.to_device(Device::Cpu);
let router_data = router_cpu.as_slice::<bf16>();
let x_cpu = x.to_device(Device::Cpu);
let x_data = x_cpu.as_slice::<bf16>();
let mut output_acc = vec![0.0f32; num_tokens * hidden];
// 2. GPU top-k + softmax
let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax(
&router_logits, num_experts, top_k,
);
for t in 0..num_tokens {
let row = &router_data[t * num_experts..(t + 1) * num_experts];
// 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden]
let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts);
// Find top-k expert indices (global)
let mut indices: Vec<(usize, f32)> = row.iter()
.enumerate()
.map(|(i, &v)| (i, v.to_f32()))
.collect();
indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_indices: Vec<(usize, f32)> = indices[..top_k].to_vec();
// 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter]
let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt);
// Softmax over top-k logits
let max_val = top_indices.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = top_indices.iter().map(|x| (x.1 - max_val).exp()).sum();
let weights: Vec<f32> = top_indices.iter()
.map(|x| (x.1 - max_val).exp() / exp_sum)
.collect();
// 5. Bias add: gate_up += expert_gate_up_bias (in-place)
xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias);
// Fresh GPU upload of token data — immune to cached allocator buffer reuse
let token_slice = &x_data[t * hidden..(t + 1) * hidden];
let token_tensor = Tensor::from_slice(token_slice, &[1, hidden]).to_device(x.device());
// 6. GLU activation: treat [E * tokens, 2*inter] → [E * tokens, inter]
let inter2 = gate_up.shape()[2];
let flat_rows = local_experts * num_tokens;
let gate_up_flat = gate_up.reshape(&[flat_rows, inter2]);
let activated_flat = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit);
let inter = inter2 / 2;
let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]);
// 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden]
let down = xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt);
for (k_idx, &(expert_id, _)) in top_indices.iter().enumerate() {
// Only process experts owned by this rank
if expert_id < expert_start || expert_id >= expert_start + local_experts {
continue;
}
let local_id = expert_id - expert_start;
let weight = weights[k_idx];
// 8. Bias add: down += expert_down_bias (in-place)
xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias);
let gate_up_raw = matmul_2d(&token_tensor, &layer.expert_gate_up_wt[local_id]);
let gate_up = add_bias(&gate_up_raw, &layer.expert_gate_up_bias[local_id]);
let activated = gpt_oss_glu(&gate_up, layer.glu_alpha, layer.glu_limit);
let down_raw = matmul_2d(&activated, &layer.expert_down_wt[local_id]);
let down = add_bias(&down_raw, &layer.expert_down_bias[local_id]);
let down_cpu = down.to_device(Device::Cpu);
let down_data = down_cpu.as_slice::<bf16>();
for d in 0..hidden {
output_acc[t * hidden + d] += weight * down_data[d].to_f32();
}
}
}
// Convert accumulated output to BF16 tensor on GPU
let output_bf16: Vec<bf16> = output_acc.iter().map(|&v| bf16::from_f32(v)).collect();
let moe_out = Tensor::from_slice(&output_bf16, &[num_tokens, hidden]).to_device(x.device());
// 9. Weighted sum across experts → [tokens, hidden]
let moe_out = xserv_kernels::moe::moe_weighted_sum(
&down, &topk_ids, &topk_weights,
expert_start, local_experts, top_k,
);
self.all_reduce(&moe_out);
moe_out
@@ -495,8 +516,15 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
let cols = x.shape()[1];
assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols);
// Broadcast bias to each row using GPU kernels.
// Tile bias [cols] into [rows, cols] by repeating rows, then add element-wise.
let x_c = x.contiguous();
if rows == 1 {
// Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU
let bias_2d = bias.reshape(&[1, cols]);
return xserv_kernels::add(&x_c, &bias_2d);
}
// General path: tile bias to [rows, cols] via CPU, then add on GPU
let bias_cpu = bias.to_device(Device::Cpu);
let bias_data = bias_cpu.as_slice::<bf16>();
let mut tiled = Vec::with_capacity(rows * cols);
@@ -504,7 +532,6 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor {
tiled.extend_from_slice(bias_data);
}
let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device());
let x_c = x.contiguous();
xserv_kernels::add(&x_c, &bias_tiled)
}
@@ -554,23 +581,23 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor {
Tensor::from_slice(&shard, &[local])
}
/// Extract expert `e` from a [num_experts, rows, cols] 3D tensor → [rows, cols] 2D
fn slice_expert_3d(t: &Tensor, e: usize, rows: usize, cols: usize) -> Tensor {
/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor
fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor {
assert_eq!(t.ndim(), 3);
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
let stride = rows * cols;
let start = e * stride;
let slice = data[start..start + stride].to_vec();
Tensor::from_slice(&slice, &[rows, cols])
let offset = start * stride;
let slice = data[offset..offset + count * stride].to_vec();
Tensor::from_slice(&slice, &[count, rows, cols])
}
/// Extract expert `e` from a [num_experts, dim] 2D tensor → [dim] 1D
fn slice_expert_2d(t: &Tensor, e: usize, dim: usize) -> Tensor {
/// Extract experts [start..start+count) from a [num_experts, dim] 2D tensor
fn slice_expert_range_2d(t: &Tensor, start: usize, count: usize, dim: usize) -> Tensor {
assert_eq!(t.ndim(), 2);
let host = t.to_device(Device::Cpu);
let data = host.as_slice::<bf16>();
let start = e * dim;
let slice = data[start..start + dim].to_vec();
Tensor::from_slice(&slice, &[dim])
let offset = start * dim;
let slice = data[offset..offset + count * dim].to_vec();
Tensor::from_slice(&slice, &[count, dim])
}

247
csrc/moe/moe_kernels.cu Normal file
View File

@@ -0,0 +1,247 @@
#include <cuda_bf16.h>
#include <float.h>
#include "../common.cuh"
// ============================================================
// MoE Top-K + Softmax kernel
//
// Input: router_logits [num_tokens, num_experts] BF16
// Output: topk_ids [num_tokens, top_k] int32
// topk_weights [num_tokens, top_k] float32
//
// One block per token. Threads cooperatively find top-k indices
// via repeated argmax, then compute softmax over the k winners.
// num_experts <= 256 (fits in registers / shared memory).
// ============================================================
#define MOE_MAX_EXPERTS 256
#define MOE_MAX_TOPK 8
__global__ void moe_topk_softmax_bf16_kernel(
const __nv_bfloat16* __restrict__ router_logits,
int* __restrict__ topk_ids,
float* __restrict__ topk_weights,
int num_experts, int top_k
) {
int token = blockIdx.x;
int tid = threadIdx.x;
const __nv_bfloat16* row = router_logits + token * num_experts;
// Load logits into shared memory
__shared__ float smem_logits[MOE_MAX_EXPERTS];
__shared__ int smem_ids[MOE_MAX_TOPK];
__shared__ float smem_vals[MOE_MAX_TOPK];
for (int i = tid; i < num_experts; i += blockDim.x) {
smem_logits[i] = __bfloat162float(row[i]);
}
__syncthreads();
// Find top-k via repeated argmax (k is small, typically 4)
if (tid == 0) {
for (int k = 0; k < top_k; k++) {
float best_val = -INFINITY;
int best_idx = 0;
for (int e = 0; e < num_experts; e++) {
if (smem_logits[e] > best_val) {
best_val = smem_logits[e];
best_idx = e;
}
}
smem_ids[k] = best_idx;
smem_vals[k] = best_val;
smem_logits[best_idx] = -INFINITY; // mask out selected
}
// Softmax over top-k values (in FP32)
float max_val = smem_vals[0];
for (int k = 1; k < top_k; k++)
max_val = fmaxf(max_val, smem_vals[k]);
float exp_sum = 0.0f;
for (int k = 0; k < top_k; k++) {
smem_vals[k] = expf(smem_vals[k] - max_val);
exp_sum += smem_vals[k];
}
float inv_sum = 1.0f / exp_sum;
for (int k = 0; k < top_k; k++)
smem_vals[k] *= inv_sum;
// Write outputs
for (int k = 0; k < top_k; k++) {
topk_ids[token * top_k + k] = smem_ids[k];
topk_weights[token * top_k + k] = smem_vals[k];
}
}
}
// ============================================================
// MoE Replicate kernel
//
// Input: x [num_tokens, hidden] BF16
// Output: x_rep [local_experts, num_tokens, hidden] BF16
//
// Copies x into each expert's batch slot.
// ============================================================
__global__ void moe_replicate_bf16_kernel(
const __nv_bfloat16* __restrict__ x,
__nv_bfloat16* __restrict__ x_rep,
int num_tokens, int hidden, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = local_experts * num_tokens * hidden;
if (idx >= total) return;
int expert = idx / (num_tokens * hidden);
int remainder = idx % (num_tokens * hidden);
// x_rep[expert, token, dim] = x[token, dim]
x_rep[idx] = x[remainder];
(void)expert; // suppress unused warning
}
// ============================================================
// MoE Bias Add 3D kernel
//
// Input: x [batch, num_tokens, dim] BF16 (in-place output)
// bias [batch, dim] BF16
//
// x[b, t, d] += bias[b, d]
// ============================================================
__global__ void moe_bias_add_3d_bf16_kernel(
__nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ bias,
int batch, int num_tokens, int dim
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch * num_tokens * dim;
if (idx >= total) return;
int b = idx / (num_tokens * dim);
int d = idx % dim;
float v = __bfloat162float(x[idx]) + __bfloat162float(bias[b * dim + d]);
x[idx] = __float2bfloat16(v);
}
// ============================================================
// MoE Weighted Sum kernel
//
// Input: expert_out [local_experts, num_tokens, hidden] BF16
// topk_ids [num_tokens, top_k] int32 (global expert ids)
// topk_weights[num_tokens, top_k] float32
// expert_start: first global expert id this rank owns
// local_experts: number of experts this rank owns
//
// Output: out [num_tokens, hidden] BF16
//
// For each (token, dim): accumulate in FP32:
// sum = 0
// for k in 0..top_k:
// global_id = topk_ids[token, k]
// if global_id in [expert_start, expert_start + local_experts):
// local_id = global_id - expert_start
// sum += topk_weights[token, k] * expert_out[local_id, token, dim]
// out[token, dim] = bf16(sum)
// ============================================================
__global__ void moe_weighted_sum_bf16_kernel(
const __nv_bfloat16* __restrict__ expert_out,
const int* __restrict__ topk_ids,
const float* __restrict__ topk_weights,
__nv_bfloat16* __restrict__ out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = num_tokens * hidden;
if (idx >= total) return;
int token = idx / hidden;
int dim = idx % hidden;
int expert_stride = num_tokens * hidden; // stride between experts in expert_out
float sum = 0.0f;
for (int k = 0; k < top_k; k++) {
int global_id = topk_ids[token * top_k + k];
int local_id = global_id - expert_start;
if (local_id >= 0 && local_id < local_experts) {
float w = topk_weights[token * top_k + k];
float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]);
sum += w * v;
}
}
out[idx] = __float2bfloat16(sum);
}
extern "C" {
void launch_moe_topk_softmax_bf16(
const void* router_logits,
void* topk_ids, void* topk_weights,
int num_tokens, int num_experts, int top_k,
void* stream
) {
int block = 128;
moe_topk_softmax_bf16_kernel<<<num_tokens, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)router_logits,
(int*)topk_ids, (float*)topk_weights,
num_experts, top_k
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_replicate_bf16(
const void* x, void* x_rep,
int num_tokens, int hidden, int local_experts,
void* stream
) {
int total = local_experts * num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_replicate_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep,
num_tokens, hidden, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_bias_add_3d_bf16(
void* x, const void* bias,
int batch, int num_tokens, int dim,
void* stream
) {
int total = batch * num_tokens * dim;
int block = 256;
int grid = (total + block - 1) / block;
moe_bias_add_3d_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(__nv_bfloat16*)x, (const __nv_bfloat16*)bias,
batch, num_tokens, dim
);
CUDA_CHECK_LAST_ERROR();
}
void launch_moe_weighted_sum_bf16(
const void* expert_out,
const void* topk_ids, const void* topk_weights,
void* out,
int num_tokens, int hidden, int top_k,
int expert_start, int local_experts,
void* stream
) {
int total = num_tokens * hidden;
int block = 256;
int grid = (total + block - 1) / block;
moe_weighted_sum_bf16_kernel<<<grid, block, 0, (cudaStream_t)stream>>>(
(const __nv_bfloat16*)expert_out,
(const int*)topk_ids, (const float*)topk_weights,
(__nv_bfloat16*)out,
num_tokens, hidden, top_k,
expert_start, local_experts
);
CUDA_CHECK_LAST_ERROR();
}
}