diff --git a/crates/xserv-kernels/build.rs b/crates/xserv-kernels/build.rs index 6d7ab3e..4512550 100644 --- a/crates/xserv-kernels/build.rs +++ b/crates/xserv-kernels/build.rs @@ -29,6 +29,7 @@ fn main() { .file("../../csrc/attention/flash_attention.cu") .file("../../csrc/attention/paged_attention.cu") .file("../../csrc/attention/reshape_and_cache.cu") + .file("../../csrc/moe/moe_kernels.cu") .compile("xserv_kernels"); println!("cargo:rerun-if-changed=../../csrc/"); diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 1b1fd80..ee19356 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -5,6 +5,7 @@ pub mod dispatch; pub mod embedding; pub mod gemm; pub mod layernorm; +pub mod moe; pub mod rmsnorm; pub mod rope; pub mod softmax; diff --git a/crates/xserv-kernels/src/moe.rs b/crates/xserv-kernels/src/moe.rs new file mode 100644 index 0000000..5e3d04d --- /dev/null +++ b/crates/xserv-kernels/src/moe.rs @@ -0,0 +1,223 @@ +use std::ffi::c_void; +use xserv_tensor::{DType, Device, Tensor}; + +use crate::gemm::{cublas_handle, CublasHandle}; + +unsafe extern "C" { + fn launch_moe_topk_softmax_bf16( + router_logits: *const c_void, + topk_ids: *mut c_void, topk_weights: *mut c_void, + num_tokens: i32, num_experts: i32, top_k: i32, + stream: *mut c_void, + ); + fn launch_moe_replicate_bf16( + x: *const c_void, x_rep: *mut c_void, + num_tokens: i32, hidden: i32, local_experts: i32, + stream: *mut c_void, + ); + fn launch_moe_bias_add_3d_bf16( + x: *mut c_void, bias: *const c_void, + batch: i32, num_tokens: i32, dim: i32, + stream: *mut c_void, + ); + fn launch_moe_weighted_sum_bf16( + expert_out: *const c_void, + topk_ids: *const c_void, topk_weights: *const c_void, + out: *mut c_void, + num_tokens: i32, hidden: i32, top_k: i32, + expert_start: i32, local_experts: i32, + stream: *mut c_void, + ); + + fn cublasGemmStridedBatchedEx( + handle: CublasHandle, + transa: i32, transb: i32, + m: i32, n: i32, k: i32, + alpha: *const c_void, + a: *const c_void, a_type: i32, lda: i32, stride_a: i64, + b: *const c_void, b_type: i32, ldb: i32, stride_b: i64, + beta: *const c_void, + c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64, + batch_count: i32, + compute_type: i32, + algo: i32, + ) -> i32; + + fn cublasSetStream_v2(handle: CublasHandle, stream: *mut c_void) -> i32; +} + +const CUDA_R_16BF: i32 = 14; +const CUBLAS_COMPUTE_32F: i32 = 68; +const CUBLAS_GEMM_DEFAULT: i32 = -1; + +/// GPU top-k selection + softmax over router logits. +/// +/// Input: router_logits [num_tokens, num_experts] BF16 on GPU +/// Output: (topk_ids [num_tokens, top_k] i32, topk_weights [num_tokens, top_k] f32) +pub fn moe_topk_softmax( + router_logits: &Tensor, + num_experts: usize, + top_k: usize, +) -> (Tensor, Tensor) { + assert_eq!(router_logits.ndim(), 2); + assert_eq!(router_logits.dtype(), DType::BF16); + assert!(router_logits.is_contiguous()); + let num_tokens = router_logits.shape()[0]; + assert_eq!(router_logits.shape()[1], num_experts); + + let topk_ids = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device()); + let topk_weights = Tensor::empty(&[num_tokens, top_k], DType::F32, router_logits.device()); + + unsafe { + launch_moe_topk_softmax_bf16( + router_logits.data_ptr() as *const c_void, + topk_ids.data_ptr() as *mut c_void, + topk_weights.data_ptr() as *mut c_void, + num_tokens as i32, num_experts as i32, top_k as i32, + std::ptr::null_mut(), + ); + } + + (topk_ids, topk_weights) +} + +/// Replicate x [num_tokens, hidden] → [local_experts, num_tokens, hidden]. +pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor { + assert_eq!(x.ndim(), 2); + assert_eq!(x.dtype(), DType::BF16); + assert!(x.is_contiguous()); + let num_tokens = x.shape()[0]; + let hidden = x.shape()[1]; + let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device()); + + unsafe { + launch_moe_replicate_bf16( + x.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + num_tokens as i32, hidden as i32, local_experts as i32, + std::ptr::null_mut(), + ); + } + + out +} + +/// In-place 3D bias add: x [batch, num_tokens, dim] += bias [batch, dim]. +pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) { + assert_eq!(x.ndim(), 3); + assert_eq!(bias.ndim(), 2); + assert_eq!(x.dtype(), DType::BF16); + assert!(x.is_contiguous()); + let batch = x.shape()[0]; + let num_tokens = x.shape()[1]; + let dim = x.shape()[2]; + assert_eq!(bias.shape(), &[batch, dim]); + + unsafe { + launch_moe_bias_add_3d_bf16( + x.data_ptr() as *mut c_void, + bias.data_ptr() as *const c_void, + batch as i32, num_tokens as i32, dim as i32, + std::ptr::null_mut(), + ); + } +} + +/// Weighted sum of expert outputs → [num_tokens, hidden]. +/// +/// expert_out: [local_experts, num_tokens, hidden] BF16 +/// topk_ids: [num_tokens, top_k] i32 (global expert indices) +/// topk_weights: [num_tokens, top_k] f32 +pub fn moe_weighted_sum( + expert_out: &Tensor, + topk_ids: &Tensor, + topk_weights: &Tensor, + expert_start: usize, + local_experts: usize, + top_k: usize, +) -> Tensor { + assert_eq!(expert_out.ndim(), 3); + assert_eq!(expert_out.dtype(), DType::BF16); + let num_tokens = expert_out.shape()[1]; + let hidden = expert_out.shape()[2]; + + let out = Tensor::empty(&[num_tokens, hidden], DType::BF16, expert_out.device()); + + unsafe { + launch_moe_weighted_sum_bf16( + expert_out.data_ptr() as *const c_void, + topk_ids.data_ptr() as *const c_void, + topk_weights.data_ptr() as *const c_void, + out.data_ptr() as *mut c_void, + num_tokens as i32, hidden as i32, top_k as i32, + expert_start as i32, local_experts as i32, + std::ptr::null_mut(), + ); + } + + out +} + +/// Strided batched GEMM for MoE expert forward. +/// C[b] = A[b] @ B[b] for b in 0..batch +/// +/// A: [batch, M, K] BF16 contiguous +/// B: [batch, K, N] BF16 contiguous +/// Returns C: [batch, M, N] BF16 +#[allow(clippy::too_many_arguments)] +pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor { + assert_eq!(a.ndim(), 3); + assert_eq!(b.ndim(), 3); + assert_eq!(a.dtype(), DType::BF16); + assert_eq!(b.dtype(), DType::BF16); + assert!(a.is_contiguous() && b.is_contiguous()); + assert_eq!(a.shape()[0], b.shape()[0]); + assert_eq!(a.shape()[2], b.shape()[1]); + + let batch = a.shape()[0]; + let m = a.shape()[1]; + let k = a.shape()[2]; + let n = b.shape()[2]; + + let c = Tensor::empty(&[batch, m, n], DType::BF16, a.device()); + + let alpha: f32 = 1.0; + let beta: f32 = 0.0; + + // cuBLAS column-major: we compute C^T = B^T @ A^T + // A is [batch, M, K] row-major → A^T is [K, M] col-major, lda=K + // B is [batch, K, N] row-major → B^T is [N, K] col-major, ldb=N? No... + // + // Actually for row-major: A[M,K] in memory = col-major A^T[K,M] with lda=K. + // So we call cublasGemmStridedBatchedEx with: + // transa=N, transb=N + // m=N, n=M, k=K (because cuBLAS sees col-major) + // A_cublas = B_row (pointer), lda=N + // B_cublas = A_row (pointer), ldb=K + // C_cublas = C_row (pointer), ldc=N + + let stride_a = (m * k) as i64; + let stride_b = (k * n) as i64; + let stride_c = (m * n) as i64; + + let handle = cublas_handle(); + unsafe { + cublasSetStream_v2(handle, std::ptr::null_mut()); + let status = cublasGemmStridedBatchedEx( + handle, + 0, 0, // CUBLAS_OP_N, CUBLAS_OP_N + n as i32, m as i32, k as i32, + &alpha as *const f32 as *const c_void, + b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b, + a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a, + &beta as *const f32 as *const c_void, + c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c, + batch as i32, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT, + ); + assert_eq!(status, 0, "cublasGemmStridedBatchedEx failed: {status}"); + } + + c +} diff --git a/crates/xserv-model/src/config.rs b/crates/xserv-model/src/config.rs index 927d618..1225156 100644 --- a/crates/xserv-model/src/config.rs +++ b/crates/xserv-model/src/config.rs @@ -73,6 +73,10 @@ pub struct ModelConfig { pub rope_scaling: Option, #[serde(default)] pub swiglu_limit: Option, + #[serde(default)] + pub geglu_alpha: Option, + #[serde(default)] + pub hidden_act: Option, } impl ModelConfig { @@ -144,4 +148,8 @@ impl ModelConfig { pub fn window_size(&self) -> usize { self.sliding_window.unwrap_or(0) } + + pub fn geglu_alpha(&self) -> f32 { + self.geglu_alpha.unwrap_or(1.702) as f32 + } } diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index ce1d007..ac15705 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -12,15 +12,18 @@ pub struct GptOss { embed_tokens: Tensor, layers: Vec, norm: Tensor, + norm_bias: Option, lm_head_t: Tensor, rope_cache: RopeCache, tp: Option>, local_num_heads: usize, local_num_kv_heads: usize, + has_norm_bias: bool, } struct GptOssBlock { input_norm: Tensor, + input_norm_bias: Option, // Attention (with bias) q_proj_wt: Tensor, q_proj_bias: Tensor, @@ -36,12 +39,15 @@ struct GptOssBlock { window_size: usize, // MoE MLP post_norm: Tensor, + post_norm_bias: Option, router_wt: Tensor, router_bias: Tensor, - expert_gate_up_wt: Vec, - expert_gate_up_bias: Vec, - expert_down_wt: Vec, - expert_down_bias: Vec, + // 3D expert weights for batched GEMM (contiguous on GPU) + expert_gate_up_wt: Tensor, // [local_experts, hidden, 2*inter] + expert_gate_up_bias: Tensor, // [local_experts, 2*inter] + expert_down_wt: Tensor, // [local_experts, inter, hidden] + expert_down_bias: Tensor, // [local_experts, hidden] + local_experts: usize, // Activation params glu_alpha: f32, glu_limit: f32, @@ -80,6 +86,7 @@ impl GptOss { let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight")); let norm = repl(take(&mut w, "model.norm.weight")); + let norm_bias = w.remove("model.norm.bias").map(|t| repl(t)); let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous(); let head_dim = config.head_dim(); @@ -106,13 +113,13 @@ impl GptOss { let num_layers = config.num_layers(); let num_experts = config.num_experts(); - let glu_alpha = 1.702f32; + let glu_alpha = config.geglu_alpha(); let glu_limit = config.swiglu_limit.unwrap_or(7.0) as f32; let mut layers = Vec::with_capacity(num_layers); if rank == 0 { eprintln!( - "Loading gpt-oss weights: {} layers, {} experts, world={world}...", + "Loading gpt-oss weights: {} layers, {} experts, world={world}, glu_alpha={glu_alpha}...", num_layers, num_experts ); } @@ -152,34 +159,26 @@ impl GptOss { let local_experts = num_experts / world; let expert_start = rank * local_experts; - let mut expert_gate_up_wt = Vec::with_capacity(local_experts); - let mut expert_gate_up_bias = Vec::with_capacity(local_experts); - let mut expert_down_wt = Vec::with_capacity(local_experts); - let mut expert_down_bias = Vec::with_capacity(local_experts); - let inter2 = gate_up_3d.shape()[2]; // 2 * intermediate_size let hidden = gate_up_3d.shape()[1]; let inter = down_3d.shape()[1]; // intermediate_size - for local_e in 0..local_experts { - let e = expert_start + local_e; - let gu_slice = slice_expert_3d(&gate_up_3d, e, hidden, inter2); - expert_gate_up_wt.push(gu_slice.to_device(dev)); - - let gu_bias = slice_expert_2d(&gate_up_bias_2d, e, inter2); - expert_gate_up_bias.push(gu_bias.to_device(dev)); - - let d_slice = slice_expert_3d(&down_3d, e, inter, hidden); - expert_down_wt.push(d_slice.to_device(dev)); - - let d_bias = slice_expert_2d(&down_bias_2d, e, hidden); - expert_down_bias.push(d_bias.to_device(dev)); - } + // Slice the rank's range of experts as contiguous 3D tensors on GPU + let expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev); + let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev); + let expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev); + let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev); xserv_cuda::allocator::cached_trim(); + let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight"))); + let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t)); + let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))); + let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t)); + layers.push(GptOssBlock { - input_norm: repl(take(&mut w, &format!("{p}.input_layernorm.weight"))), + input_norm, + input_norm_bias, q_proj_wt, q_proj_bias, k_proj_wt, @@ -191,13 +190,15 @@ impl GptOss { sinks, is_sliding, window_size, - post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))), + post_norm, + post_norm_bias, router_wt, router_bias, expert_gate_up_wt, expert_gate_up_bias, expert_down_wt, expert_down_bias, + local_experts, glu_alpha, glu_limit, }); @@ -206,16 +207,35 @@ impl GptOss { let local_num_heads = config.num_heads() / world; let local_num_kv_heads = config.num_kv_heads() / world; + let has_norm_bias = norm_bias.is_some(); + if rank == 0 { + if has_norm_bias { + eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm"); + } + } + + // Warn about unused weights that the model didn't consume + if rank == 0 && !w.is_empty() { + eprintln!("WARNING: {} unused weight(s) in model:", w.len()); + let mut keys: Vec<_> = w.keys().collect(); + keys.sort(); + for k in &keys { + eprintln!(" {k}"); + } + } + Self { config, embed_tokens, layers, norm, + norm_bias, lm_head_t, rope_cache, tp, local_num_heads, local_num_kv_heads, + has_norm_bias, } } @@ -229,6 +249,34 @@ impl GptOss { } } + #[inline] + fn norm(x: &Tensor, weight: &Tensor, bias: &Option, eps: f32) -> Tensor { + match bias { + Some(b) => layernorm(x, weight, b, eps), + None => rmsnorm(x, weight, eps), + } + } + + #[inline] + fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option, eps: f32) -> (Tensor, Tensor) { + match bias { + Some(b) => { + let sum = xserv_kernels::add(x, residual); + let normed = layernorm(&sum, weight, b, eps); + (normed, sum) + } + None => add_rmsnorm(x, residual, weight, eps), + } + } + + fn norm_eps(&self) -> f32 { + if self.has_norm_bias { + self.config.ln_eps() + } else { + self.config.rms_norm_eps.unwrap_or(1e-5) as f32 + } + } + /// Paged decode: process one token per sequence using paged KV cache. pub fn forward_decode_paged( &self, @@ -245,7 +293,7 @@ impl GptOss { let num_heads = self.local_num_heads; let num_kv_heads = self.local_num_kv_heads; let head_dim = self.config.head_dim(); - let eps = self.config.rms_norm_eps.unwrap_or(1e-5) as f32; + let eps = self.norm_eps(); let kv_lens: Vec = positions.iter().map(|&p| (p + 1) as i32).collect(); for (b, &slot) in seq_slots.iter().enumerate() { @@ -263,7 +311,7 @@ impl GptOss { for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); - let normed = rmsnorm(&x, &layer.input_norm, eps); + let normed = Self::norm(&x, &layer.input_norm, &layer.input_norm_bias, eps); // Q/K/V projections with bias let q_all = add_bias(&matmul_2d(&normed, &layer.q_proj_wt), &layer.q_proj_bias); @@ -304,13 +352,11 @@ impl GptOss { // Residual + post-norm - let (normed, x_new) = add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); - + let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps); let residual = x_new; let normed = normed.contiguous(); - // MoE MLP let moe_out = self.moe_forward(&normed, layer, batch); x = xserv_kernels::add(&residual, &moe_out); @@ -322,7 +368,7 @@ impl GptOss { } unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } - let x = rmsnorm(&x, &self.norm, eps); + let x = Self::norm(&x, &self.norm, &self.norm_bias, eps); unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } let logits = matmul_2d(&x, &self.lm_head_t); unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } @@ -341,7 +387,7 @@ impl GptOss { let num_heads = self.local_num_heads; let num_kv_heads = self.local_num_kv_heads; let head_dim = self.config.head_dim(); - let eps = self.config.rms_norm_eps.unwrap_or(1e-5) as f32; + let eps = self.norm_eps(); paged_cache.ensure_capacity(slot, pos_offset + new_tokens); paged_cache.advance_seq_len(slot, new_tokens); @@ -351,7 +397,7 @@ impl GptOss { for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); - let normed = rmsnorm(&x, &layer.input_norm, eps); + let normed = Self::norm(&x, &layer.input_norm, &layer.input_norm_bias, eps); let q = add_bias(&matmul_2d(&normed, &layer.q_proj_wt), &layer.q_proj_bias); let k = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias); @@ -381,7 +427,7 @@ impl GptOss { self.all_reduce(&attn_proj); let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias); - let (normed, x_new) = add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps); let residual = x_new; // MoE MLP @@ -389,90 +435,65 @@ impl GptOss { x = xserv_kernels::add(&residual, &moe_out); } - let x = rmsnorm(&x, &self.norm, eps); + let x = Self::norm(&x, &self.norm, &self.norm_bias, eps); let logits = matmul_2d(&x, &self.lm_head_t); unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } logits } - /// MoE forward pass for one layer with expert parallelism. - /// Each rank owns `num_experts / world` experts. Tokens routed to non-local - /// experts get zero contribution from this rank; AllReduce sums all ranks. - /// Input: [tokens, hidden], Output: [tokens, hidden] + /// MoE forward pass — fully on GPU via batched GEMM. + /// + /// Each rank owns `local_experts` experts. The input is replicated across all + /// local experts, processed via two batched cuBLAS GEMMs (gate_up and down), + /// and the selected experts' outputs are weighted-summed on GPU. Non-selected + /// experts contribute zero (via the routing weights), so no scatter/gather is + /// needed. AllReduce sums partial results across TP ranks. fn moe_forward(&self, x: &Tensor, layer: &GptOssBlock, num_tokens: usize) -> Tensor { - let hidden = self.config.hidden(); let num_experts = self.config.num_experts(); let top_k = self.config.experts_per_token(); - let world = self.tp.as_ref().map(|tp| tp.world).unwrap_or(1); let rank = self.tp.as_ref().map(|tp| tp.rank).unwrap_or(0); - let local_experts = num_experts / world; + let local_experts = layer.local_experts; let expert_start = rank * local_experts; - // Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts] - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } + // 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts] let router_logits = add_bias( &matmul_2d(x, &layer.router_wt), &layer.router_bias, ); - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } - let router_cpu = router_logits.to_device(Device::Cpu); - let router_data = router_cpu.as_slice::(); - let x_cpu = x.to_device(Device::Cpu); - let x_data = x_cpu.as_slice::(); - let mut output_acc = vec![0.0f32; num_tokens * hidden]; + // 2. GPU top-k + softmax + let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax( + &router_logits, num_experts, top_k, + ); - for t in 0..num_tokens { - let row = &router_data[t * num_experts..(t + 1) * num_experts]; + // 3. Replicate input: [tokens, hidden] → [local_experts, tokens, hidden] + let x_rep = xserv_kernels::moe::moe_replicate(x, local_experts); - // Find top-k expert indices (global) - let mut indices: Vec<(usize, f32)> = row.iter() - .enumerate() - .map(|(i, &v)| (i, v.to_f32())) - .collect(); - indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - let top_indices: Vec<(usize, f32)> = indices[..top_k].to_vec(); + // 4. Batched GEMM gate_up: [E, tokens, hidden] @ [E, hidden, 2*inter] → [E, tokens, 2*inter] + let gate_up = xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt); - // Softmax over top-k logits - let max_val = top_indices.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max); - let exp_sum: f32 = top_indices.iter().map(|x| (x.1 - max_val).exp()).sum(); - let weights: Vec = top_indices.iter() - .map(|x| (x.1 - max_val).exp() / exp_sum) - .collect(); + // 5. Bias add: gate_up += expert_gate_up_bias (in-place) + xserv_kernels::moe::moe_bias_add_3d(&gate_up, &layer.expert_gate_up_bias); - // Fresh GPU upload of token data — immune to cached allocator buffer reuse - let token_slice = &x_data[t * hidden..(t + 1) * hidden]; - let token_tensor = Tensor::from_slice(token_slice, &[1, hidden]).to_device(x.device()); + // 6. GLU activation: treat [E * tokens, 2*inter] → [E * tokens, inter] + let inter2 = gate_up.shape()[2]; + let flat_rows = local_experts * num_tokens; + let gate_up_flat = gate_up.reshape(&[flat_rows, inter2]); + let activated_flat = gpt_oss_glu(&gate_up_flat, layer.glu_alpha, layer.glu_limit); + let inter = inter2 / 2; + let activated = activated_flat.reshape(&[local_experts, num_tokens, inter]); + // 7. Batched GEMM down: [E, tokens, inter] @ [E, inter, hidden] → [E, tokens, hidden] + let down = xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt); - for (k_idx, &(expert_id, _)) in top_indices.iter().enumerate() { - // Only process experts owned by this rank - if expert_id < expert_start || expert_id >= expert_start + local_experts { - continue; - } - let local_id = expert_id - expert_start; - let weight = weights[k_idx]; + // 8. Bias add: down += expert_down_bias (in-place) + xserv_kernels::moe::moe_bias_add_3d(&down, &layer.expert_down_bias); - let gate_up_raw = matmul_2d(&token_tensor, &layer.expert_gate_up_wt[local_id]); - let gate_up = add_bias(&gate_up_raw, &layer.expert_gate_up_bias[local_id]); - - let activated = gpt_oss_glu(&gate_up, layer.glu_alpha, layer.glu_limit); - - let down_raw = matmul_2d(&activated, &layer.expert_down_wt[local_id]); - let down = add_bias(&down_raw, &layer.expert_down_bias[local_id]); - - - let down_cpu = down.to_device(Device::Cpu); - let down_data = down_cpu.as_slice::(); - for d in 0..hidden { - output_acc[t * hidden + d] += weight * down_data[d].to_f32(); - } - } - } - - // Convert accumulated output to BF16 tensor on GPU - let output_bf16: Vec = output_acc.iter().map(|&v| bf16::from_f32(v)).collect(); - let moe_out = Tensor::from_slice(&output_bf16, &[num_tokens, hidden]).to_device(x.device()); + // 9. Weighted sum across experts → [tokens, hidden] + let moe_out = xserv_kernels::moe::moe_weighted_sum( + &down, &topk_ids, &topk_weights, + expert_start, local_experts, top_k, + ); self.all_reduce(&moe_out); moe_out @@ -495,8 +516,15 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { let cols = x.shape()[1]; assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols); - // Broadcast bias to each row using GPU kernels. - // Tile bias [cols] into [rows, cols] by repeating rows, then add element-wise. + let x_c = x.contiguous(); + + if rows == 1 { + // Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU + let bias_2d = bias.reshape(&[1, cols]); + return xserv_kernels::add(&x_c, &bias_2d); + } + + // General path: tile bias to [rows, cols] via CPU, then add on GPU let bias_cpu = bias.to_device(Device::Cpu); let bias_data = bias_cpu.as_slice::(); let mut tiled = Vec::with_capacity(rows * cols); @@ -504,7 +532,6 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { tiled.extend_from_slice(bias_data); } let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device()); - let x_c = x.contiguous(); xserv_kernels::add(&x_c, &bias_tiled) } @@ -554,23 +581,23 @@ fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor { Tensor::from_slice(&shard, &[local]) } -/// Extract expert `e` from a [num_experts, rows, cols] 3D tensor → [rows, cols] 2D -fn slice_expert_3d(t: &Tensor, e: usize, rows: usize, cols: usize) -> Tensor { +/// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor +fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor { assert_eq!(t.ndim(), 3); let host = t.to_device(Device::Cpu); let data = host.as_slice::(); let stride = rows * cols; - let start = e * stride; - let slice = data[start..start + stride].to_vec(); - Tensor::from_slice(&slice, &[rows, cols]) + let offset = start * stride; + let slice = data[offset..offset + count * stride].to_vec(); + Tensor::from_slice(&slice, &[count, rows, cols]) } -/// Extract expert `e` from a [num_experts, dim] 2D tensor → [dim] 1D -fn slice_expert_2d(t: &Tensor, e: usize, dim: usize) -> Tensor { +/// Extract experts [start..start+count) from a [num_experts, dim] 2D tensor +fn slice_expert_range_2d(t: &Tensor, start: usize, count: usize, dim: usize) -> Tensor { assert_eq!(t.ndim(), 2); let host = t.to_device(Device::Cpu); let data = host.as_slice::(); - let start = e * dim; - let slice = data[start..start + dim].to_vec(); - Tensor::from_slice(&slice, &[dim]) + let offset = start * dim; + let slice = data[offset..offset + count * dim].to_vec(); + Tensor::from_slice(&slice, &[count, dim]) } diff --git a/csrc/moe/moe_kernels.cu b/csrc/moe/moe_kernels.cu new file mode 100644 index 0000000..9b58552 --- /dev/null +++ b/csrc/moe/moe_kernels.cu @@ -0,0 +1,247 @@ +#include +#include +#include "../common.cuh" + +// ============================================================ +// MoE Top-K + Softmax kernel +// +// Input: router_logits [num_tokens, num_experts] BF16 +// Output: topk_ids [num_tokens, top_k] int32 +// topk_weights [num_tokens, top_k] float32 +// +// One block per token. Threads cooperatively find top-k indices +// via repeated argmax, then compute softmax over the k winners. +// num_experts <= 256 (fits in registers / shared memory). +// ============================================================ + +#define MOE_MAX_EXPERTS 256 +#define MOE_MAX_TOPK 8 + +__global__ void moe_topk_softmax_bf16_kernel( + const __nv_bfloat16* __restrict__ router_logits, + int* __restrict__ topk_ids, + float* __restrict__ topk_weights, + int num_experts, int top_k +) { + int token = blockIdx.x; + int tid = threadIdx.x; + const __nv_bfloat16* row = router_logits + token * num_experts; + + // Load logits into shared memory + __shared__ float smem_logits[MOE_MAX_EXPERTS]; + __shared__ int smem_ids[MOE_MAX_TOPK]; + __shared__ float smem_vals[MOE_MAX_TOPK]; + + for (int i = tid; i < num_experts; i += blockDim.x) { + smem_logits[i] = __bfloat162float(row[i]); + } + __syncthreads(); + + // Find top-k via repeated argmax (k is small, typically 4) + if (tid == 0) { + for (int k = 0; k < top_k; k++) { + float best_val = -INFINITY; + int best_idx = 0; + for (int e = 0; e < num_experts; e++) { + if (smem_logits[e] > best_val) { + best_val = smem_logits[e]; + best_idx = e; + } + } + smem_ids[k] = best_idx; + smem_vals[k] = best_val; + smem_logits[best_idx] = -INFINITY; // mask out selected + } + + // Softmax over top-k values (in FP32) + float max_val = smem_vals[0]; + for (int k = 1; k < top_k; k++) + max_val = fmaxf(max_val, smem_vals[k]); + + float exp_sum = 0.0f; + for (int k = 0; k < top_k; k++) { + smem_vals[k] = expf(smem_vals[k] - max_val); + exp_sum += smem_vals[k]; + } + float inv_sum = 1.0f / exp_sum; + for (int k = 0; k < top_k; k++) + smem_vals[k] *= inv_sum; + + // Write outputs + for (int k = 0; k < top_k; k++) { + topk_ids[token * top_k + k] = smem_ids[k]; + topk_weights[token * top_k + k] = smem_vals[k]; + } + } +} + +// ============================================================ +// MoE Replicate kernel +// +// Input: x [num_tokens, hidden] BF16 +// Output: x_rep [local_experts, num_tokens, hidden] BF16 +// +// Copies x into each expert's batch slot. +// ============================================================ + +__global__ void moe_replicate_bf16_kernel( + const __nv_bfloat16* __restrict__ x, + __nv_bfloat16* __restrict__ x_rep, + int num_tokens, int hidden, int local_experts +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = local_experts * num_tokens * hidden; + if (idx >= total) return; + + int expert = idx / (num_tokens * hidden); + int remainder = idx % (num_tokens * hidden); + // x_rep[expert, token, dim] = x[token, dim] + x_rep[idx] = x[remainder]; + (void)expert; // suppress unused warning +} + +// ============================================================ +// MoE Bias Add 3D kernel +// +// Input: x [batch, num_tokens, dim] BF16 (in-place output) +// bias [batch, dim] BF16 +// +// x[b, t, d] += bias[b, d] +// ============================================================ + +__global__ void moe_bias_add_3d_bf16_kernel( + __nv_bfloat16* __restrict__ x, + const __nv_bfloat16* __restrict__ bias, + int batch, int num_tokens, int dim +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch * num_tokens * dim; + if (idx >= total) return; + + int b = idx / (num_tokens * dim); + int d = idx % dim; + float v = __bfloat162float(x[idx]) + __bfloat162float(bias[b * dim + d]); + x[idx] = __float2bfloat16(v); +} + +// ============================================================ +// MoE Weighted Sum kernel +// +// Input: expert_out [local_experts, num_tokens, hidden] BF16 +// topk_ids [num_tokens, top_k] int32 (global expert ids) +// topk_weights[num_tokens, top_k] float32 +// expert_start: first global expert id this rank owns +// local_experts: number of experts this rank owns +// +// Output: out [num_tokens, hidden] BF16 +// +// For each (token, dim): accumulate in FP32: +// sum = 0 +// for k in 0..top_k: +// global_id = topk_ids[token, k] +// if global_id in [expert_start, expert_start + local_experts): +// local_id = global_id - expert_start +// sum += topk_weights[token, k] * expert_out[local_id, token, dim] +// out[token, dim] = bf16(sum) +// ============================================================ + +__global__ void moe_weighted_sum_bf16_kernel( + const __nv_bfloat16* __restrict__ expert_out, + const int* __restrict__ topk_ids, + const float* __restrict__ topk_weights, + __nv_bfloat16* __restrict__ out, + int num_tokens, int hidden, int top_k, + int expert_start, int local_experts +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int total = num_tokens * hidden; + if (idx >= total) return; + + int token = idx / hidden; + int dim = idx % hidden; + + int expert_stride = num_tokens * hidden; // stride between experts in expert_out + + float sum = 0.0f; + for (int k = 0; k < top_k; k++) { + int global_id = topk_ids[token * top_k + k]; + int local_id = global_id - expert_start; + if (local_id >= 0 && local_id < local_experts) { + float w = topk_weights[token * top_k + k]; + float v = __bfloat162float(expert_out[local_id * expert_stride + token * hidden + dim]); + sum += w * v; + } + } + out[idx] = __float2bfloat16(sum); +} + + +extern "C" { + +void launch_moe_topk_softmax_bf16( + const void* router_logits, + void* topk_ids, void* topk_weights, + int num_tokens, int num_experts, int top_k, + void* stream +) { + int block = 128; + moe_topk_softmax_bf16_kernel<<>>( + (const __nv_bfloat16*)router_logits, + (int*)topk_ids, (float*)topk_weights, + num_experts, top_k + ); + CUDA_CHECK_LAST_ERROR(); +} + +void launch_moe_replicate_bf16( + const void* x, void* x_rep, + int num_tokens, int hidden, int local_experts, + void* stream +) { + int total = local_experts * num_tokens * hidden; + int block = 256; + int grid = (total + block - 1) / block; + moe_replicate_bf16_kernel<<>>( + (const __nv_bfloat16*)x, (__nv_bfloat16*)x_rep, + num_tokens, hidden, local_experts + ); + CUDA_CHECK_LAST_ERROR(); +} + +void launch_moe_bias_add_3d_bf16( + void* x, const void* bias, + int batch, int num_tokens, int dim, + void* stream +) { + int total = batch * num_tokens * dim; + int block = 256; + int grid = (total + block - 1) / block; + moe_bias_add_3d_bf16_kernel<<>>( + (__nv_bfloat16*)x, (const __nv_bfloat16*)bias, + batch, num_tokens, dim + ); + CUDA_CHECK_LAST_ERROR(); +} + +void launch_moe_weighted_sum_bf16( + const void* expert_out, + const void* topk_ids, const void* topk_weights, + void* out, + int num_tokens, int hidden, int top_k, + int expert_start, int local_experts, + void* stream +) { + int total = num_tokens * hidden; + int block = 256; + int grid = (total + block - 1) / block; + moe_weighted_sum_bf16_kernel<<>>( + (const __nv_bfloat16*)expert_out, + (const int*)topk_ids, (const float*)topk_weights, + (__nv_bfloat16*)out, + num_tokens, hidden, top_k, + expert_start, local_experts + ); + CUDA_CHECK_LAST_ERROR(); +} + +}