diff --git a/crates/xserv-kernels/src/activation.rs b/crates/xserv-kernels/src/activation.rs index 5ccc156..2477841 100644 --- a/crates/xserv-kernels/src/activation.rs +++ b/crates/xserv-kernels/src/activation.rs @@ -15,6 +15,8 @@ unsafe extern "C" { fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32, alpha: f32, limit: f32, stream: *mut c_void); + fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void, + rows: i32, cols: i32, stream: *mut c_void); } fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void), @@ -77,6 +79,28 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) } pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) } +/// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only). +pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor { + assert_eq!(x.ndim(), 2); + assert_eq!(bias.ndim(), 1); + assert_eq!(x.dtype(), DType::BF16); + assert_eq!(bias.dtype(), DType::BF16); + assert!(x.is_contiguous() && bias.is_contiguous()); + assert!(matches!(x.device(), Device::Cuda(_))); + let rows = x.shape()[0]; + let cols = x.shape()[1]; + assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]); + assert!(rows * cols <= i32::MAX as usize); + let out = Tensor::empty(&[rows, cols], DType::BF16, x.device()); + unsafe { + launch_bias_add_2d_bf16( + x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void, + rows as i32, cols as i32, std::ptr::null_mut(), + ); + } + out +} + /// Fused SiLU×Mul: out = silu(gate) * up (BF16 only) /// Saves one HBM read + one HBM write compared to separate silu + mul. pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor { diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index adf2ca1..eeddd54 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -12,7 +12,7 @@ pub mod rope; pub mod softmax; pub mod transpose; -pub use activation::{add, gelu, gpt_oss_glu, mul, scale, silu, silu_mul}; +pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul}; pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu}; pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16}; diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index 73dbf7f..5250f54 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -450,12 +450,8 @@ impl GptOss { paged_cache.advance_seq_len(slot, 1); } - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } let x = Self::norm(&x, &self.norm, &self.norm_bias, eps); - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } - let logits = matmul_2d(&x, &self.lm_head_t); - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } - logits + matmul_2d(&x, &self.lm_head_t) } /// Paged prefill: process full prompt tokens. @@ -519,9 +515,7 @@ impl GptOss { } let x = Self::norm(&x, &self.norm, &self.norm_bias, eps); - let logits = matmul_2d(&x, &self.lm_head_t); - unsafe { xserv_cuda::ffi::cudaDeviceSynchronize(); } - logits + matmul_2d(&x, &self.lm_head_t) } /// MoE forward pass — fully on GPU via batched GEMM. @@ -691,31 +685,12 @@ fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor { matmul(a, b, GemmBackend::CuBlas) } -/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols] +/// Add bias to a 2D tensor: [rows, cols] + [cols] → [rows, cols]. +/// Single GPU broadcast kernel — the old rows>1 path tiled the bias on the +/// CPU (D2H + host loop + H2D) on every call, 96×/prefill in the hot path. fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { - assert_eq!(x.ndim(), 2); - assert_eq!(bias.ndim(), 1); - let rows = x.shape()[0]; - let cols = x.shape()[1]; - assert_eq!(bias.shape()[0], cols, "bias size {} != cols {}", bias.shape()[0], cols); - let x_c = x.contiguous(); - - if rows == 1 { - // Fast path: reshape bias [cols] → [1, cols] (zero-copy), add directly on GPU - let bias_2d = bias.reshape(&[1, cols]); - return xserv_kernels::add(&x_c, &bias_2d); - } - - // General path: tile bias to [rows, cols] via CPU, then add on GPU - let bias_cpu = bias.to_device(Device::Cpu); - let bias_data = bias_cpu.as_slice::(); - let mut tiled = Vec::with_capacity(rows * cols); - for _ in 0..rows { - tiled.extend_from_slice(bias_data); - } - let bias_tiled = Tensor::from_slice(&tiled, &[rows, cols]).to_device(x.device()); - xserv_kernels::add(&x_c, &bias_tiled) + xserv_kernels::bias_add_2d(&x_c, bias) } fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor { diff --git a/csrc/activation/activations.cu b/csrc/activation/activations.cu index 899b86c..fc9c672 100644 --- a/csrc/activation/activations.cu +++ b/csrc/activation/activations.cu @@ -87,6 +87,17 @@ __global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx])); } +// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] +__global__ void bias_add_2d_bf16_kernel( + const __nv_bfloat16* __restrict__ x, const __nv_bfloat16* __restrict__ bias, + __nv_bfloat16* __restrict__ out, int rows, int cols +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= rows * cols) return; + float v = __bfloat162float(x[idx]) + __bfloat162float(bias[idx % cols]); + out[idx] = __float2bfloat16(v); +} + // Element-wise mul: out = a * b __global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -159,6 +170,14 @@ void launch_add_bf16(const void* a, const void* b, void* out, int n, void* strea (const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n); CUDA_CHECK_LAST_ERROR(); } +void launch_bias_add_2d_bf16(const void* x, const void* bias, void* out, int rows, int cols, void* stream) { + int n = rows * cols; + int block = 256; + int grid = (n + block - 1) / block; + bias_add_2d_bf16_kernel<<>>( + (const __nv_bfloat16*)x, (const __nv_bfloat16*)bias, (__nv_bfloat16*)out, rows, cols); + CUDA_CHECK_LAST_ERROR(); +} void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) { int block = 256; int grid = (n + block - 1) / block;