diff --git a/crates/xserv-kernels/src/activation.rs b/crates/xserv-kernels/src/activation.rs index 09d8e2b..943a531 100644 --- a/crates/xserv-kernels/src/activation.rs +++ b/crates/xserv-kernels/src/activation.rs @@ -8,43 +8,53 @@ unsafe extern "C" { fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); + fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); + fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); + fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); + fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); } -pub fn gelu(x: &Tensor) -> Tensor { - assert!(x.is_contiguous()); - assert!(matches!(x.device(), Device::Cuda(_))); +fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void), + bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor { + assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::zeros(x.shape(), x.dtype(), x.device()); let n = x.numel() as i32; unsafe { match x.dtype() { - DType::F32 => launch_gelu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - DType::BF16 => launch_gelu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - _ => panic!("unsupported dtype for gelu"), + DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + _ => panic!("unsupported dtype"), } } xserv_cuda::device::synchronize().unwrap(); out } -pub fn silu(x: &Tensor) -> Tensor { - assert!(x.is_contiguous()); - assert!(matches!(x.device(), Device::Cuda(_))); - let out = Tensor::zeros(x.shape(), x.dtype(), x.device()); - let n = x.numel() as i32; +fn dispatch_binary(a: &Tensor, b: &Tensor, + f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void), + bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor { + assert_eq!(a.shape(), b.shape()); + assert!(a.is_contiguous() && b.is_contiguous()); + assert!(matches!(a.device(), Device::Cuda(_))); + assert_eq!(a.dtype(), b.dtype()); + let out = Tensor::zeros(a.shape(), a.dtype(), a.device()); + let n = a.numel() as i32; unsafe { - match x.dtype() { - DType::F32 => launch_silu_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - DType::BF16 => launch_silu_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - _ => panic!("unsupported dtype for silu"), + match a.dtype() { + DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + _ => panic!("unsupported dtype"), } } xserv_cuda::device::synchronize().unwrap(); out } +pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) } +pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) } + pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { - assert!(x.is_contiguous()); - assert!(matches!(x.device(), Device::Cuda(_))); + assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::zeros(x.shape(), x.dtype(), x.device()); let n = x.numel() as i32; unsafe { @@ -57,3 +67,6 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { xserv_cuda::device::synchronize().unwrap(); out } + +pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) } +pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) } diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 20cab23..c2bcda3 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -7,7 +7,7 @@ pub mod rmsnorm; pub mod rope; pub mod softmax; -pub use activation::{gelu, scale, silu}; +pub use activation::{add, gelu, mul, scale, silu}; pub use attention::attention; pub use embedding::embedding; pub use gemm::{batched_matmul, matmul, GemmBackend}; diff --git a/crates/xserv-model/src/bin/dump-logits.rs b/crates/xserv-model/src/bin/dump-logits.rs new file mode 100644 index 0000000..080345a --- /dev/null +++ b/crates/xserv-model/src/bin/dump-logits.rs @@ -0,0 +1,44 @@ +use std::path::PathBuf; +use xserv_model::{loader, KVCache, ModelConfig, Qwen3}; +use xserv_tensor::{DType, Device}; +use xserv_tokenizer::Tokenizer; +use half::bf16; + +fn main() { + let args: Vec = std::env::args().collect(); + let model_dir = PathBuf::from(&args[1]); + let prompt = &args[2]; + + xserv_cuda::device::set_device(0).unwrap(); + let config = ModelConfig::from_file(&model_dir.join("config.json")); + let weights = loader::load_model_dir(&model_dir, Device::Cuda(0)); + let model = Qwen3::from_weights(config.clone(), weights); + let tokenizer = Tokenizer::from_file(&model_dir.join("tokenizer.json")); + + let token_ids = tokenizer.encode(prompt); + eprintln!("Prompt: {prompt}"); + eprintln!("Token IDs: {token_ids:?}"); + + let mut cache = KVCache::new( + config.num_layers(), config.num_kv_heads(), config.head_dim(), + DType::BF16, Device::Cuda(0), + ); + let logits = model.forward_with_cache(&token_ids, &mut cache); + let logits_cpu = logits.to_device(Device::Cpu); + let data = logits_cpu.as_slice::(); + let vocab_size = logits.shape()[1]; + let seq_len = logits.shape()[0]; + + // Print top-20 logits for the last position + let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; + let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate() + .map(|(i, v)| (i, v.to_f32())) + .collect(); + indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + println!("Top-20 logits (last position):"); + for (rank, (id, val)) in indexed.iter().take(20).enumerate() { + let tok = tokenizer.decode(&[*id as u32]); + println!(" [{rank:>2}] id={id:>6} logit={val:>10.4} token={tok:?}"); + } +} diff --git a/crates/xserv-model/src/gpt2.rs b/crates/xserv-model/src/gpt2.rs index 7daf988..606688c 100644 --- a/crates/xserv-model/src/gpt2.rs +++ b/crates/xserv-model/src/gpt2.rs @@ -247,27 +247,33 @@ fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor { } fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor { - assert_eq!(a.shape(), b.shape()); - assert_eq!(a.dtype(), DType::F32); - let a_cpu = a.to_device(Device::Cpu); - let b_cpu = b.to_device(Device::Cpu); - let a_data = a_cpu.as_slice::(); - let b_data = b_cpu.as_slice::(); - let sum: Vec = a_data.iter().zip(b_data).map(|(x, y)| x + y).collect(); - Tensor::from_slice(&sum, a.shape()).to_device(a.device()) + xserv_kernels::add(a, b) } fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { + // bias: [N], x: [S, N] — broadcast add via reshape assert_eq!(x.ndim(), 2); assert_eq!(bias.ndim(), 1); - assert_eq!(x.shape()[1], bias.shape()[0]); - let x_cpu = x.to_device(Device::Cpu); - let b_cpu = bias.to_device(Device::Cpu); - let x_data = x_cpu.as_slice::(); - let b_data = b_cpu.as_slice::(); let n = bias.shape()[0]; - let result: Vec = x_data.iter().enumerate().map(|(i, &v)| v + b_data[i % n]).collect(); - Tensor::from_slice(&result, x.shape()).to_device(x.device()) + assert_eq!(x.shape()[1], n); + let rows = x.shape()[0]; + // Broadcast: tile bias to [S, N] on CPU, then GPU add + let b_cpu = bias.to_device(Device::Cpu); + match x.dtype() { + DType::F32 => { + let bd = b_cpu.as_slice::(); + let tiled: Vec = (0..rows).flat_map(|_| bd.iter().copied()).collect(); + let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device()); + xserv_kernels::add(x, &b_full) + } + DType::BF16 => { + let bd = b_cpu.as_slice::(); + let tiled: Vec = (0..rows).flat_map(|_| bd.iter().copied()).collect(); + let b_full = Tensor::from_slice(&tiled, x.shape()).to_device(x.device()); + xserv_kernels::add(x, &b_full) + } + _ => panic!("unsupported dtype"), + } } fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) { diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index d99dc5d..ceb0a76 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -250,27 +250,11 @@ fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor { } fn add_any(a: &Tensor, b: &Tensor) -> Tensor { - assert_eq!(a.shape(), b.shape()); - let a_cpu = a.to_device(Device::Cpu); - let b_cpu = b.to_device(Device::Cpu); - let ad = a_cpu.as_slice::(); - let bd = b_cpu.as_slice::(); - let r: Vec = ad.iter().zip(bd) - .map(|(x, y)| bf16::from_f32(x.to_f32() + y.to_f32())) - .collect(); - Tensor::from_slice(&r, a.shape()).to_device(a.device()) + xserv_kernels::add(a, b) } fn mul_any(a: &Tensor, b: &Tensor) -> Tensor { - assert_eq!(a.shape(), b.shape()); - let a_cpu = a.to_device(Device::Cpu); - let b_cpu = b.to_device(Device::Cpu); - let ad = a_cpu.as_slice::(); - let bd = b_cpu.as_slice::(); - let r: Vec = ad.iter().zip(bd) - .map(|(x, y)| bf16::from_f32(x.to_f32() * y.to_f32())) - .collect(); - Tensor::from_slice(&r, a.shape()).to_device(a.device()) + xserv_kernels::mul(a, b) } pub fn sample_greedy(logits: &Tensor) -> u32 { diff --git a/csrc/activation/activations.cu b/csrc/activation/activations.cu index 3593dc1..c4b02ff 100644 --- a/csrc/activation/activations.cu +++ b/csrc/activation/activations.cu @@ -45,6 +45,26 @@ __global__ void scale_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* out, fl if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(x[idx]) * scale); } +// Element-wise add: out = a + b +__global__ void add_f32_kernel(const float* a, const float* b, float* out, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = a[idx] + b[idx]; +} +__global__ void add_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) + __bfloat162float(b[idx])); +} + +// Element-wise mul: out = a * b +__global__ void mul_f32_kernel(const float* a, const float* b, float* out, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = a[idx] * b[idx]; +} +__global__ void mul_bf16_kernel(const __nv_bfloat16* a, const __nv_bfloat16* b, __nv_bfloat16* out, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) out[idx] = __float2bfloat16(__bfloat162float(a[idx]) * __bfloat162float(b[idx])); +} + extern "C" { void launch_gelu_f32(const void* x, void* out, int n, void* stream) { @@ -87,4 +107,29 @@ void launch_scale_bf16(const void* x, void* out, float scale, int n, void* strea (const __nv_bfloat16*)x, (__nv_bfloat16*)out, scale, n); } +void launch_add_f32(const void* a, const void* b, void* out, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + add_f32_kernel<<>>( + (const float*)a, (const float*)b, (float*)out, n); +} +void launch_add_bf16(const void* a, const void* b, void* out, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + add_bf16_kernel<<>>( + (const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n); +} +void launch_mul_f32(const void* a, const void* b, void* out, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + mul_f32_kernel<<>>( + (const float*)a, (const float*)b, (float*)out, n); +} +void launch_mul_bf16(const void* a, const void* b, void* out, int n, void* stream) { + int block = 256; + int grid = (n + block - 1) / block; + mul_bf16_kernel<<>>( + (const __nv_bfloat16*)a, (const __nv_bfloat16*)b, (__nv_bfloat16*)out, n); +} + }