From 531cd3fe08f9bbe0f1d395d64211c6b13f977ae8 Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Thu, 18 Jun 2026 18:11:58 +0800 Subject: [PATCH] style: format Rust workspace --- crates/xserv-cuda/src/allocator.rs | 5 +- crates/xserv-cuda/src/device.rs | 4 +- crates/xserv-cuda/src/ffi.rs | 8 +- crates/xserv-cuda/src/graph.rs | 17 +- crates/xserv-cuda/src/lib.rs | 2 +- crates/xserv-cuda/src/memory.rs | 46 ++- crates/xserv-cuda/tests/integration.rs | 21 +- crates/xserv-distributed/src/ffi.rs | 14 +- crates/xserv-distributed/src/lib.rs | 38 +- crates/xserv-distributed/tests/allreduce.rs | 8 +- crates/xserv-distributed/tests/sendrecv.rs | 7 +- crates/xserv-kernels/src/activation.rs | 189 ++++++++-- crates/xserv-kernels/src/argmax.rs | 23 +- crates/xserv-kernels/src/attention.rs | 237 ++++++++---- crates/xserv-kernels/src/dispatch.rs | 302 +++++++++++++--- crates/xserv-kernels/src/embedding.rs | 56 ++- crates/xserv-kernels/src/gemm.rs | 232 +++++++++--- crates/xserv-kernels/src/layernorm.rs | 52 ++- crates/xserv-kernels/src/lib.rs | 14 +- crates/xserv-kernels/src/moe.rs | 217 ++++++++--- crates/xserv-kernels/src/quantization.rs | 220 +++++++++--- crates/xserv-kernels/src/rmsnorm.rs | 74 +++- crates/xserv-kernels/src/rope.rs | 90 +++-- crates/xserv-kernels/src/softmax.rs | 40 ++- crates/xserv-kernels/src/transpose.rs | 188 ++++++++-- crates/xserv-kernels/tests/attention_test.rs | 81 ++++- crates/xserv-kernels/tests/gemm_test.rs | 82 +++-- crates/xserv-kernels/tests/ops_test.rs | 46 ++- crates/xserv-model/src/bin/bench-gpt-oss.rs | 192 +++++++--- crates/xserv-model/src/bin/bench-gpt2.rs | 51 ++- crates/xserv-model/src/bin/bench-qwen3.rs | 52 ++- crates/xserv-model/src/bin/bench-tp.rs | 78 +++- crates/xserv-model/src/bin/dump-logits.rs | 15 +- crates/xserv-model/src/bin/xserv-chat.rs | 310 +++++++++++++--- crates/xserv-model/src/bin/xserv-cli.rs | 76 +++- crates/xserv-model/src/config.rs | 23 +- crates/xserv-model/src/decode_graph.rs | 338 +++++++++++++---- crates/xserv-model/src/gpt2.rs | 109 ++++-- crates/xserv-model/src/gpt_oss.rs | 359 ++++++++++++++----- crates/xserv-model/src/gpt_oss_graph.rs | 39 +- crates/xserv-model/src/kv_cache.rs | 99 ++++- crates/xserv-model/src/lib.rs | 4 +- crates/xserv-model/src/loader.rs | 14 +- crates/xserv-model/src/paged_kv_cache.rs | 248 ++++++++++--- crates/xserv-model/src/qwen3.rs | 295 ++++++++++----- crates/xserv-model/src/sampling.rs | 17 +- crates/xserv-server/src/api.rs | 44 ++- crates/xserv-server/src/engine.rs | 105 ++++-- crates/xserv-server/src/main.rs | 45 ++- crates/xserv-server/src/pp_engine.rs | 100 +++++- crates/xserv-server/src/tp_engine.rs | 154 ++++++-- crates/xserv-tensor/src/dtype.rs | 24 +- crates/xserv-tensor/src/lib.rs | 2 +- crates/xserv-tensor/src/shape.rs | 27 +- crates/xserv-tensor/src/tensor.rs | 79 +++- crates/xserv-tensor/tests/integration.rs | 6 +- crates/xserv-tokenizer/src/bpe.rs | 31 +- 57 files changed, 4045 insertions(+), 1204 deletions(-) diff --git a/crates/xserv-cuda/src/allocator.rs b/crates/xserv-cuda/src/allocator.rs index e309dd1..c143c57 100644 --- a/crates/xserv-cuda/src/allocator.rs +++ b/crates/xserv-cuda/src/allocator.rs @@ -165,7 +165,10 @@ pub fn begin_retain() { /// Stop quarantining and hand the quarantined blocks to the caller. pub fn end_retain() -> RetainedBlocks { RETAINED.with(|cell| { - let list = cell.borrow_mut().take().expect("end_retain without begin_retain"); + let list = cell + .borrow_mut() + .take() + .expect("end_retain without begin_retain"); RetainedBlocks(list) }) } diff --git a/crates/xserv-cuda/src/device.rs b/crates/xserv-cuda/src/device.rs index 5cb60bf..802db9d 100644 --- a/crates/xserv-cuda/src/device.rs +++ b/crates/xserv-cuda/src/device.rs @@ -48,9 +48,7 @@ pub fn device_info(device: u32) -> Result { // Heap-allocate oversized buffer for cudaDeviceProp (layout varies by CUDA version). // CUDA 12.x struct is ~5-6 KB; use 32 KB to guard against future growth. let mut prop_buf = vec![0u8; 32768]; - error::check(unsafe { - ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) - })?; + error::check(unsafe { ffi::cudaGetDeviceProperties(prop_buf.as_mut_ptr(), device as i32) })?; // Name is always the first field: char[256]. let name = unsafe { CStr::from_ptr(prop_buf.as_ptr() as *const c_char) } .to_string_lossy() diff --git a/crates/xserv-cuda/src/ffi.rs b/crates/xserv-cuda/src/ffi.rs index 8b18571..dff82ea 100644 --- a/crates/xserv-cuda/src/ffi.rs +++ b/crates/xserv-cuda/src/ffi.rs @@ -64,11 +64,5 @@ unsafe extern "C" { pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32; // --- Our test kernel --- - pub fn launch_vecadd_f32( - a: *const f32, - b: *const f32, - c: *mut f32, - n: i32, - stream: CudaStream, - ); + pub fn launch_vecadd_f32(a: *const f32, b: *const f32, c: *mut f32, n: i32, stream: CudaStream); } diff --git a/crates/xserv-cuda/src/graph.rs b/crates/xserv-cuda/src/graph.rs index 748bcbb..b67fef6 100644 --- a/crates/xserv-cuda/src/graph.rs +++ b/crates/xserv-cuda/src/graph.rs @@ -54,30 +54,21 @@ impl CudaGraph { // made by THIS thread invalidate the capture. With GLOBAL mode, TP rank // threads capturing concurrently would poison each other's captures. error::check(unsafe { - ffi::cudaStreamBeginCapture( - stream.as_raw(), - ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL, - ) + ffi::cudaStreamBeginCapture(stream.as_raw(), ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL) }) } /// End capture and instantiate the executable graph. pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> { - error::check(unsafe { - ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) - })?; - error::check(unsafe { - ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) - }) + error::check(unsafe { ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph) })?; + error::check(unsafe { ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0) }) } /// Replay the captured graph on `stream`. /// Panics if no graph has been captured yet. pub fn launch(&self, stream: &CudaStream) -> Result<()> { assert!(self.is_ready(), "CudaGraph::launch called before capture"); - error::check(unsafe { - ffi::cudaGraphLaunch(self.exec, stream.as_raw()) - }) + error::check(unsafe { ffi::cudaGraphLaunch(self.exec, stream.as_raw()) }) } fn destroy_inner(&mut self) { diff --git a/crates/xserv-cuda/src/lib.rs b/crates/xserv-cuda/src/lib.rs index adb3811..0ecf3c6 100644 --- a/crates/xserv-cuda/src/lib.rs +++ b/crates/xserv-cuda/src/lib.rs @@ -11,4 +11,4 @@ pub use device::DeviceInfo; pub use error::{CudaError, Result}; pub use graph::CudaGraph; pub use memory::{GpuBuffer, PinnedBuffer}; -pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard}; +pub use stream::{CudaStream, StreamGuard, current_stream_raw, push_stream}; diff --git a/crates/xserv-cuda/src/memory.rs b/crates/xserv-cuda/src/memory.rs index 974c00e..382930c 100644 --- a/crates/xserv-cuda/src/memory.rs +++ b/crates/xserv-cuda/src/memory.rs @@ -22,7 +22,12 @@ impl GpuBuffer { assert!(len > 0, "cannot allocate 0 bytes on GPU"); let mut ptr = std::ptr::null_mut(); error::check(unsafe { ffi::cudaMalloc(&mut ptr, len) })?; - Ok(Self { ptr, len, owned: true, pooled: false }) + Ok(Self { + ptr, + len, + owned: true, + pooled: false, + }) } /// Mark this buffer as pooled (returned to caching allocator on drop) @@ -92,9 +97,7 @@ impl GpuBuffer { /// Copy from another GPU buffer (D2D). pub fn copy_from_device(&mut self, src: &GpuBuffer) -> Result<()> { let n = src.len.min(self.len); - error::check(unsafe { - ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) - }) + error::check(unsafe { ffi::cudaMemcpy(self.ptr, src.ptr, n, ffi::CUDA_MEMCPY_D2D) }) } /// Fill buffer with zeros. @@ -103,7 +106,13 @@ impl GpuBuffer { } /// Copy `count` bytes from `src` buffer at `src_offset` to this buffer at `dst_offset`. - pub fn copy_from_device_at(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize) -> Result<()> { + pub fn copy_from_device_at( + &mut self, + src: &GpuBuffer, + src_offset: usize, + dst_offset: usize, + count: usize, + ) -> Result<()> { assert!(src_offset + count <= src.len); assert!(dst_offset + count <= self.len); error::check(unsafe { @@ -117,7 +126,14 @@ impl GpuBuffer { } /// Async copy `count` bytes from `src` at `src_offset` to `self` at `dst_offset` on `stream`. - pub fn copy_from_device_at_async(&mut self, src: &GpuBuffer, src_offset: usize, dst_offset: usize, count: usize, stream: &CudaStream) -> Result<()> { + pub fn copy_from_device_at_async( + &mut self, + src: &GpuBuffer, + src_offset: usize, + dst_offset: usize, + count: usize, + stream: &CudaStream, + ) -> Result<()> { assert!(src_offset + count <= src.len); assert!(dst_offset + count <= self.len); error::check(unsafe { @@ -161,9 +177,7 @@ impl GpuBuffer { /// Async zero fill on stream. pub fn zero_async(&mut self, stream: &CudaStream) -> Result<()> { - error::check(unsafe { - ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) - }) + error::check(unsafe { ffi::cudaMemsetAsync(self.ptr, 0, self.len, stream.as_raw()) }) } /// Consume the buffer without freeing GPU memory. Returns the raw pointer and length. @@ -178,7 +192,12 @@ impl GpuBuffer { /// Reconstruct a GpuBuffer from a raw pointer + length. /// Safety: ptr must have been allocated with cudaMalloc, len must be correct. pub unsafe fn from_raw(ptr: *mut u8, len: usize) -> Self { - Self { ptr, len, owned: true, pooled: false } + Self { + ptr, + len, + owned: true, + pooled: false, + } } /// Create a non-owning view of GPU memory. Dropping this buffer does NOT @@ -189,7 +208,12 @@ impl GpuBuffer { /// `ptr` must point to a valid GPU allocation of at least `len` bytes that /// will remain live for the lifetime of the returned `GpuBuffer`. pub unsafe fn borrow_raw(ptr: *mut u8, len: usize) -> Self { - Self { ptr, len, owned: false, pooled: false } + Self { + ptr, + len, + owned: false, + pooled: false, + } } } diff --git a/crates/xserv-cuda/tests/integration.rs b/crates/xserv-cuda/tests/integration.rs index f174b01..2410e2d 100644 --- a/crates/xserv-cuda/tests/integration.rs +++ b/crates/xserv-cuda/tests/integration.rs @@ -14,7 +14,10 @@ fn test_device_info() { info.compute_major, info.compute_minor ); println!(" SM Count: {}", info.sm_count); - println!(" Shared Mem/Block: {} KB", info.shared_mem_per_block / 1024); + println!( + " Shared Mem/Block: {} KB", + info.shared_mem_per_block / 1024 + ); println!(" Warp Size: {}", info.warp_size); println!(" Max Threads/Block: {}", info.max_threads_per_block); @@ -145,7 +148,11 @@ fn test_caching_allocator() { // Second allocation of same size: should hit cache let _buf2 = alloc.alloc(1024).unwrap(); - assert_eq!(alloc.stats().cuda_malloc_count, 1, "should reuse cached buffer"); + assert_eq!( + alloc.stats().cuda_malloc_count, + 1, + "should reuse cached buffer" + ); assert_eq!(alloc.stats().cache_hit_count, 1); } @@ -198,11 +205,17 @@ fn test_async_copy() { } let mut gpu = GpuBuffer::alloc(4096).unwrap(); - unsafe { gpu.copy_from_host_async(pinned.as_slice(), &stream).unwrap() }; + unsafe { + gpu.copy_from_host_async(pinned.as_slice(), &stream) + .unwrap() + }; stream.synchronize().unwrap(); let mut out_pinned = PinnedBuffer::alloc(4096).unwrap(); - unsafe { gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream).unwrap() }; + unsafe { + gpu.copy_to_host_async(out_pinned.as_mut_slice(), &stream) + .unwrap() + }; stream.synchronize().unwrap(); assert_eq!(pinned.as_slice(), out_pinned.as_slice()); diff --git a/crates/xserv-distributed/src/ffi.rs b/crates/xserv-distributed/src/ffi.rs index 00fa57c..625b1ed 100644 --- a/crates/xserv-distributed/src/ffi.rs +++ b/crates/xserv-distributed/src/ffi.rs @@ -34,7 +34,12 @@ pub const NCCL_SUCCESS: i32 = 0; unsafe extern "C" { pub fn ncclGetUniqueId(uid: *mut NcclUniqueId) -> i32; // ncclUniqueId is passed BY VALUE (a 128-byte struct) per the NCCL ABI. - pub fn ncclCommInitRank(comm: *mut NcclComm, nranks: i32, commid: NcclUniqueId, rank: i32) -> i32; + pub fn ncclCommInitRank( + comm: *mut NcclComm, + nranks: i32, + commid: NcclUniqueId, + rank: i32, + ) -> i32; pub fn ncclCommDestroy(comm: NcclComm) -> i32; pub fn ncclAllReduce( sendbuff: *const c_void, @@ -78,5 +83,10 @@ pub fn err_string(result: i32) -> String { } pub fn check(result: i32, what: &str) { - assert_eq!(result, NCCL_SUCCESS, "{what} failed: {}", err_string(result)); + assert_eq!( + result, + NCCL_SUCCESS, + "{what} failed: {}", + err_string(result) + ); } diff --git a/crates/xserv-distributed/src/lib.rs b/crates/xserv-distributed/src/lib.rs index 1608241..be41220 100644 --- a/crates/xserv-distributed/src/lib.rs +++ b/crates/xserv-distributed/src/lib.rs @@ -9,8 +9,8 @@ pub mod ffi; use std::ffi::c_void; use ffi::{NcclComm, NcclUniqueId}; -use xserv_cuda::device; use xserv_cuda::GpuBuffer; +use xserv_cuda::device; pub use ffi::NcclUniqueId as UniqueId; @@ -55,7 +55,12 @@ impl TpContext { "ncclCommInitRank", ); ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)"); - Self { rank, world, device, comm } + Self { + rank, + world, + device, + comm, + } } /// In-place AllReduce(sum) over `count` BF16 elements in `buf`. @@ -127,7 +132,12 @@ impl PpContext { "ncclCommInitRank", ); ffi::check(unsafe { ffi::ncclGroupEnd() }, "ncclGroupEnd(init)"); - Self { stage, world, device, comm } + Self { + stage, + world, + device, + comm, + } } /// Send `count` BF16 elements at `ptr` to `peer`, on the null stream so it is @@ -138,7 +148,16 @@ impl PpContext { /// `ptr` must point to at least `count` BF16 elements of valid device memory. pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) { ffi::check( - unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) }, + unsafe { + ffi::ncclSend( + ptr, + count, + ffi::NCCL_BF16, + peer as i32, + self.comm, + launch_stream(), + ) + }, "ncclSend", ); } @@ -149,7 +168,16 @@ impl PpContext { /// `ptr` must point to at least `count` BF16 elements of valid device memory. pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) { ffi::check( - unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) }, + unsafe { + ffi::ncclRecv( + ptr, + count, + ffi::NCCL_BF16, + peer as i32, + self.comm, + launch_stream(), + ) + }, "ncclRecv", ); } diff --git a/crates/xserv-distributed/tests/allreduce.rs b/crates/xserv-distributed/tests/allreduce.rs index 636a149..edd4ed1 100644 --- a/crates/xserv-distributed/tests/allreduce.rs +++ b/crates/xserv-distributed/tests/allreduce.rs @@ -2,8 +2,8 @@ use half::bf16; use std::thread; -use xserv_cuda::{device, GpuBuffer}; -use xserv_distributed::{get_unique_id, TpContext}; +use xserv_cuda::{GpuBuffer, device}; +use xserv_distributed::{TpContext, get_unique_id}; #[test] fn allreduce_two_gpu_sum() { @@ -25,9 +25,7 @@ fn allreduce_two_gpu_sum() { // Rank r fills its buffer with (r + 1). let val = bf16::from_f32((rank + 1) as f32); let host = vec![val; n]; - let src = unsafe { - std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) - }; + let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) }; let mut buf = GpuBuffer::alloc(n * 2).unwrap(); buf.copy_from_host(src).unwrap(); diff --git a/crates/xserv-distributed/tests/sendrecv.rs b/crates/xserv-distributed/tests/sendrecv.rs index 1975729..b1cf80e 100644 --- a/crates/xserv-distributed/tests/sendrecv.rs +++ b/crates/xserv-distributed/tests/sendrecv.rs @@ -6,8 +6,8 @@ use half::bf16; use std::ffi::c_void; use std::thread; -use xserv_cuda::{device, GpuBuffer}; -use xserv_distributed::{get_unique_id, PpContext}; +use xserv_cuda::{GpuBuffer, device}; +use xserv_distributed::{PpContext, get_unique_id}; #[test] fn pp_send_recv_two_stages() { @@ -30,7 +30,8 @@ fn pp_send_recv_two_stages() { if stage == 0 { // Fill with a known pattern and send to stage 1. let host: Vec = (0..n).map(|i| bf16::from_f32((i % 97) as f32)).collect(); - let src = unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) }; + let src = + unsafe { std::slice::from_raw_parts(host.as_ptr() as *const u8, n * 2) }; buf.copy_from_host(src).unwrap(); pp.send_bf16_ptr(buf.as_mut_ptr() as *const c_void, n, 1); device::synchronize().unwrap(); diff --git a/crates/xserv-kernels/src/activation.rs b/crates/xserv-kernels/src/activation.rs index 288804d..45adc1b 100644 --- a/crates/xserv-kernels/src/activation.rs +++ b/crates/xserv-kernels/src/activation.rs @@ -6,78 +6,189 @@ unsafe extern "C" { fn launch_gelu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_silu_f32(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); fn launch_silu_bf16(x: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_scale_f32(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); - fn launch_scale_bf16(x: *const c_void, out: *mut c_void, scale: f32, n: i32, stream: *mut c_void); - fn launch_add_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_mul_f32(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_mul_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_gpt_oss_glu_bf16(gate_up: *const c_void, out: *mut c_void, n_elements: i32, - alpha: f32, limit: f32, stream: *mut c_void); - fn launch_bias_add_2d_bf16(x: *const c_void, bias: *const c_void, out: *mut c_void, - rows: i32, cols: i32, stream: *mut c_void); + fn launch_scale_f32( + x: *const c_void, + out: *mut c_void, + scale: f32, + n: i32, + stream: *mut c_void, + ); + fn launch_scale_bf16( + x: *const c_void, + out: *mut c_void, + scale: f32, + n: i32, + stream: *mut c_void, + ); + fn launch_add_f32( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_add_bf16( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_mul_f32( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_mul_bf16( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_silu_mul_bf16( + gate: *const c_void, + up: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_gpt_oss_glu_bf16( + gate_up: *const c_void, + out: *mut c_void, + n_elements: i32, + alpha: f32, + limit: f32, + stream: *mut c_void, + ); + fn launch_bias_add_2d_bf16( + x: *const c_void, + bias: *const c_void, + out: *mut c_void, + rows: i32, + cols: i32, + stream: *mut c_void, + ); } -fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void), - bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor { +fn dispatch_unary( + x: &Tensor, + f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void), + bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void), +) -> Tensor { assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::empty(x.shape(), x.dtype(), x.device()); let n = x.numel(); - assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)"); + assert!( + n <= i32::MAX as usize, + "tensor too large for i32 kernel param ({n} elements)" + ); let n = n as i32; unsafe { match x.dtype() { - DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), - DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), + DType::F32 => f32_fn( + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + n, + xserv_cuda::current_stream_raw(), + ), + DType::BF16 => bf16_fn( + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + n, + xserv_cuda::current_stream_raw(), + ), _ => panic!("unsupported dtype"), } } out } -fn dispatch_binary(a: &Tensor, b: &Tensor, - f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void), - bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void)) -> Tensor { +fn dispatch_binary( + a: &Tensor, + b: &Tensor, + f32_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void), + bf16_fn: unsafe extern "C" fn(*const c_void, *const c_void, *mut c_void, i32, *mut c_void), +) -> Tensor { assert_eq!(a.shape(), b.shape()); assert!(a.is_contiguous() && b.is_contiguous()); assert!(matches!(a.device(), Device::Cuda(_))); assert_eq!(a.dtype(), b.dtype()); let out = Tensor::empty(a.shape(), a.dtype(), a.device()); let n = a.numel(); - assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)"); + assert!( + n <= i32::MAX as usize, + "tensor too large for i32 kernel param ({n} elements)" + ); let n = n as i32; unsafe { match a.dtype() { - DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), - DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), + DType::F32 => f32_fn( + a.data_ptr() as _, + b.data_ptr() as _, + out.data_ptr() as *mut c_void, + n, + xserv_cuda::current_stream_raw(), + ), + DType::BF16 => bf16_fn( + a.data_ptr() as _, + b.data_ptr() as _, + out.data_ptr() as *mut c_void, + n, + xserv_cuda::current_stream_raw(), + ), _ => panic!("unsupported dtype"), } } out } -pub fn gelu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) } -pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_silu_bf16) } +pub fn gelu(x: &Tensor) -> Tensor { + dispatch_unary(x, launch_gelu_f32, launch_gelu_bf16) +} +pub fn silu(x: &Tensor) -> Tensor { + dispatch_unary(x, launch_silu_f32, launch_silu_bf16) +} pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::empty(x.shape(), x.dtype(), x.device()); let n = x.numel(); - assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)"); + assert!( + n <= i32::MAX as usize, + "tensor too large for i32 kernel param ({n} elements)" + ); let n = n as i32; unsafe { match x.dtype() { - DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()), - DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()), + DType::F32 => launch_scale_f32( + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + scale_val, + n, + xserv_cuda::current_stream_raw(), + ), + DType::BF16 => launch_scale_bf16( + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + scale_val, + n, + xserv_cuda::current_stream_raw(), + ), _ => panic!("unsupported dtype for scale"), } } out } -pub fn add(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_add_f32, launch_add_bf16) } -pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) } +pub fn add(a: &Tensor, b: &Tensor) -> Tensor { + dispatch_binary(a, b, launch_add_f32, launch_add_bf16) +} +pub fn mul(a: &Tensor, b: &Tensor) -> Tensor { + dispatch_binary(a, b, launch_mul_f32, launch_mul_bf16) +} /// Row-broadcast bias add: out[r, c] = x[r, c] + bias[c] (BF16 only). pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor { @@ -89,13 +200,22 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor { assert!(matches!(x.device(), Device::Cuda(_))); let rows = x.shape()[0]; let cols = x.shape()[1]; - assert_eq!(bias.shape()[0], cols, "bias size {} != cols {cols}", bias.shape()[0]); + assert_eq!( + bias.shape()[0], + cols, + "bias size {} != cols {cols}", + bias.shape()[0] + ); assert!(rows * cols <= i32::MAX as usize); let out = Tensor::empty(&[rows, cols], DType::BF16, x.device()); unsafe { launch_bias_add_2d_bf16( - x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + bias.data_ptr() as _, + out.data_ptr() as *mut c_void, + rows as i32, + cols as i32, + xserv_cuda::current_stream_raw(), ); } out @@ -110,7 +230,10 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor { assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16"); let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device()); let n = gate.numel(); - assert!(n <= i32::MAX as usize, "tensor too large for i32 kernel param ({n} elements)"); + assert!( + n <= i32::MAX as usize, + "tensor too large for i32 kernel param ({n} elements)" + ); let n = n as i32; unsafe { launch_silu_mul_bf16( diff --git a/crates/xserv-kernels/src/argmax.rs b/crates/xserv-kernels/src/argmax.rs index 3d2765b..7ae0b9e 100644 --- a/crates/xserv-kernels/src/argmax.rs +++ b/crates/xserv-kernels/src/argmax.rs @@ -2,8 +2,13 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_argmax_bf16(logits: *const c_void, out_idx: *mut c_void, - rows: i32, cols: i32, stream: *mut c_void); + fn launch_argmax_bf16( + logits: *const c_void, + out_idx: *mut c_void, + rows: i32, + cols: i32, + stream: *mut c_void, + ); } /// GPU argmax over the last dim of a [rows, cols] BF16 tensor. @@ -19,7 +24,10 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec { assert_eq!(logits.ndim(), 2, "argmax expects a 2D [rows, cols] tensor"); assert_eq!(logits.dtype(), DType::BF16, "argmax kernel is BF16-only"); assert!(logits.is_contiguous(), "argmax requires contiguous input"); - assert!(matches!(logits.device(), Device::Cuda(_)), "argmax requires GPU input"); + assert!( + matches!(logits.device(), Device::Cuda(_)), + "argmax requires GPU input" + ); let rows = logits.shape()[0]; let cols = logits.shape()[1]; @@ -35,7 +43,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec { launch_argmax_bf16( logits.data_ptr() as *const c_void, out.as_mut_ptr() as *mut c_void, - rows as i32, cols as i32, + rows as i32, + cols as i32, xserv_cuda::current_stream_raw(), ); } @@ -44,9 +53,8 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec { out.copy_to_host(&mut host_bytes).expect("argmax D2H"); drop(out); // returned to pool - let host_i32: &[i32] = unsafe { - std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) - }; + let host_i32: &[i32] = + unsafe { std::slice::from_raw_parts(host_bytes.as_ptr() as *const i32, rows) }; host_i32.iter().map(|&v| v as u32).collect() } @@ -62,4 +70,3 @@ pub fn argmax_bf16_single(logits: &Tensor) -> u32 { }; argmax_bf16_to_host(&view)[0] } - diff --git a/crates/xserv-kernels/src/attention.rs b/crates/xserv-kernels/src/attention.rs index 61f41b4..fca9b66 100644 --- a/crates/xserv-kernels/src/attention.rs +++ b/crates/xserv-kernels/src/attention.rs @@ -6,28 +6,67 @@ use crate::gemm::batched_matmul; use crate::softmax::softmax; unsafe extern "C" { - fn launch_causal_mask_f32(scores: *mut c_void, batch: i32, rows: i32, cols: i32, - offset: i32, stream: *mut c_void); - fn launch_causal_mask_bf16(scores: *mut c_void, batch: i32, rows: i32, cols: i32, - offset: i32, stream: *mut c_void); + fn launch_causal_mask_f32( + scores: *mut c_void, + batch: i32, + rows: i32, + cols: i32, + offset: i32, + stream: *mut c_void, + ); + fn launch_causal_mask_bf16( + scores: *mut c_void, + batch: i32, + rows: i32, + cols: i32, + offset: i32, + stream: *mut c_void, + ); fn launch_flash_attention_bf16( - q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - q_len: i32, kv_len: i32, head_dim: i32, - scale: f32, causal: i32, stream: *mut c_void, + q: *const c_void, + k: *const c_void, + v: *const c_void, + o: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + q_len: i32, + kv_len: i32, + head_dim: i32, + scale: f32, + causal: i32, + stream: *mut c_void, ); fn launch_flash_attention_sinks_bf16( - q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void, + q: *const c_void, + k: *const c_void, + v: *const c_void, + o: *mut c_void, sinks: *const c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - q_len: i32, kv_len: i32, head_dim: i32, - scale: f32, causal: i32, window_size: i32, stream: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + q_len: i32, + kv_len: i32, + head_dim: i32, + scale: f32, + causal: i32, + window_size: i32, + stream: *mut c_void, ); fn launch_decode_attention_bf16( - q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - kv_len: i32, head_dim: i32, - scale: f32, causal: i32, stream: *mut c_void, + q: *const c_void, + k: *const c_void, + v: *const c_void, + o: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + kv_len: i32, + head_dim: i32, + scale: f32, + causal: i32, + stream: *mut c_void, ); fn launch_paged_decode_attention_bf16( q: *const c_void, @@ -36,9 +75,13 @@ unsafe extern "C" { o: *mut c_void, block_tables: *const i32, context_lens: *const i32, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - head_dim: i32, max_blocks_per_seq: i32, - scale: f32, stream: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + head_dim: i32, + max_blocks_per_seq: i32, + scale: f32, + stream: *mut c_void, ); fn launch_paged_decode_attention_sinks_bf16( q: *const c_void, @@ -48,24 +91,40 @@ unsafe extern "C" { block_tables: *const i32, context_lens: *const i32, sinks: *const c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - head_dim: i32, max_blocks_per_seq: i32, - scale: f32, window_size: i32, stream: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + head_dim: i32, + max_blocks_per_seq: i32, + scale: f32, + window_size: i32, + stream: *mut c_void, ); fn launch_reshape_and_cache_bf16( - k_src: *const c_void, v_src: *const c_void, - k_pool: *mut c_void, v_pool: *mut c_void, + k_src: *const c_void, + v_src: *const c_void, + k_pool: *mut c_void, + v_pool: *mut c_void, block_ids: *const c_void, - num_tokens: i32, num_heads: i32, - head_dim: i32, start_pos: i32, block_size: i32, + num_tokens: i32, + num_heads: i32, + head_dim: i32, + start_pos: i32, + block_size: i32, stream: *mut c_void, ); fn launch_reshape_and_cache_batched_bf16( - k_src: *const c_void, v_src: *const c_void, - k_pool: *mut c_void, v_pool: *mut c_void, - block_tables: *const c_void, kv_lens: *const c_void, - batch: i32, num_heads: i32, - head_dim: i32, block_size: i32, max_blocks_per_seq: i32, + k_src: *const c_void, + v_src: *const c_void, + k_pool: *mut c_void, + v_pool: *mut c_void, + block_tables: *const c_void, + kv_lens: *const c_void, + batch: i32, + num_heads: i32, + head_dim: i32, + block_size: i32, + max_blocks_per_seq: i32, stream: *mut c_void, ); } @@ -84,20 +143,30 @@ unsafe extern "C" { /// `block_ids_gpu` must contain at least `(start_pos + num_tokens + block_size - 1) / block_size` /// valid physical block ids. pub unsafe fn reshape_and_cache_bf16( - k_src: *const c_void, v_src: *const c_void, - k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void, + k_src: *const c_void, + v_src: *const c_void, + k_pool_ptr: *mut c_void, + v_pool_ptr: *mut c_void, block_ids_gpu: *const i32, - num_tokens: usize, num_heads: usize, - head_dim: usize, start_pos: usize, block_size: usize, + num_tokens: usize, + num_heads: usize, + head_dim: usize, + start_pos: usize, + block_size: usize, stream: *mut c_void, ) { unsafe { launch_reshape_and_cache_bf16( - k_src, v_src, - k_pool_ptr, v_pool_ptr, + k_src, + v_src, + k_pool_ptr, + v_pool_ptr, block_ids_gpu as *const c_void, - num_tokens as i32, num_heads as i32, - head_dim as i32, start_pos as i32, block_size as i32, + num_tokens as i32, + num_heads as i32, + head_dim as i32, + start_pos as i32, + block_size as i32, stream, ); } @@ -113,21 +182,32 @@ pub unsafe fn reshape_and_cache_bf16( /// All pointers must be on the same GPU. `block_tables` and `kv_lens` must /// already be synced to the device for the active batch. pub unsafe fn reshape_and_cache_batched_bf16( - k_src: *const c_void, v_src: *const c_void, - k_pool_ptr: *mut c_void, v_pool_ptr: *mut c_void, - block_tables_gpu: *const i32, kv_lens_gpu: *const i32, - batch: usize, num_heads: usize, - head_dim: usize, block_size: usize, max_blocks_per_seq: usize, + k_src: *const c_void, + v_src: *const c_void, + k_pool_ptr: *mut c_void, + v_pool_ptr: *mut c_void, + block_tables_gpu: *const i32, + kv_lens_gpu: *const i32, + batch: usize, + num_heads: usize, + head_dim: usize, + block_size: usize, + max_blocks_per_seq: usize, stream: *mut c_void, ) { unsafe { launch_reshape_and_cache_batched_bf16( - k_src, v_src, - k_pool_ptr, v_pool_ptr, + k_src, + v_src, + k_pool_ptr, + v_pool_ptr, block_tables_gpu as *const c_void, kv_lens_gpu as *const c_void, - batch as i32, num_heads as i32, - head_dim as i32, block_size as i32, max_blocks_per_seq as i32, + batch as i32, + num_heads as i32, + head_dim as i32, + block_size as i32, + max_blocks_per_seq as i32, stream, ); } @@ -143,12 +223,18 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) { match scores.dtype() { DType::F32 => launch_causal_mask_f32( scores.data_ptr() as *mut c_void, - batch as i32, rows as i32, cols as i32, offset as i32, + batch as i32, + rows as i32, + cols as i32, + offset as i32, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_causal_mask_bf16( scores.data_ptr() as *mut c_void, - batch as i32, rows as i32, cols as i32, offset as i32, + batch as i32, + rows as i32, + cols as i32, + offset as i32, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for causal mask"), @@ -214,11 +300,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor { let kv_len = k.shape()[2]; let scale = 1.0 / (head_dim as f32).sqrt(); - let output = Tensor::empty( - &[batch, num_q_heads, 1, head_dim], - DType::BF16, - q.device(), - ); + let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device()); unsafe { launch_decode_attention_bf16( @@ -266,8 +348,14 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens assert_eq!(k.shape(), &[batch, num_kv_heads, kv_len, head_dim]); assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]); - assert!(num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"); - assert!(head_dim <= 128, "flash_attention supports head_dim up to 128"); + assert!( + num_q_heads % num_kv_heads == 0, + "num_q_heads must be divisible by num_kv_heads" + ); + assert!( + head_dim <= 128, + "flash_attention supports head_dim up to 128" + ); // Dispatch to specialized decode kernel for single-token generation if q_len == 1 { @@ -333,10 +421,18 @@ pub fn flash_attention_sinks( assert_eq!(v.shape(), &[batch, num_kv_heads, kv_len, head_dim]); assert!(num_q_heads % num_kv_heads == 0); assert!(head_dim <= 128); - assert_eq!(sinks.shape()[0], num_q_heads, "sinks must have num_q_heads entries"); + assert_eq!( + sinks.shape()[0], + num_q_heads, + "sinks must have num_q_heads entries" + ); let scale = 1.0 / (head_dim as f32).sqrt(); - let output = Tensor::empty(&[batch, num_q_heads, q_len, head_dim], DType::BF16, q.device()); + let output = Tensor::empty( + &[batch, num_q_heads, q_len, head_dim], + DType::BF16, + q.device(), + ); unsafe { launch_flash_attention_sinks_bf16( @@ -383,17 +479,20 @@ pub fn paged_decode_attention( max_blocks_per_seq: usize, ) -> Tensor { assert_eq!(q.ndim(), 4); - assert_eq!(q.shape()[2], 1, "paged_decode_attention requires q_len == 1"); + assert_eq!( + q.shape()[2], + 1, + "paged_decode_attention requires q_len == 1" + ); assert_eq!(q.dtype(), DType::BF16); - assert!(num_q_heads % num_kv_heads == 0, "GQA: num_q_heads must be divisible by num_kv_heads"); + assert!( + num_q_heads % num_kv_heads == 0, + "GQA: num_q_heads must be divisible by num_kv_heads" + ); assert!(head_dim <= 128); let scale = 1.0 / (head_dim as f32).sqrt(); - let output = Tensor::empty( - &[batch, num_q_heads, 1, head_dim], - DType::BF16, - q.device(), - ); + let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device()); unsafe { launch_paged_decode_attention_bf16( @@ -442,11 +541,7 @@ pub fn paged_decode_attention_sinks( assert!(head_dim <= 128); let scale = 1.0 / (head_dim as f32).sqrt(); - let output = Tensor::empty( - &[batch, num_q_heads, 1, head_dim], - DType::BF16, - q.device(), - ); + let output = Tensor::empty(&[batch, num_q_heads, 1, head_dim], DType::BF16, q.device()); unsafe { launch_paged_decode_attention_sinks_bf16( diff --git a/crates/xserv-kernels/src/dispatch.rs b/crates/xserv-kernels/src/dispatch.rs index 2de52f9..1901ad4 100644 --- a/crates/xserv-kernels/src/dispatch.rs +++ b/crates/xserv-kernels/src/dispatch.rs @@ -5,104 +5,302 @@ use std::ffi::c_void; // Re-declare the extern functions we need (same as in the individual modules) unsafe extern "C" { - fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); - fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void, - normed_out: *mut c_void, sum_out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); - fn launch_silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void); - fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void, - num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void); - fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void, - positions: *const c_void, num_tokens: i32, num_heads: i32, - head_dim: i32, stream: *mut c_void); - fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, - k: i32, n: i32, stream: *mut c_void); + fn launch_rmsnorm_bf16( + x: *const c_void, + gamma: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); + fn launch_add_rmsnorm_bf16( + x: *const c_void, + residual: *const c_void, + gamma: *const c_void, + normed_out: *mut c_void, + sum_out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); + fn launch_silu_mul_bf16( + gate: *const c_void, + up: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_add_bf16( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, + ); + fn launch_embedding_bf16( + table: *const c_void, + token_ids: *const c_void, + out: *mut c_void, + num_tokens: i32, + hidden_size: i32, + vocab_size: i32, + stream: *mut c_void, + ); + fn launch_reshape_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_merge_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_transpose_hsd_to_shd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_transpose_shd_to_hsd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_rope_bf16( + x: *mut c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + positions: *const c_void, + num_tokens: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_gemv_bf16( + x: *const c_void, + w: *const c_void, + y_bf16: *mut c_void, + y_fp32_buf: *mut c_void, + k: i32, + n: i32, + stream: *mut c_void, + ); fn launch_decode_attention_bf16( - q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - kv_len: i32, head_dim: i32, - scale: f32, causal: i32, stream: *mut c_void, + q: *const c_void, + k: *const c_void, + v: *const c_void, + o: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + kv_len: i32, + head_dim: i32, + scale: f32, + causal: i32, + stream: *mut c_void, ); } /// Raw rmsnorm dispatch: writes to pre-allocated `out`. -pub unsafe fn rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) { +pub unsafe fn rmsnorm_bf16( + x: *const c_void, + gamma: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, +) { launch_rmsnorm_bf16(x, gamma, out, rows, hidden_size, eps, stream); } /// Raw add_rmsnorm dispatch. -pub unsafe fn add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void, - normed_out: *mut c_void, sum_out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void) { - launch_add_rmsnorm_bf16(x, residual, gamma, normed_out, sum_out, rows, hidden_size, eps, stream); +pub unsafe fn add_rmsnorm_bf16( + x: *const c_void, + residual: *const c_void, + gamma: *const c_void, + normed_out: *mut c_void, + sum_out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, +) { + launch_add_rmsnorm_bf16( + x, + residual, + gamma, + normed_out, + sum_out, + rows, + hidden_size, + eps, + stream, + ); } /// Raw silu_mul dispatch. -pub unsafe fn silu_mul_bf16(gate: *const c_void, up: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) { +pub unsafe fn silu_mul_bf16( + gate: *const c_void, + up: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, +) { launch_silu_mul_bf16(gate, up, out, n, stream); } /// Raw add dispatch. -pub unsafe fn add_bf16(a: *const c_void, b: *const c_void, out: *mut c_void, n: i32, stream: *mut c_void) { +pub unsafe fn add_bf16( + a: *const c_void, + b: *const c_void, + out: *mut c_void, + n: i32, + stream: *mut c_void, +) { launch_add_bf16(a, b, out, n, stream); } /// Raw embedding dispatch. -pub unsafe fn embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void, - num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void) { - launch_embedding_bf16(table, token_ids, out, num_tokens, hidden_size, vocab_size, stream); +pub unsafe fn embedding_bf16( + table: *const c_void, + token_ids: *const c_void, + out: *mut c_void, + num_tokens: i32, + hidden_size: i32, + vocab_size: i32, + stream: *mut c_void, +) { + launch_embedding_bf16( + table, + token_ids, + out, + num_tokens, + hidden_size, + vocab_size, + stream, + ); } /// Raw reshape_heads dispatch. -pub unsafe fn reshape_heads_bf16(inp: *const c_void, out: *mut c_void, - seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) { +pub unsafe fn reshape_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, +) { launch_reshape_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream); } /// Raw merge_heads dispatch. -pub unsafe fn merge_heads_bf16(inp: *const c_void, out: *mut c_void, - seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) { +pub unsafe fn merge_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, +) { launch_merge_heads_bf16(inp, out, seq_len, num_heads, head_dim, stream); } /// Raw transpose HSD->SHD dispatch. -pub unsafe fn transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, - seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) { +pub unsafe fn transpose_hsd_to_shd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, +) { launch_transpose_hsd_to_shd_bf16(inp, out, seq_len, num_heads, head_dim, stream); } /// Raw transpose SHD->HSD dispatch. -pub unsafe fn transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, - seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void) { +pub unsafe fn transpose_shd_to_hsd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, +) { launch_transpose_shd_to_hsd_bf16(inp, out, seq_len, num_heads, head_dim, stream); } /// Raw RoPE dispatch (in-place). -pub unsafe fn rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void, - positions: *const c_void, num_tokens: i32, num_heads: i32, - head_dim: i32, stream: *mut c_void) { - launch_rope_bf16(x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream); +pub unsafe fn rope_bf16( + x: *mut c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + positions: *const c_void, + num_tokens: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, +) { + launch_rope_bf16( + x, cos_cache, sin_cache, positions, num_tokens, num_heads, head_dim, stream, + ); } /// Raw GEMV dispatch (BF16, M=1). Caller must provide fp32 accumulator buffer. -pub unsafe fn gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, - y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void) { +pub unsafe fn gemv_bf16( + x: *const c_void, + w: *const c_void, + y_bf16: *mut c_void, + y_fp32_buf: *mut c_void, + k: i32, + n: i32, + stream: *mut c_void, +) { launch_gemv_bf16(x, w, y_bf16, y_fp32_buf, k, n, stream); } /// Raw decode attention dispatch. -pub unsafe fn decode_attention_bf16(q: *const c_void, k: *const c_void, v: *const c_void, o: *mut c_void, - batch: i32, num_q_heads: i32, num_kv_heads: i32, - kv_len: i32, head_dim: i32, - scale: f32, stream: *mut c_void) { - launch_decode_attention_bf16(q, k, v, o, batch, num_q_heads, num_kv_heads, kv_len, head_dim, scale, 1, stream); +pub unsafe fn decode_attention_bf16( + q: *const c_void, + k: *const c_void, + v: *const c_void, + o: *mut c_void, + batch: i32, + num_q_heads: i32, + num_kv_heads: i32, + kv_len: i32, + head_dim: i32, + scale: f32, + stream: *mut c_void, +) { + launch_decode_attention_bf16( + q, + k, + v, + o, + batch, + num_q_heads, + num_kv_heads, + kv_len, + head_dim, + scale, + 1, + stream, + ); } // cuBLAS FFI diff --git a/crates/xserv-kernels/src/embedding.rs b/crates/xserv-kernels/src/embedding.rs index 0cc2262..1a1d8a2 100644 --- a/crates/xserv-kernels/src/embedding.rs +++ b/crates/xserv-kernels/src/embedding.rs @@ -2,10 +2,24 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_embedding_f32(table: *const c_void, token_ids: *const c_void, out: *mut c_void, - num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void); - fn launch_embedding_bf16(table: *const c_void, token_ids: *const c_void, out: *mut c_void, - num_tokens: i32, hidden_size: i32, vocab_size: i32, stream: *mut c_void); + fn launch_embedding_f32( + table: *const c_void, + token_ids: *const c_void, + out: *mut c_void, + num_tokens: i32, + hidden_size: i32, + vocab_size: i32, + stream: *mut c_void, + ); + fn launch_embedding_bf16( + table: *const c_void, + token_ids: *const c_void, + out: *mut c_void, + num_tokens: i32, + hidden_size: i32, + vocab_size: i32, + stream: *mut c_void, + ); } /// Embedding lookup: table[token_ids[i]] for each i. @@ -18,8 +32,14 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor { let hidden_size = table.shape()[1]; let num_tokens = token_ids.len(); let vocab_size = table.shape()[0]; - assert!(num_tokens <= i32::MAX as usize, "too many tokens for i32 kernel param"); - assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param"); + assert!( + num_tokens <= i32::MAX as usize, + "too many tokens for i32 kernel param" + ); + assert!( + hidden_size <= i32::MAX as usize, + "hidden_size too large for i32 kernel param" + ); // Upload token_ids to GPU let ids_bytes = unsafe { @@ -28,11 +48,15 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor { num_tokens * std::mem::size_of::(), ) }; - let mut ids_gpu = xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids"); + let mut ids_gpu = + xserv_cuda::allocator::cached_alloc(ids_bytes.len()).expect("alloc token_ids"); ids_gpu.copy_from_host(ids_bytes).unwrap(); for &tid in token_ids { - assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})"); + assert!( + (tid as usize) < vocab_size, + "token_id {tid} out of bounds (vocab_size={vocab_size})" + ); } embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens) @@ -53,14 +77,22 @@ pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: unsafe { match table.dtype() { DType::F32 => launch_embedding_f32( - table.data_ptr() as _, ids_gpu, + table.data_ptr() as _, + ids_gpu, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(), + num_tokens as i32, + hidden_size as i32, + vocab_size as i32, + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_embedding_bf16( - table.data_ptr() as _, ids_gpu, + table.data_ptr() as _, + ids_gpu, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(), + num_tokens as i32, + hidden_size as i32, + vocab_size as i32, + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for embedding"), } diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs index 88e743f..fc919c6 100644 --- a/crates/xserv-kernels/src/gemm.rs +++ b/crates/xserv-kernels/src/gemm.rs @@ -1,14 +1,22 @@ use std::cell::RefCell; use std::ffi::c_void; -use xserv_cuda::error::{self, Result}; use xserv_cuda::GpuBuffer; +use xserv_cuda::error::{self, Result}; use xserv_tensor::{DType, Device, Tensor}; const CUBLAS_WORKSPACE_BYTES: usize = 32 * 1024 * 1024; // GEMV: single-kernel, no FP32 temp buffer needed unsafe extern "C" { - fn launch_gemv_bf16(x: *const c_void, w: *const c_void, y_bf16: *mut c_void, y_fp32_buf: *mut c_void, k: i32, n: i32, stream: *mut c_void); + fn launch_gemv_bf16( + x: *const c_void, + w: *const c_void, + y_bf16: *mut c_void, + y_fp32_buf: *mut c_void, + k: i32, + n: i32, + stream: *mut c_void, + ); } #[derive(Debug, Clone, Copy)] @@ -20,10 +28,42 @@ pub enum GemmBackend { // --- FFI: custom CUDA kernels --- unsafe extern "C" { - fn launch_gemm_naive_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); - fn launch_gemm_naive_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); - fn launch_gemm_tiled_f32(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); - fn launch_gemm_tiled_bf16(a: *const c_void, b: *const c_void, c: *mut c_void, m: i32, n: i32, k: i32, stream: *mut c_void); + fn launch_gemm_naive_f32( + a: *const c_void, + b: *const c_void, + c: *mut c_void, + m: i32, + n: i32, + k: i32, + stream: *mut c_void, + ); + fn launch_gemm_naive_bf16( + a: *const c_void, + b: *const c_void, + c: *mut c_void, + m: i32, + n: i32, + k: i32, + stream: *mut c_void, + ); + fn launch_gemm_tiled_f32( + a: *const c_void, + b: *const c_void, + c: *mut c_void, + m: i32, + n: i32, + k: i32, + stream: *mut c_void, + ); + fn launch_gemm_tiled_bf16( + a: *const c_void, + b: *const c_void, + c: *mut c_void, + m: i32, + n: i32, + k: i32, + stream: *mut c_void, + ); } // --- FFI: cuBLAS --- @@ -46,25 +86,46 @@ unsafe extern "C" { fn cublasSetWorkspace_v2(handle: CublasHandle, workspace: *mut c_void, size: usize) -> i32; fn cublasGemmEx( handle: CublasHandle, - transa: i32, transb: i32, - m: i32, n: i32, k: i32, + transa: i32, + transb: i32, + m: i32, + n: i32, + k: i32, alpha: *const c_void, - a: *const c_void, a_type: i32, lda: i32, - b: *const c_void, b_type: i32, ldb: i32, + a: *const c_void, + a_type: i32, + lda: i32, + b: *const c_void, + b_type: i32, + ldb: i32, beta: *const c_void, - c: *mut c_void, c_type: i32, ldc: i32, + c: *mut c_void, + c_type: i32, + ldc: i32, compute_type: i32, algo: i32, ) -> i32; fn cublasGemmStridedBatchedEx( handle: CublasHandle, - transa: i32, transb: i32, - m: i32, n: i32, k: i32, + transa: i32, + transb: i32, + m: i32, + n: i32, + k: i32, alpha: *const c_void, - a: *const c_void, a_type: i32, lda: i32, stride_a: i64, - b: *const c_void, b_type: i32, ldb: i32, stride_b: i64, + a: *const c_void, + a_type: i32, + lda: i32, + stride_a: i64, + b: *const c_void, + b_type: i32, + ldb: i32, + stride_b: i64, beta: *const c_void, - c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64, + c: *mut c_void, + c_type: i32, + ldc: i32, + stride_c: i64, batch_count: i32, compute_type: i32, algo: i32, @@ -89,9 +150,16 @@ impl CublasContext { // set, so we keep the GpuBuffer in this struct. let mut workspace = GpuBuffer::alloc(CUBLAS_WORKSPACE_BYTES)?; error::check(unsafe { - cublasSetWorkspace_v2(handle, workspace.as_mut_ptr() as *mut c_void, CUBLAS_WORKSPACE_BYTES) + cublasSetWorkspace_v2( + handle, + workspace.as_mut_ptr() as *mut c_void, + CUBLAS_WORKSPACE_BYTES, + ) })?; - Ok(Self { handle, _workspace: Some(workspace) }) + Ok(Self { + handle, + _workspace: Some(workspace), + }) } } @@ -123,9 +191,7 @@ where /// Get the thread-local cuBLAS handle for use with dispatch module. pub fn cublas_handle() -> CublasHandle { - CUBLAS_CTX.with(|cell| { - cell.borrow().handle - }) + CUBLAS_CTX.with(|cell| cell.borrow().handle) } /// Matrix multiplication: C = A @ B @@ -136,8 +202,14 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { assert_eq!(b.ndim(), 2); assert_eq!(a.shape()[1], b.shape()[0], "inner dimension mismatch"); assert_eq!(a.dtype(), b.dtype(), "dtype mismatch"); - assert!(a.is_contiguous() && b.is_contiguous(), "matmul requires contiguous tensors"); - assert!(matches!(a.device(), Device::Cuda(_)), "matmul requires GPU tensors"); + assert!( + a.is_contiguous() && b.is_contiguous(), + "matmul requires contiguous tensors" + ); + assert!( + matches!(a.device(), Device::Cuda(_)), + "matmul requires GPU tensors" + ); let m = a.shape()[0]; let k = a.shape()[1]; @@ -154,32 +226,63 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { let null_stream = xserv_cuda::current_stream_raw(); match backend { - GemmBackend::Naive => { - unsafe { - match dtype { - DType::F32 => launch_gemm_naive_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), - DType::BF16 => launch_gemm_naive_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), - _ => panic!("unsupported dtype for naive GEMM"), - } + GemmBackend::Naive => unsafe { + match dtype { + DType::F32 => launch_gemm_naive_f32( + a_ptr, + b_ptr, + c_ptr, + m as i32, + n as i32, + k as i32, + null_stream, + ), + DType::BF16 => launch_gemm_naive_bf16( + a_ptr, + b_ptr, + c_ptr, + m as i32, + n as i32, + k as i32, + null_stream, + ), + _ => panic!("unsupported dtype for naive GEMM"), } - } - GemmBackend::Tiled => { - unsafe { - match dtype { - DType::F32 => launch_gemm_tiled_f32(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), - DType::BF16 => launch_gemm_tiled_bf16(a_ptr, b_ptr, c_ptr, m as i32, n as i32, k as i32, null_stream), - _ => panic!("unsupported dtype for tiled GEMM"), - } + }, + GemmBackend::Tiled => unsafe { + match dtype { + DType::F32 => launch_gemm_tiled_f32( + a_ptr, + b_ptr, + c_ptr, + m as i32, + n as i32, + k as i32, + null_stream, + ), + DType::BF16 => launch_gemm_tiled_bf16( + a_ptr, + b_ptr, + c_ptr, + m as i32, + n as i32, + k as i32, + null_stream, + ), + _ => panic!("unsupported dtype for tiled GEMM"), } - } + }, GemmBackend::CuBlas => { if m == 1 && dtype == DType::BF16 && n >= 256 { let mut fp32_buf = xserv_cuda::allocator::cached_alloc(n * 4).unwrap(); unsafe { launch_gemv_bf16( - a_ptr, b_ptr, c_ptr, + a_ptr, + b_ptr, + c_ptr, fp32_buf.as_mut_ptr() as *mut c_void, - k as i32, n as i32, + k as i32, + n as i32, null_stream, ); } @@ -197,16 +300,26 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { cublasSetStream_v2(handle, null_stream); error::check(cublasGemmEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n as i32, m as i32, k as i32, + CUBLAS_OP_N, + CUBLAS_OP_N, + n as i32, + m as i32, + k as i32, &alpha as *const f32 as *const c_void, - b_ptr, b_type, n as i32, - a_ptr, a_type, k as i32, + b_ptr, + b_type, + n as i32, + a_ptr, + a_type, + k as i32, &beta as *const f32 as *const c_void, - c_ptr, c_type, n as i32, + c_ptr, + c_type, + n as i32, CUBLAS_COMPUTE_32F, -1, - )).expect("cuBLAS GEMM failed"); + )) + .expect("cuBLAS GEMM failed"); }); } } @@ -264,17 +377,30 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor { // Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major) error::check(cublasGemmStridedBatchedEx( handle, - CUBLAS_OP_N, CUBLAS_OP_N, - n as i32, m as i32, k as i32, + CUBLAS_OP_N, + CUBLAS_OP_N, + n as i32, + m as i32, + k as i32, &alpha as *const f32 as *const c_void, - b.data_ptr() as _, b_type, n as i32, stride_b, - a.data_ptr() as _, a_type, k as i32, stride_a, + b.data_ptr() as _, + b_type, + n as i32, + stride_b, + a.data_ptr() as _, + a_type, + k as i32, + stride_a, &beta as *const f32 as *const c_void, - c.data_ptr() as *mut c_void, c_type, n as i32, stride_c, + c.data_ptr() as *mut c_void, + c_type, + n as i32, + stride_c, batch as i32, CUBLAS_COMPUTE_32F, -1, - )).expect("cuBLAS batched GEMM failed"); + )) + .expect("cuBLAS batched GEMM failed"); }); c } diff --git a/crates/xserv-kernels/src/layernorm.rs b/crates/xserv-kernels/src/layernorm.rs index e03ce51..47851e4 100644 --- a/crates/xserv-kernels/src/layernorm.rs +++ b/crates/xserv-kernels/src/layernorm.rs @@ -2,10 +2,26 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_layernorm_f32(x: *const c_void, gamma: *const c_void, beta: *const c_void, - out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); - fn launch_layernorm_bf16(x: *const c_void, gamma: *const c_void, beta: *const c_void, - out: *mut c_void, rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); + fn launch_layernorm_f32( + x: *const c_void, + gamma: *const c_void, + beta: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); + fn launch_layernorm_bf16( + x: *const c_void, + gamma: *const c_void, + beta: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); } pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor { @@ -17,21 +33,37 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor assert_eq!(beta.shape(), &[hidden_size]); let rows = x.numel() / hidden_size; - assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param"); - assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param"); + assert!( + rows <= i32::MAX as usize, + "too many rows for i32 kernel param" + ); + assert!( + hidden_size <= i32::MAX as usize, + "hidden_size too large for i32 kernel param" + ); let out = Tensor::empty(x.shape(), x.dtype(), x.device()); unsafe { match x.dtype() { DType::F32 => launch_layernorm_f32( - x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, + x.data_ptr() as _, + gamma.data_ptr() as _, + beta.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), + rows as i32, + hidden_size as i32, + eps, + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_layernorm_bf16( - x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, + x.data_ptr() as _, + gamma.data_ptr() as _, + beta.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), + rows as i32, + hidden_size as i32, + eps, + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for layernorm"), } diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index 584fa8d..98e6150 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -14,14 +14,20 @@ pub mod transpose; pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu_mul}; pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; -pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu}; -pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16}; +pub use attention::{ + attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, + paged_decode_attention_sinks, reshape_and_cache_batched_bf16, reshape_and_cache_bf16, +}; pub use embedding::{embedding, embedding_device_ids}; -pub use gemm::{batched_matmul, matmul, GemmBackend}; +pub use gemm::{GemmBackend, batched_matmul, matmul}; pub use layernorm::layernorm; pub use rmsnorm::{add_rmsnorm, rmsnorm}; -pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache}; +pub use rope::{RopeCache, rope_inplace, rope_inplace_device_pos}; pub use softmax::softmax; +pub use transpose::{ + merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, + transpose_for_rope_gpu, transpose_from_rope_gpu, +}; /// Register GPU kernels with the tensor crate. Call once at startup. pub fn init() { diff --git a/crates/xserv-kernels/src/moe.rs b/crates/xserv-kernels/src/moe.rs index 38bb1e2..e04060e 100644 --- a/crates/xserv-kernels/src/moe.rs +++ b/crates/xserv-kernels/src/moe.rs @@ -1,66 +1,113 @@ use std::ffi::c_void; use xserv_tensor::{DType, Tensor}; -use crate::gemm::{cublas_handle, CublasHandle}; +use crate::gemm::{CublasHandle, cublas_handle}; unsafe extern "C" { fn launch_moe_topk_softmax_bf16( router_logits: *const c_void, - topk_ids: *mut c_void, topk_weights: *mut c_void, - num_tokens: i32, num_experts: i32, top_k: i32, + topk_ids: *mut c_void, + topk_weights: *mut c_void, + num_tokens: i32, + num_experts: i32, + top_k: i32, stream: *mut c_void, ); fn launch_moe_replicate_bf16( - x: *const c_void, x_rep: *mut c_void, - num_tokens: i32, hidden: i32, local_experts: i32, + x: *const c_void, + x_rep: *mut c_void, + num_tokens: i32, + hidden: i32, + local_experts: i32, stream: *mut c_void, ); fn launch_moe_bias_add_3d_bf16( - x: *mut c_void, bias: *const c_void, - batch: i32, num_tokens: i32, dim: i32, + x: *mut c_void, + bias: *const c_void, + batch: i32, + num_tokens: i32, + dim: i32, stream: *mut c_void, ); fn launch_moe_weighted_sum_bf16( expert_out: *const c_void, - topk_ids: *const c_void, topk_weights: *const c_void, + topk_ids: *const c_void, + topk_weights: *const c_void, out: *mut c_void, - num_tokens: i32, hidden: i32, top_k: i32, - expert_start: i32, local_experts: i32, + num_tokens: i32, + hidden: i32, + top_k: i32, + expert_start: i32, + local_experts: i32, stream: *mut c_void, ); fn launch_moe_sparse_gemv_fp8_bf16( - x: *const c_void, w: *const c_void, w_scales: *const c_void, - bias: *const c_void, topk_ids: *const c_void, y: *mut c_void, - num_tokens: i32, n: i32, k: i32, top_k: i32, - expert_start: i32, local_experts: i32, x_per_slot: i32, + x: *const c_void, + w: *const c_void, + w_scales: *const c_void, + bias: *const c_void, + topk_ids: *const c_void, + y: *mut c_void, + num_tokens: i32, + n: i32, + k: i32, + top_k: i32, + expert_start: i32, + local_experts: i32, + x_per_slot: i32, stream: *mut c_void, ); fn launch_moe_sparse_gemv_mxfp4_bf16( - x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, - bias: *const c_void, topk_ids: *const c_void, y: *mut c_void, - num_tokens: i32, n: i32, k: i32, top_k: i32, - expert_start: i32, local_experts: i32, x_per_slot: i32, + x: *const c_void, + w_packed: *const c_void, + w_scales: *const c_void, + bias: *const c_void, + topk_ids: *const c_void, + y: *mut c_void, + num_tokens: i32, + n: i32, + k: i32, + top_k: i32, + expert_start: i32, + local_experts: i32, + x_per_slot: i32, stream: *mut c_void, ); fn launch_moe_weighted_sum_sparse_bf16( down: *const c_void, - topk_ids: *const c_void, topk_weights: *const c_void, + topk_ids: *const c_void, + topk_weights: *const c_void, out: *mut c_void, - num_tokens: i32, hidden: i32, top_k: i32, - expert_start: i32, local_experts: i32, + num_tokens: i32, + hidden: i32, + top_k: i32, + expert_start: i32, + local_experts: i32, stream: *mut c_void, ); fn cublasGemmStridedBatchedEx( handle: CublasHandle, - transa: i32, transb: i32, - m: i32, n: i32, k: i32, + transa: i32, + transb: i32, + m: i32, + n: i32, + k: i32, alpha: *const c_void, - a: *const c_void, a_type: i32, lda: i32, stride_a: i64, - b: *const c_void, b_type: i32, ldb: i32, stride_b: i64, + a: *const c_void, + a_type: i32, + lda: i32, + stride_a: i64, + b: *const c_void, + b_type: i32, + ldb: i32, + stride_b: i64, beta: *const c_void, - c: *mut c_void, c_type: i32, ldc: i32, stride_c: i64, + c: *mut c_void, + c_type: i32, + ldc: i32, + stride_c: i64, batch_count: i32, compute_type: i32, algo: i32, @@ -99,7 +146,9 @@ pub fn moe_topk_softmax( router_logits.data_ptr() as *const c_void, topk_ids.data_ptr() as *mut c_void, topk_weights.data_ptr() as *mut c_void, - num_tokens as i32, num_experts as i32, top_k as i32, + num_tokens as i32, + num_experts as i32, + top_k as i32, xserv_cuda::current_stream_raw(), ); } @@ -114,13 +163,19 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor { assert!(x.is_contiguous()); let num_tokens = x.shape()[0]; let hidden = x.shape()[1]; - let out = Tensor::empty(&[local_experts, num_tokens, hidden], DType::BF16, x.device()); + let out = Tensor::empty( + &[local_experts, num_tokens, hidden], + DType::BF16, + x.device(), + ); unsafe { launch_moe_replicate_bf16( x.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden as i32, local_experts as i32, + num_tokens as i32, + hidden as i32, + local_experts as i32, xserv_cuda::current_stream_raw(), ); } @@ -143,7 +198,9 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) { launch_moe_bias_add_3d_bf16( x.data_ptr() as *mut c_void, bias.data_ptr() as *const c_void, - batch as i32, num_tokens as i32, dim as i32, + batch as i32, + num_tokens as i32, + dim as i32, xserv_cuda::current_stream_raw(), ); } @@ -175,8 +232,11 @@ pub fn moe_weighted_sum( topk_ids.data_ptr() as *const c_void, topk_weights.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden as i32, top_k as i32, - expert_start as i32, local_experts as i32, + num_tokens as i32, + hidden as i32, + top_k as i32, + expert_start as i32, + local_experts as i32, xserv_cuda::current_stream_raw(), ); } @@ -198,9 +258,16 @@ pub fn moe_weighted_sum( /// consumer must skip them (see moe_weighted_sum_sparse). #[allow(clippy::too_many_arguments)] pub fn moe_sparse_gemv_fp8( - x: &Tensor, w_fp8_t: &Tensor, w_scales: &Tensor, bias: &Tensor, - topk_ids: &Tensor, num_tokens: usize, top_k: usize, - expert_start: usize, local_experts: usize, x_per_slot: bool, + x: &Tensor, + w_fp8_t: &Tensor, + w_scales: &Tensor, + bias: &Tensor, + topk_ids: &Tensor, + num_tokens: usize, + top_k: usize, + expert_start: usize, + local_experts: usize, + x_per_slot: bool, ) -> Tensor { assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous()); @@ -211,7 +278,14 @@ pub fn moe_sparse_gemv_fp8( // silently skip a K%16 tail. assert_eq!(k % 16, 0, "sparse FP8 GEMV requires K % 16 == 0, got {k}"); assert_eq!(x.shape()[x.ndim() - 1], k); - assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens }); + assert_eq!( + x.shape()[0], + if x_per_slot { + num_tokens * top_k + } else { + num_tokens + } + ); let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device()); unsafe { @@ -222,8 +296,13 @@ pub fn moe_sparse_gemv_fp8( bias.data_ptr() as *const c_void, topk_ids.data_ptr() as *const c_void, y.data_ptr() as *mut c_void, - num_tokens as i32, n as i32, k as i32, top_k as i32, - expert_start as i32, local_experts as i32, x_per_slot as i32, + num_tokens as i32, + n as i32, + k as i32, + top_k as i32, + expert_start as i32, + local_experts as i32, + x_per_slot as i32, xserv_cuda::current_stream_raw(), ); } @@ -234,16 +313,32 @@ pub fn moe_sparse_gemv_fp8( /// with packed 4-bit weights [E, N, K/2] + UE8M0 block scales [E, N, K/32]. #[allow(clippy::too_many_arguments)] pub fn moe_sparse_gemv_mxfp4( - x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, bias: &Tensor, - topk_ids: &Tensor, num_tokens: usize, top_k: usize, n: usize, k: usize, - expert_start: usize, local_experts: usize, x_per_slot: bool, + x: &Tensor, + w_packed: &Tensor, + w_scales: &Tensor, + bias: &Tensor, + topk_ids: &Tensor, + num_tokens: usize, + top_k: usize, + n: usize, + k: usize, + expert_start: usize, + local_experts: usize, + x_per_slot: bool, ) -> Tensor { assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous()); // 32-element MXFP4 blocks, read as uint4 (32 nibbles) per lane. assert_eq!(k % 32, 0, "sparse MXFP4 GEMV requires K % 32 == 0, got {k}"); assert_eq!(x.shape()[x.ndim() - 1], k); - assert_eq!(x.shape()[0], if x_per_slot { num_tokens * top_k } else { num_tokens }); + assert_eq!( + x.shape()[0], + if x_per_slot { + num_tokens * top_k + } else { + num_tokens + } + ); let y = Tensor::empty(&[num_tokens, top_k, n], DType::BF16, x.device()); unsafe { @@ -254,8 +349,13 @@ pub fn moe_sparse_gemv_mxfp4( bias.data_ptr() as *const c_void, topk_ids.data_ptr() as *const c_void, y.data_ptr() as *mut c_void, - num_tokens as i32, n as i32, k as i32, top_k as i32, - expert_start as i32, local_experts as i32, x_per_slot as i32, + num_tokens as i32, + n as i32, + k as i32, + top_k as i32, + expert_start as i32, + local_experts as i32, + x_per_slot as i32, xserv_cuda::current_stream_raw(), ); } @@ -286,8 +386,11 @@ pub fn moe_weighted_sum_sparse( topk_ids.data_ptr() as *const c_void, topk_weights.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden as i32, top_k as i32, - expert_start as i32, local_experts as i32, + num_tokens as i32, + hidden as i32, + top_k as i32, + expert_start as i32, + local_experts as i32, xserv_cuda::current_stream_raw(), ); } @@ -341,13 +444,25 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor { cublasSetStream_v2(handle, xserv_cuda::current_stream_raw()); let status = cublasGemmStridedBatchedEx( handle, - 0, 0, // CUBLAS_OP_N, CUBLAS_OP_N - n as i32, m as i32, k as i32, + 0, + 0, // CUBLAS_OP_N, CUBLAS_OP_N + n as i32, + m as i32, + k as i32, &alpha as *const f32 as *const c_void, - b.data_ptr() as *const c_void, CUDA_R_16BF, n as i32, stride_b, - a.data_ptr() as *const c_void, CUDA_R_16BF, k as i32, stride_a, + b.data_ptr() as *const c_void, + CUDA_R_16BF, + n as i32, + stride_b, + a.data_ptr() as *const c_void, + CUDA_R_16BF, + k as i32, + stride_a, &beta as *const f32 as *const c_void, - c.data_ptr() as *mut c_void, CUDA_R_16BF, n as i32, stride_c, + c.data_ptr() as *mut c_void, + CUDA_R_16BF, + n as i32, + stride_c, batch as i32, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT, diff --git a/crates/xserv-kernels/src/quantization.rs b/crates/xserv-kernels/src/quantization.rs index 6d40520..fcdcf3f 100644 --- a/crates/xserv-kernels/src/quantization.rs +++ b/crates/xserv-kernels/src/quantization.rs @@ -13,30 +13,46 @@ unsafe extern "C" { src: *const c_void, scales: *const c_void, dst: *mut c_void, - num_experts: i32, rows: i32, cols: i32, + num_experts: i32, + rows: i32, + cols: i32, stream: *mut c_void, ); fn launch_quantize_bf16_to_fp8e4m3_rowwise( src: *const c_void, dst: *mut c_void, scales: *mut c_void, - num_rows: i32, cols: i32, + num_rows: i32, + cols: i32, stream: *mut c_void, ); fn launch_rowwise_scale_moe_bf16( data: *mut c_void, a_scales: *const c_void, b_scales: *const c_void, - num_rows: i32, cols: i32, tokens: i32, + num_rows: i32, + cols: i32, + tokens: i32, stream: *mut c_void, ); fn launch_batched_gemv_mxfp4_bf16( - x: *const c_void, w_packed: *const c_void, w_scales: *const c_void, y: *mut c_void, - e: i32, n: i32, k: i32, stream: *mut c_void, + x: *const c_void, + w_packed: *const c_void, + w_scales: *const c_void, + y: *mut c_void, + e: i32, + n: i32, + k: i32, + stream: *mut c_void, ); fn launch_dequant_mxfp4_to_bf16_t( - w_packed: *const c_void, w_scales: *const c_void, out: *mut c_void, - e: i32, n: i32, k: i32, stream: *mut c_void, + w_packed: *const c_void, + w_scales: *const c_void, + out: *mut c_void, + e: i32, + n: i32, + k: i32, + stream: *mut c_void, ); } @@ -66,34 +82,68 @@ struct CublasLtMatmulHeuristicResult { unsafe extern "C" { fn cublasLtCreate(handle: *mut CublasLtHandle) -> i32; fn cublasLtDestroy(handle: CublasLtHandle) -> i32; - fn cublasLtMatmulDescCreate(desc: *mut CublasLtMatmulDesc, compute_type: i32, scale_type: i32) -> i32; + fn cublasLtMatmulDescCreate( + desc: *mut CublasLtMatmulDesc, + compute_type: i32, + scale_type: i32, + ) -> i32; fn cublasLtMatmulDescDestroy(desc: CublasLtMatmulDesc) -> i32; - fn cublasLtMatmulDescSetAttribute(desc: CublasLtMatmulDesc, attr: i32, buf: *const c_void, size: usize) -> i32; - fn cublasLtMatrixLayoutCreate(layout: *mut CublasLtMatrixLayout, dtype: i32, rows: u64, cols: u64, ld: i64) -> i32; + fn cublasLtMatmulDescSetAttribute( + desc: CublasLtMatmulDesc, + attr: i32, + buf: *const c_void, + size: usize, + ) -> i32; + fn cublasLtMatrixLayoutCreate( + layout: *mut CublasLtMatrixLayout, + dtype: i32, + rows: u64, + cols: u64, + ld: i64, + ) -> i32; fn cublasLtMatrixLayoutDestroy(layout: CublasLtMatrixLayout) -> i32; - fn cublasLtMatrixLayoutSetAttribute(layout: CublasLtMatrixLayout, attr: i32, buf: *const c_void, size: usize) -> i32; + fn cublasLtMatrixLayoutSetAttribute( + layout: CublasLtMatrixLayout, + attr: i32, + buf: *const c_void, + size: usize, + ) -> i32; fn cublasLtMatmulPreferenceCreate(pref: *mut CublasLtMatmulPreference) -> i32; fn cublasLtMatmulPreferenceDestroy(pref: CublasLtMatmulPreference) -> i32; - fn cublasLtMatmulPreferenceSetAttribute(pref: CublasLtMatmulPreference, attr: i32, buf: *const c_void, size: usize) -> i32; + fn cublasLtMatmulPreferenceSetAttribute( + pref: CublasLtMatmulPreference, + attr: i32, + buf: *const c_void, + size: usize, + ) -> i32; fn cublasLtMatmulAlgoGetHeuristic( - handle: CublasLtHandle, desc: CublasLtMatmulDesc, - a_layout: CublasLtMatrixLayout, b_layout: CublasLtMatrixLayout, - c_layout: CublasLtMatrixLayout, d_layout: CublasLtMatrixLayout, + handle: CublasLtHandle, + desc: CublasLtMatmulDesc, + a_layout: CublasLtMatrixLayout, + b_layout: CublasLtMatrixLayout, + c_layout: CublasLtMatrixLayout, + d_layout: CublasLtMatrixLayout, pref: CublasLtMatmulPreference, requested: i32, results: *mut CublasLtMatmulHeuristicResult, found: *mut i32, ) -> i32; fn cublasLtMatmul( - handle: CublasLtHandle, desc: CublasLtMatmulDesc, + handle: CublasLtHandle, + desc: CublasLtMatmulDesc, alpha: *const c_void, - a: *const c_void, a_layout: CublasLtMatrixLayout, - b: *const c_void, b_layout: CublasLtMatrixLayout, + a: *const c_void, + a_layout: CublasLtMatrixLayout, + b: *const c_void, + b_layout: CublasLtMatrixLayout, beta: *const c_void, - c: *const c_void, c_layout: CublasLtMatrixLayout, - d: *mut c_void, d_layout: CublasLtMatrixLayout, + c: *const c_void, + c_layout: CublasLtMatrixLayout, + d: *mut c_void, + d_layout: CublasLtMatrixLayout, algo: *const CublasLtMatmulAlgo, - workspace: *mut c_void, workspace_size: usize, + workspace: *mut c_void, + workspace_size: usize, stream: *mut c_void, ) -> i32; } @@ -153,8 +203,15 @@ impl CublasLtContext { assert_eq!(status, 0, "cublasLtCreate failed: {status}"); let workspace = GpuBuffer::alloc(WORKSPACE_BYTES).expect("alloc cublasLt workspace"); let mut one_buf = GpuBuffer::alloc(4).expect("alloc cublasLt fp8 scale"); - one_buf.copy_from_host(&1.0f32.to_le_bytes()).expect("init fp8 scale"); - Self { handle, workspace, one_buf, plans: HashMap::new() } + one_buf + .copy_from_host(&1.0f32.to_le_bytes()) + .expect("init fp8 scale"); + Self { + handle, + workspace, + one_buf, + plans: HashMap::new(), + } } /// Get the cached strided-batched plan for (m, n, k, batch), building it on @@ -210,10 +267,25 @@ unsafe fn build_fp8_plan( // transA=T (required for FP8 on Blackwell) let trans_a: i32 = 1; - cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a as *const i32 as _, 4); + cublasLtMatmulDescSetAttribute( + desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a as *const i32 as _, + 4, + ); let ptr_sz = std::mem::size_of::<*const c_void>(); - cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz); - cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &one_ptr as *const _ as _, ptr_sz); + cublasLtMatmulDescSetAttribute( + desc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &one_ptr as *const _ as _, + ptr_sz, + ); + cublasLtMatmulDescSetAttribute( + desc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &one_ptr as *const _ as _, + ptr_sz, + ); // Per-expert strides in ELEMENTS for the strided-batch layout. let stride_a = (n * k) as i64; // weights [N, K] @@ -221,10 +293,18 @@ unsafe fn build_fp8_plan( let stride_c = (m * n) as i64; // output [M, N] let bc = batch as i32; let set_batch = |layout: CublasLtMatrixLayout, stride: i64| { - cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &bc as *const i32 as _, 4); - cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &stride as *const i64 as _, 8); + cublasLtMatrixLayoutSetAttribute( + layout, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &bc as *const i32 as _, + 4, + ); + cublasLtMatrixLayoutSetAttribute( + layout, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride as *const i64 as _, + 8, + ); }; // "A" layout (weights, transposed): physical (K, N) col-major, ld=K @@ -246,20 +326,39 @@ unsafe fn build_fp8_plan( let mut pref: CublasLtMatmulPreference = std::ptr::null_mut(); cublasLtMatmulPreferenceCreate(&mut pref); let ws_bytes = WORKSPACE_BYTES as u64; - cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws_bytes as *const u64 as _, 8); + cublasLtMatmulPreferenceSetAttribute( + pref, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &ws_bytes as *const u64 as _, + 8, + ); let mut heuristic = std::mem::zeroed::(); let mut found: i32 = 0; let status = cublasLtMatmulAlgoGetHeuristic( - handle, desc, a_layout, b_layout, c_layout, d_layout, - pref, 1, &mut heuristic, &mut found, + handle, + desc, + a_layout, + b_layout, + c_layout, + d_layout, + pref, + 1, + &mut heuristic, + &mut found, + ); + assert!( + status == 0 && found > 0, + "cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}" ); - assert!(status == 0 && found > 0, - "cublasLtMatmulAlgoGetHeuristic failed for batched FP8 GEMM (m={m}, n={n}, k={k}, batch={batch}): status={status}, found={found}"); cublasLtMatmulPreferenceDestroy(pref); Fp8Plan { - desc, a_layout, b_layout, c_layout, d_layout, + desc, + a_layout, + b_layout, + c_layout, + d_layout, algo: heuristic.algo, workspace_size: heuristic.workspace_size, } @@ -299,7 +398,9 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor { src.data_ptr() as *const c_void, scales.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, - num_experts as i32, rows as i32, cols as i32, + num_experts as i32, + rows as i32, + cols as i32, xserv_cuda::current_stream_raw(), ); } @@ -329,7 +430,8 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) { src.data_ptr() as *const c_void, fp8_out.data_ptr() as *mut c_void, scales.data_ptr() as *mut c_void, - num_rows as i32, cols as i32, + num_rows as i32, + cols as i32, xserv_cuda::current_stream_raw(), ); } @@ -392,23 +494,27 @@ pub fn batched_gemm_fp8( unsafe { let status = cublasLtMatmul( - handle, plan.desc, + handle, + plan.desc, &alpha as *const f32 as _, b_fp8_t.data_ptr() as *const c_void, // cuBLASLt "A" = weights plan.a_layout, - a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations + a_fp8.data_ptr() as *const c_void, // cuBLASLt "B" = activations plan.b_layout, &beta as *const f32 as _, - c.data_ptr() as *const c_void, // C (unused with beta=0) + c.data_ptr() as *const c_void, // C (unused with beta=0) plan.c_layout, - c.data_ptr() as *mut c_void, // D = output + c.data_ptr() as *mut c_void, // D = output plan.d_layout, &plan.algo, ws_ptr, plan.workspace_size, xserv_cuda::current_stream_raw(), ); - assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}"); + assert_eq!( + status, 0, + "batched cublasLtMatmul FP8 failed: status={status}" + ); } }); @@ -423,7 +529,9 @@ pub fn batched_gemm_fp8( c.data_ptr() as *mut c_void, a_scales.data_ptr() as *const c_void, b_scales.data_ptr() as *const c_void, - total_rows, n as i32, m as i32, + total_rows, + n as i32, + m as i32, xserv_cuda::current_stream_raw(), ); } @@ -442,7 +550,13 @@ pub fn batched_gemm_fp8( /// w_scales: [E, N, K/32] byte tensor — UE8M0 scale per 32-element block /// /// Returns: [E, N] BF16, where y[e,n] = sum_k x[e,k] * dequant(W[e,n,k]). -pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: usize, k: usize) -> Tensor { +pub fn batched_gemv_mxfp4( + x: &Tensor, + w_packed: &Tensor, + w_scales: &Tensor, + n: usize, + k: usize, +) -> Tensor { assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous()); let e = x.shape()[0]; @@ -455,7 +569,9 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u w_packed.data_ptr() as *const c_void, w_scales.data_ptr() as *const c_void, y.data_ptr() as *mut c_void, - e as i32, n as i32, k as i32, + e as i32, + n as i32, + k as i32, xserv_cuda::current_stream_raw(), ); } @@ -464,14 +580,22 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u /// Dequantize MXFP4 weights [E, N, K] → BF16 [E, K, N] for the prefill GEMM path /// (the BF16 batched GEMM expects weights as [E, K, N]). -pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n: usize, k: usize) -> Tensor { +pub fn dequant_mxfp4_to_bf16_t( + w_packed: &Tensor, + w_scales: &Tensor, + e: usize, + n: usize, + k: usize, +) -> Tensor { let out = Tensor::empty(&[e, k, n], DType::BF16, w_packed.device()); unsafe { launch_dequant_mxfp4_to_bf16_t( w_packed.data_ptr() as *const c_void, w_scales.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, - e as i32, n as i32, k as i32, + e as i32, + n as i32, + k as i32, xserv_cuda::current_stream_raw(), ); } diff --git a/crates/xserv-kernels/src/rmsnorm.rs b/crates/xserv-kernels/src/rmsnorm.rs index ad4e981..70db94b 100644 --- a/crates/xserv-kernels/src/rmsnorm.rs +++ b/crates/xserv-kernels/src/rmsnorm.rs @@ -2,13 +2,35 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_rmsnorm_f32(x: *const c_void, gamma: *const c_void, out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); - fn launch_rmsnorm_bf16(x: *const c_void, gamma: *const c_void, out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); - fn launch_add_rmsnorm_bf16(x: *const c_void, residual: *const c_void, gamma: *const c_void, - normed_out: *mut c_void, sum_out: *mut c_void, - rows: i32, hidden_size: i32, eps: f32, stream: *mut c_void); + fn launch_rmsnorm_f32( + x: *const c_void, + gamma: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); + fn launch_rmsnorm_bf16( + x: *const c_void, + gamma: *const c_void, + out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); + fn launch_add_rmsnorm_bf16( + x: *const c_void, + residual: *const c_void, + gamma: *const c_void, + normed_out: *mut c_void, + sum_out: *mut c_void, + rows: i32, + hidden_size: i32, + eps: f32, + stream: *mut c_void, + ); } pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor { @@ -20,19 +42,35 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor { assert_eq!(x.dtype(), gamma.dtype()); let rows = x.numel() / hidden_size; - assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param"); - assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param"); + assert!( + rows <= i32::MAX as usize, + "too many rows for i32 kernel param" + ); + assert!( + hidden_size <= i32::MAX as usize, + "hidden_size too large for i32 kernel param" + ); let out = Tensor::empty(x.shape(), x.dtype(), x.device()); unsafe { match x.dtype() { DType::F32 => launch_rmsnorm_f32( - x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + gamma.data_ptr() as _, + out.data_ptr() as *mut c_void, + rows as i32, + hidden_size as i32, + eps, + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_rmsnorm_bf16( - x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + gamma.data_ptr() as _, + out.data_ptr() as *mut c_void, + rows as i32, + hidden_size as i32, + eps, + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for rmsnorm"), } @@ -56,8 +94,14 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> ( assert_eq!(gamma.shape(), &[hidden_size]); let rows = x.numel() / hidden_size; - assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param"); - assert!(hidden_size <= i32::MAX as usize, "hidden_size too large for i32 kernel param"); + assert!( + rows <= i32::MAX as usize, + "too many rows for i32 kernel param" + ); + assert!( + hidden_size <= i32::MAX as usize, + "hidden_size too large for i32 kernel param" + ); let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device()); let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device()); diff --git a/crates/xserv-kernels/src/rope.rs b/crates/xserv-kernels/src/rope.rs index 829fd6c..552a59e 100644 --- a/crates/xserv-kernels/src/rope.rs +++ b/crates/xserv-kernels/src/rope.rs @@ -3,15 +3,34 @@ use xserv_cuda::GpuBuffer; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_rope_f32(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void, - positions: *const c_void, num_tokens: i32, num_heads: i32, - head_dim: i32, stream: *mut c_void); - fn launch_rope_bf16(x: *mut c_void, cos_cache: *const c_void, sin_cache: *const c_void, - positions: *const c_void, num_tokens: i32, num_heads: i32, - head_dim: i32, stream: *mut c_void); - fn launch_compute_rope_cache(cos_cache: *mut c_void, sin_cache: *mut c_void, - max_seq_len: i32, half_dim: i32, theta: f32, - stream: *mut c_void); + fn launch_rope_f32( + x: *mut c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + positions: *const c_void, + num_tokens: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_rope_bf16( + x: *mut c_void, + cos_cache: *const c_void, + sin_cache: *const c_void, + positions: *const c_void, + num_tokens: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_compute_rope_cache( + cos_cache: *mut c_void, + sin_cache: *mut c_void, + max_seq_len: i32, + half_dim: i32, + theta: f32, + stream: *mut c_void, + ); } pub struct RopeCache { @@ -30,12 +49,21 @@ impl RopeCache { unsafe { launch_compute_rope_cache( - cos.as_mut_ptr() as _, sin.as_mut_ptr() as _, - max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(), + cos.as_mut_ptr() as _, + sin.as_mut_ptr() as _, + max_seq_len as i32, + half_dim as i32, + theta, + xserv_cuda::current_stream_raw(), ); } - Self { cos, sin, max_seq_len, half_dim } + Self { + cos, + sin, + max_seq_len, + half_dim, + } } /// YaRN (Yet another RoPE extensioN) RoPE cache. Applies frequency-dependent @@ -68,8 +96,8 @@ impl RopeCache { let mut inv_freq = vec![0.0f64; half_dim]; for i in 0..half_dim { let pos_freq = theta.powf((2 * i) as f64 / dim); - let inv_freq_extrapolation = 1.0 / pos_freq; // original - let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled + let inv_freq_extrapolation = 1.0 / pos_freq; // original + let inv_freq_interpolation = 1.0 / (factor * pos_freq); // scaled // Linear ramp: 0 where we keep original, 1 where we interpolate let ramp = if (high - low).abs() < 0.001 { @@ -101,16 +129,19 @@ impl RopeCache { let nbytes = total * std::mem::size_of::(); let mut cos = GpuBuffer::alloc(nbytes).expect("alloc yarn cos_cache"); let mut sin = GpuBuffer::alloc(nbytes).expect("alloc yarn sin_cache"); - let cos_bytes = unsafe { - std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) - }; - let sin_bytes = unsafe { - std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) - }; + let cos_bytes = + unsafe { std::slice::from_raw_parts(cos_host.as_ptr() as *const u8, nbytes) }; + let sin_bytes = + unsafe { std::slice::from_raw_parts(sin_host.as_ptr() as *const u8, nbytes) }; cos.copy_from_host(cos_bytes).unwrap(); sin.copy_from_host(sin_bytes).unwrap(); - Self { cos, sin, max_seq_len, half_dim } + Self { + cos, + sin, + max_seq_len, + half_dim, + } } } @@ -133,7 +164,8 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) { num_tokens * std::mem::size_of::(), ) }; - let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions"); + let mut pos_gpu = + xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions"); pos_gpu.copy_from_host(pos_bytes).unwrap(); rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void); @@ -155,16 +187,22 @@ pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_ match x.dtype() { DType::F32 => launch_rope_f32( x.data_ptr() as *mut c_void, - cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, + cache.cos.as_ptr() as _, + cache.sin.as_ptr() as _, pos_gpu, - num_tokens as i32, num_heads as i32, head_dim as i32, + num_tokens as i32, + num_heads as i32, + head_dim as i32, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_rope_bf16( x.data_ptr() as *mut c_void, - cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, + cache.cos.as_ptr() as _, + cache.sin.as_ptr() as _, pos_gpu, - num_tokens as i32, num_heads as i32, head_dim as i32, + num_tokens as i32, + num_heads as i32, + head_dim as i32, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for rope"), diff --git a/crates/xserv-kernels/src/softmax.rs b/crates/xserv-kernels/src/softmax.rs index e83c9be..01a5696 100644 --- a/crates/xserv-kernels/src/softmax.rs +++ b/crates/xserv-kernels/src/softmax.rs @@ -2,8 +2,20 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_softmax_f32(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void); - fn launch_softmax_bf16(x: *const c_void, out: *mut c_void, rows: i32, cols: i32, stream: *mut c_void); + fn launch_softmax_f32( + x: *const c_void, + out: *mut c_void, + rows: i32, + cols: i32, + stream: *mut c_void, + ); + fn launch_softmax_bf16( + x: *const c_void, + out: *mut c_void, + rows: i32, + cols: i32, + stream: *mut c_void, + ); } /// Softmax along the last dimension. @@ -14,19 +26,31 @@ pub fn softmax(x: &Tensor) -> Tensor { let cols = *x.shape().last().unwrap(); let rows = x.numel() / cols; - assert!(rows <= i32::MAX as usize, "too many rows for i32 kernel param"); - assert!(cols <= i32::MAX as usize, "cols too large for i32 kernel param"); + assert!( + rows <= i32::MAX as usize, + "too many rows for i32 kernel param" + ); + assert!( + cols <= i32::MAX as usize, + "cols too large for i32 kernel param" + ); let out = Tensor::empty(x.shape(), x.dtype(), x.device()); unsafe { match x.dtype() { DType::F32 => launch_softmax_f32( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + rows as i32, + cols as i32, + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_softmax_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + rows as i32, + cols as i32, + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for softmax"), } diff --git a/crates/xserv-kernels/src/transpose.rs b/crates/xserv-kernels/src/transpose.rs index 0260801..98c9903 100644 --- a/crates/xserv-kernels/src/transpose.rs +++ b/crates/xserv-kernels/src/transpose.rs @@ -2,19 +2,79 @@ use std::ffi::c_void; use xserv_tensor::{DType, Device, Tensor}; unsafe extern "C" { - fn launch_reshape_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_merge_heads_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_transpose_hsd_to_shd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_transpose_shd_to_hsd_bf16(inp: *const c_void, out: *mut c_void, seq_len: i32, num_heads: i32, head_dim: i32, stream: *mut c_void); - fn launch_repeat_kv_bf16(inp: *const c_void, out: *mut c_void, kv_heads: i32, n_rep: i32, seq_len: i32, head_dim: i32, stream: *mut c_void); - fn launch_strided_copy_bf16(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32, - shape0: i32, shape1: i32, shape2: i32, shape3: i32, - in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32, - in_offset: i32, stream: *mut c_void); - fn launch_strided_copy_f32(inp: *const c_void, out: *mut c_void, numel: i32, ndim: i32, - shape0: i32, shape1: i32, shape2: i32, shape3: i32, - in_stride0: i32, in_stride1: i32, in_stride2: i32, in_stride3: i32, - in_offset: i32, stream: *mut c_void); + fn launch_reshape_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_merge_heads_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_transpose_hsd_to_shd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_transpose_shd_to_hsd_bf16( + inp: *const c_void, + out: *mut c_void, + seq_len: i32, + num_heads: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_repeat_kv_bf16( + inp: *const c_void, + out: *mut c_void, + kv_heads: i32, + n_rep: i32, + seq_len: i32, + head_dim: i32, + stream: *mut c_void, + ); + fn launch_strided_copy_bf16( + inp: *const c_void, + out: *mut c_void, + numel: i32, + ndim: i32, + shape0: i32, + shape1: i32, + shape2: i32, + shape3: i32, + in_stride0: i32, + in_stride1: i32, + in_stride2: i32, + in_stride3: i32, + in_offset: i32, + stream: *mut c_void, + ); + fn launch_strided_copy_f32( + inp: *const c_void, + out: *mut c_void, + numel: i32, + ndim: i32, + shape0: i32, + shape1: i32, + shape2: i32, + shape3: i32, + in_stride0: i32, + in_stride1: i32, + in_stride2: i32, + in_stride3: i32, + in_offset: i32, + stream: *mut c_void, + ); } /// [S, H*D] → [1, H, S, D] on GPU (BF16) @@ -24,8 +84,12 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device()); unsafe { launch_reshape_heads_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + seq_len as i32, + num_heads as i32, + head_dim as i32, + xserv_cuda::current_stream_raw(), ); } out @@ -39,36 +103,58 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device()); unsafe { launch_merge_heads_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + seq_len as i32, + num_heads as i32, + head_dim as i32, + xserv_cuda::current_stream_raw(), ); } out } /// [1, H, S, D] → [S, H, D] for RoPE on GPU (BF16) -pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor { +pub fn transpose_for_rope_gpu( + x: &Tensor, + seq_len: usize, + num_heads: usize, + head_dim: usize, +) -> Tensor { assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device()); unsafe { launch_transpose_hsd_to_shd_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + seq_len as i32, + num_heads as i32, + head_dim as i32, + xserv_cuda::current_stream_raw(), ); } out } /// [S, H, D] → [1, H, S, D] after RoPE on GPU (BF16) -pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor { +pub fn transpose_from_rope_gpu( + x: &Tensor, + seq_len: usize, + num_heads: usize, + head_dim: usize, +) -> Tensor { assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device()); unsafe { launch_transpose_shd_to_hsd_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + seq_len as i32, + num_heads as i32, + head_dim as i32, + xserv_cuda::current_stream_raw(), ); } out @@ -76,7 +162,9 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea /// [1, KV_H, S, D] → [1, KV_H*n_rep, S, D] on GPU (BF16) pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor { - if n_rep == 1 { return x.clone(); } + if n_rep == 1 { + return x.clone(); + } assert_eq!(x.dtype(), DType::BF16); assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_))); let kv_heads = x.shape()[1]; @@ -86,8 +174,13 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor { let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device()); unsafe { launch_repeat_kv_bf16( - x.data_ptr() as _, out.data_ptr() as *mut c_void, - kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(), + x.data_ptr() as _, + out.data_ptr() as *mut c_void, + kv_heads as i32, + n_rep as i32, + seq_len as i32, + head_dim as i32, + xserv_cuda::current_stream_raw(), ); } out @@ -122,20 +215,41 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor { unsafe { match x.dtype() { DType::BF16 => launch_strided_copy_bf16( - storage_ptr as _, out.data_ptr() as *mut c_void, - numel as i32, ndim as i32, - shape4[0], shape4[1], shape4[2], shape4[3], - strides4[0], strides4[1], strides4[2], strides4[3], - in_offset, xserv_cuda::current_stream_raw(), + storage_ptr as _, + out.data_ptr() as *mut c_void, + numel as i32, + ndim as i32, + shape4[0], + shape4[1], + shape4[2], + shape4[3], + strides4[0], + strides4[1], + strides4[2], + strides4[3], + in_offset, + xserv_cuda::current_stream_raw(), ), DType::F32 => launch_strided_copy_f32( - storage_ptr as _, out.data_ptr() as *mut c_void, - numel as i32, ndim as i32, - shape4[0], shape4[1], shape4[2], shape4[3], - strides4[0], strides4[1], strides4[2], strides4[3], - in_offset, xserv_cuda::current_stream_raw(), + storage_ptr as _, + out.data_ptr() as *mut c_void, + numel as i32, + ndim as i32, + shape4[0], + shape4[1], + shape4[2], + shape4[3], + strides4[0], + strides4[1], + strides4[2], + strides4[3], + in_offset, + xserv_cuda::current_stream_raw(), + ), + _ => panic!( + "strided_to_contiguous_gpu: unsupported dtype {:?}", + x.dtype() ), - _ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()), } } out diff --git a/crates/xserv-kernels/tests/attention_test.rs b/crates/xserv-kernels/tests/attention_test.rs index 744b58e..7c21504 100644 --- a/crates/xserv-kernels/tests/attention_test.rs +++ b/crates/xserv-kernels/tests/attention_test.rs @@ -1,11 +1,21 @@ use xserv_kernels::*; use xserv_tensor::{Device, Tensor}; -fn init() { xserv_cuda::device::set_device(0).unwrap(); } +fn init() { + xserv_cuda::device::set_device(0).unwrap(); +} -fn cpu_attention(q: &[f32], k: &[f32], v: &[f32], - batch: usize, heads: usize, q_len: usize, kv_len: usize, head_dim: usize, - causal: bool) -> Vec { +fn cpu_attention( + q: &[f32], + k: &[f32], + v: &[f32], + batch: usize, + heads: usize, + q_len: usize, + kv_len: usize, + head_dim: usize, + causal: bool, +) -> Vec { let mut out = vec![0.0f32; batch * heads * q_len * head_dim]; let scale = 1.0 / (head_dim as f32).sqrt(); @@ -70,8 +80,13 @@ fn check_close(a: &[f32], b: &[f32], atol: f32, name: &str) { let mut max_err = 0.0f32; for (i, (x, y)) in a.iter().zip(b).enumerate() { let err = (x - y).abs(); - if err > max_err { max_err = err; } - assert!(err <= atol, "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}"); + if err > max_err { + max_err = err; + } + assert!( + err <= atol, + "{name}: mismatch at [{i}]: got {x}, expected {y}, err {err}" + ); } println!("{name}: max_err = {max_err:.6e}"); } @@ -105,7 +120,9 @@ fn test_batched_matmul() { for i in 0..m { for j in 0..n { let mut s = 0.0f32; - for kk in 0..k { s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; } + for kk in 0..k { + s += a_cpu[i * k + kk] * b_cpu[kk * n + j]; + } expected[i * n + j] = s; } } @@ -116,7 +133,10 @@ fn test_batched_matmul() { #[test] fn test_attention_no_causal() { init(); - let b = 1; let h = 2; let s = 8; let d = 16; + let b = 1; + let h = 2; + let s = 8; + let d = 16; let q_data = make_data(b * h * s * d); let k_data = make_data(b * h * s * d); let v_data = make_data(b * h * s * d); @@ -126,13 +146,21 @@ fn test_attention_no_causal() { let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); let out = attention(&q, &k, &v, false).to_device(Device::Cpu); - check_close(out.as_slice::(), &expected, 1e-4, "attention_no_causal"); + check_close( + out.as_slice::(), + &expected, + 1e-4, + "attention_no_causal", + ); } #[test] fn test_attention_causal() { init(); - let b = 1; let h = 2; let s = 16; let d = 32; + let b = 1; + let h = 2; + let s = 16; + let d = 32; let q_data = make_data(b * h * s * d); let k_data = make_data(b * h * s * d); let v_data = make_data(b * h * s * d); @@ -148,7 +176,10 @@ fn test_attention_causal() { #[test] fn test_attention_causal_larger() { init(); - let b = 2; let h = 4; let s = 64; let d = 64; + let b = 2; + let h = 4; + let s = 64; + let d = 64; let q_data = make_data(b * h * s * d); let k_data = make_data(b * h * s * d); let v_data = make_data(b * h * s * d); @@ -158,18 +189,28 @@ fn test_attention_causal_larger() { let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); let v = Tensor::from_slice(&v_data, &[b, h, s, d]).to_device(Device::Cuda(0)); let out = attention(&q, &k, &v, true).to_device(Device::Cpu); - check_close(out.as_slice::(), &expected, 1e-2, "attention_causal_larger"); + check_close( + out.as_slice::(), + &expected, + 1e-2, + "attention_causal_larger", + ); } #[test] fn test_attention_causal_first_row_sees_only_first_token() { init(); - let b = 1; let h = 1; let s = 4; let d = 8; + let b = 1; + let h = 1; + let s = 4; + let d = 8; let q_data = make_data(b * h * s * d); let k_data = make_data(b * h * s * d); - let v_data: Vec = (0..s * d).map(|i| { - if i < d { 1.0 } else { 0.0 } // only first V row is nonzero - }).collect(); + let v_data: Vec = (0..s * d) + .map(|i| { + if i < d { 1.0 } else { 0.0 } // only first V row is nonzero + }) + .collect(); let q = Tensor::from_slice(&q_data, &[b, h, s, d]).to_device(Device::Cuda(0)); let k = Tensor::from_slice(&k_data, &[b, h, s, d]).to_device(Device::Cuda(0)); @@ -181,7 +222,11 @@ fn test_attention_causal_first_row_sees_only_first_token() { // output[0] should be exactly V[0] = [1, 1, 1, ...1] let result = out.as_slice::(); for i in 0..d { - assert!((result[i] - 1.0).abs() < 1e-5, - "first row should equal V[0], got {} at dim {}", result[i], i); + assert!( + (result[i] - 1.0).abs() < 1e-5, + "first row should equal V[0], got {} at dim {}", + result[i], + i + ); } } diff --git a/crates/xserv-kernels/tests/gemm_test.rs b/crates/xserv-kernels/tests/gemm_test.rs index f49ccd7..54a157b 100644 --- a/crates/xserv-kernels/tests/gemm_test.rs +++ b/crates/xserv-kernels/tests/gemm_test.rs @@ -1,5 +1,5 @@ use half::bf16; -use xserv_kernels::{matmul, GemmBackend}; +use xserv_kernels::{GemmBackend, matmul}; use xserv_tensor::{Device, Tensor}; fn cpu_matmul_f32(a: &[f32], b: &[f32], m: usize, n: usize, k: usize) -> Vec { @@ -75,70 +75,110 @@ fn run_gemm_test_bf16(backend: GemmBackend, m: usize, n: usize, k: usize) { // --- F32 tests --- #[test] -fn test_gemm_naive_f32_small() { run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); } +fn test_gemm_naive_f32_small() { + run_gemm_test_f32(GemmBackend::Naive, 4, 4, 4); +} #[test] -fn test_gemm_naive_f32_medium() { run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); } +fn test_gemm_naive_f32_medium() { + run_gemm_test_f32(GemmBackend::Naive, 64, 64, 64); +} #[test] -fn test_gemm_naive_f32_rect() { run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); } +fn test_gemm_naive_f32_rect() { + run_gemm_test_f32(GemmBackend::Naive, 32, 64, 48); +} #[test] -fn test_gemm_tiled_f32_small() { run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); } +fn test_gemm_tiled_f32_small() { + run_gemm_test_f32(GemmBackend::Tiled, 4, 4, 4); +} #[test] -fn test_gemm_tiled_f32_medium() { run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); } +fn test_gemm_tiled_f32_medium() { + run_gemm_test_f32(GemmBackend::Tiled, 128, 128, 128); +} #[test] -fn test_gemm_tiled_f32_rect() { run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); } +fn test_gemm_tiled_f32_rect() { + run_gemm_test_f32(GemmBackend::Tiled, 65, 33, 97); +} #[test] -fn test_gemm_cublas_f32_small() { run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); } +fn test_gemm_cublas_f32_small() { + run_gemm_test_f32(GemmBackend::CuBlas, 4, 4, 4); +} #[test] -fn test_gemm_cublas_f32_medium() { run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); } +fn test_gemm_cublas_f32_medium() { + run_gemm_test_f32(GemmBackend::CuBlas, 256, 256, 256); +} #[test] -fn test_gemm_cublas_f32_rect() { run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); } +fn test_gemm_cublas_f32_rect() { + run_gemm_test_f32(GemmBackend::CuBlas, 65, 33, 97); +} // --- BF16 tests --- #[test] -fn test_gemm_naive_bf16_small() { run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); } +fn test_gemm_naive_bf16_small() { + run_gemm_test_bf16(GemmBackend::Naive, 4, 4, 4); +} #[test] -fn test_gemm_naive_bf16_medium() { run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); } +fn test_gemm_naive_bf16_medium() { + run_gemm_test_bf16(GemmBackend::Naive, 64, 64, 64); +} #[test] -fn test_gemm_tiled_bf16_small() { run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); } +fn test_gemm_tiled_bf16_small() { + run_gemm_test_bf16(GemmBackend::Tiled, 4, 4, 4); +} #[test] -fn test_gemm_tiled_bf16_medium() { run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); } +fn test_gemm_tiled_bf16_medium() { + run_gemm_test_bf16(GemmBackend::Tiled, 128, 128, 128); +} #[test] -fn test_gemm_cublas_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); } +fn test_gemm_cublas_bf16_small() { + run_gemm_test_bf16(GemmBackend::CuBlas, 4, 4, 4); +} #[test] -fn test_gemm_cublas_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); } +fn test_gemm_cublas_bf16_medium() { + run_gemm_test_bf16(GemmBackend::CuBlas, 256, 256, 256); +} // --- Custom GEMV tests (M=1, BF16 fast path) --- #[test] -fn test_gemv_bf16_small() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); } +fn test_gemv_bf16_small() { + run_gemm_test_bf16(GemmBackend::CuBlas, 1, 64, 64); +} #[test] -fn test_gemv_bf16_medium() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); } +fn test_gemv_bf16_medium() { + run_gemm_test_bf16(GemmBackend::CuBlas, 1, 256, 256); +} #[test] -fn test_gemv_bf16_4096() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); } +fn test_gemv_bf16_4096() { + run_gemm_test_bf16(GemmBackend::CuBlas, 1, 4096, 4096); +} #[test] -fn test_gemv_bf16_rect() { run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); } +fn test_gemv_bf16_rect() { + run_gemm_test_bf16(GemmBackend::CuBlas, 1, 512, 4096); +} // --- Larger benchmark-style tests --- #[test] -fn test_gemm_cublas_f32_1024() { run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); } +fn test_gemm_cublas_f32_1024() { + run_gemm_test_f32(GemmBackend::CuBlas, 1024, 1024, 1024); +} #[test] fn test_gemm_consistency_all_backends() { diff --git a/crates/xserv-kernels/tests/ops_test.rs b/crates/xserv-kernels/tests/ops_test.rs index aced053..c8f04bf 100644 --- a/crates/xserv-kernels/tests/ops_test.rs +++ b/crates/xserv-kernels/tests/ops_test.rs @@ -2,7 +2,9 @@ use half::bf16; use xserv_kernels::*; use xserv_tensor::{Device, Tensor}; -fn init() { xserv_cuda::device::set_device(0).unwrap(); } +fn init() { + xserv_cuda::device::set_device(0).unwrap(); +} // --- CPU reference implementations --- @@ -37,10 +39,12 @@ fn cpu_layernorm(x: &[f32], gamma: &[f32], beta: &[f32], eps: f32, hidden: usize fn cpu_gelu(x: &[f32]) -> Vec { let sqrt_2_over_pi = 0.7978845608f32; - x.iter().map(|&v| { - let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v); - 0.5 * v * (1.0 + inner.tanh()) - }).collect() + x.iter() + .map(|&v| { + let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v); + 0.5 * v * (1.0 + inner.tanh()) + }) + .collect() } fn cpu_silu(x: &[f32]) -> Vec { @@ -88,8 +92,13 @@ fn check_close(result: &[f32], expected: &[f32], atol: f32, name: &str) { let mut max_err = 0.0f32; for (i, (r, e)) in result.iter().zip(expected).enumerate() { let err = (r - e).abs(); - if err > max_err { max_err = err; } - assert!(err <= atol, "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}"); + if err > max_err { + max_err = err; + } + assert!( + err <= atol, + "{name}: mismatch at [{i}]: got {r}, expected {e}, err {err}" + ); } println!("{name}: max_err = {max_err:.6e}"); } @@ -208,13 +217,18 @@ fn test_softmax_sum_to_one() { init(); let rows = 4; let cols = 2048; - let data: Vec = (0..rows * cols).map(|i| ((i % 31) as f32 - 15.0) * 0.5).collect(); + let data: Vec = (0..rows * cols) + .map(|i| ((i % 31) as f32 - 15.0) * 0.5) + .collect(); let x = Tensor::from_slice(&data, &[rows, cols]).to_device(Device::Cuda(0)); let out = softmax(&x).to_device(Device::Cpu); let result = out.as_slice::(); for r in 0..rows { let row_sum: f32 = result[r * cols..(r + 1) * cols].iter().sum(); - assert!((row_sum - 1.0).abs() < 1e-5, "softmax row {r} sum = {row_sum}"); + assert!( + (row_sum - 1.0).abs() < 1e-5, + "softmax row {r} sum = {row_sum}" + ); } } @@ -247,8 +261,10 @@ fn test_embedding_f32() { for i in 0..hidden { let expected = table_data[tid as usize * hidden + i]; let got = result[seq_idx * hidden + i]; - assert!((got - expected).abs() < 1e-6, - "embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}"); + assert!( + (got - expected).abs() < 1e-6, + "embedding mismatch at [{seq_idx},{i}]: got {got}, expected {expected}" + ); } } } @@ -270,8 +286,8 @@ fn test_rope_f32() { let mut expected = x_data.clone(); cpu_rope(&mut expected, &positions, num_heads, head_dim, theta); - let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]) - .to_device(Device::Cuda(0)); + let x = + Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0)); let cache = RopeCache::new(64, head_dim, theta); rope_inplace(&x, &cache, &positions); @@ -292,8 +308,8 @@ fn test_rope_position_0_identity() { .map(|i| (i as f32 + 1.0) * 0.1) .collect(); - let x = Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]) - .to_device(Device::Cuda(0)); + let x = + Tensor::from_slice(&x_data, &[num_tokens, num_heads, head_dim]).to_device(Device::Cuda(0)); let cache = RopeCache::new(64, head_dim, 10000.0); rope_inplace(&x, &cache, &positions); diff --git a/crates/xserv-model/src/bin/bench-gpt-oss.rs b/crates/xserv-model/src/bin/bench-gpt-oss.rs index de983a0..5043a98 100644 --- a/crates/xserv-model/src/bin/bench-gpt-oss.rs +++ b/crates/xserv-model/src/bin/bench-gpt-oss.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Instant; use xserv_distributed::{TpContext, UniqueId, get_unique_id}; -use xserv_model::{loader, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, BLOCK_SIZE}; +use xserv_model::{BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, loader}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -23,8 +23,12 @@ fn main() { eprintln!( "gpt-oss-20b: layers={}, hidden={}, heads={}/{} kv, experts={}, top_k={}, vocab={}", - config.num_layers(), config.hidden(), config.num_heads(), - config.num_kv_heads(), config.num_experts(), config.experts_per_token(), + config.num_layers(), + config.hidden(), + config.num_heads(), + config.num_kv_heads(), + config.num_experts(), + config.experts_per_token(), config.vocab_size ); eprintln!("TP world={world}, max_tokens={max_tokens}"); @@ -59,17 +63,29 @@ fn main() { let tp0 = Arc::new(TpContext::init(0, world, uid, 0)); eprintln!("[rank 0] Loading weights..."); let weights = loader::load_model_dir(&model_dir, Device::Cpu); - eprintln!("[rank 0] Loaded {} tensors, building model...", weights.len()); + eprintln!( + "[rank 0] Loaded {} tensors, building model...", + weights.len() + ); let model = GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp0)); let total_blocks = max_blocks_per_seq + 64; let mut cache = PagedKVCache::new_tp( - &config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0, + &config, + local_kv, + total_blocks, + 0, + 4, + max_blocks_per_seq, + DType::BF16, + 0, ); eprintln!("[rank 0] Ready."); // Prompt let prompt_arg = get_arg::(&args, "--prompt"); - let prompt = prompt_arg.as_deref().unwrap_or("What is the meaning of life?"); + let prompt = prompt_arg + .as_deref() + .unwrap_or("What is the meaning of life?"); let token_ids = tokenizer.encode(prompt); eprintln!("Prompt ({} tokens): {prompt}", token_ids.len()); @@ -83,11 +99,21 @@ fn main() { // (oracle) next token. Removes free-running compounding so it isolates // whether per-position logits agree with the llama.cpp trajectory. if let Some(forced) = get_arg::(&args, "--forced") { - let forced_ids: Vec = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect(); + let forced_ids: Vec = forced + .split(',') + .filter_map(|s| s.trim().parse().ok()) + .collect(); let mut seq = token_ids.clone(); seq.extend_from_slice(&forced_ids); // Workers must run the same prefill in lockstep (TP AllReduces match up). - broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: seq.clone(), slot }); + broadcast_cmd( + &worker_txs, + &worker_handles, + WorkerCmd::Prefill { + tokens: seq.clone(), + slot, + }, + ); let logits = model.forward_prefill_paged(&seq, slot, &mut cache); wait_workers(&worker_handles); let logits_cpu = logits.to_device(Device::Cpu); @@ -99,19 +125,31 @@ fn main() { // position i predicts seq[i+1]; we check the forced region for i in (plen - 1)..(seq.len() - 1) { let row = &data[i * vocab..(i + 1) * vocab]; - let argmax = row.iter().enumerate() + let argmax = row + .iter() + .enumerate() .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(j, _)| j as u32).unwrap(); + .map(|(j, _)| j as u32) + .unwrap(); let expected = seq[i + 1]; let ok = argmax == expected; - if ok { matches += 1; } + if ok { + matches += 1; + } total += 1; - eprintln!("pos {i}: xserv_argmax={argmax} oracle={expected} {}", if ok {"OK"} else {"DIFF"}); + eprintln!( + "pos {i}: xserv_argmax={argmax} oracle={expected} {}", + if ok { "OK" } else { "DIFF" } + ); } - eprintln!("\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%", - 100.0 * matches as f64 / total as f64); + eprintln!( + "\nTeacher-forced top-1 agreement: {matches}/{total} = {:.1}%", + 100.0 * matches as f64 / total as f64 + ); broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown); - for (h, _) in worker_handles { h.join().unwrap(); } + for (h, _) in worker_handles { + h.join().unwrap(); + } return; } @@ -120,8 +158,18 @@ fn main() { // per-position top-1 agreement bucketed by position. Localizes long-context // decode degradation (which prefill teacher-forcing cannot see). if let Some(forced) = get_arg::(&args, "--forced-decode") { - let forced_ids: Vec = forced.split(',').filter_map(|s| s.trim().parse().ok()).collect(); - broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { tokens: token_ids.clone(), slot }); + let forced_ids: Vec = forced + .split(',') + .filter_map(|s| s.trim().parse().ok()) + .collect(); + broadcast_cmd( + &worker_txs, + &worker_handles, + WorkerCmd::Prefill { + tokens: token_ids.clone(), + slot, + }, + ); let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache); wait_workers(&worker_handles); let mut pred = sample_greedy_last(&logits); // prediction for forced[0] @@ -133,34 +181,55 @@ fn main() { matches += ok as usize; total += 1; let b = i / bucket; - if buckets.len() <= b { buckets.push((0, 0)); } + if buckets.len() <= b { + buckets.push((0, 0)); + } buckets[b].0 += ok as usize; buckets[b].1 += 1; // Teacher-force: feed the oracle token through the decode path. let pos = cache.seq_len(slot); - broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode { - tokens: vec![f], positions: vec![pos], slots: vec![slot], - }); + broadcast_cmd( + &worker_txs, + &worker_handles, + WorkerCmd::Decode { + tokens: vec![f], + positions: vec![pos], + slots: vec![slot], + }, + ); let logits = model.forward_decode_paged(&[f], &[pos], &[slot], &mut cache); wait_workers(&worker_handles); pred = sample_greedy_last(&logits); } - eprintln!("Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%", - 100.0 * matches as f64 / total as f64); + eprintln!( + "Teacher-forced DECODE agreement: {matches}/{total} = {:.1}%", + 100.0 * matches as f64 / total as f64 + ); for (b, (m, t)) in buckets.iter().enumerate() { - eprintln!(" pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%", - b * bucket, b * bucket + t, 100.0 * (*m as f64) / (*t as f64)); + eprintln!( + " pos[{:>4}..{:<4}]: {m:>3}/{t:<3} = {:.0}%", + b * bucket, + b * bucket + t, + 100.0 * (*m as f64) / (*t as f64) + ); } broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Shutdown); - for (h, _) in worker_handles { h.join().unwrap(); } + for (h, _) in worker_handles { + h.join().unwrap(); + } return; } // Prefill let t0 = Instant::now(); - broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Prefill { - tokens: token_ids.clone(), slot, - }); + broadcast_cmd( + &worker_txs, + &worker_handles, + WorkerCmd::Prefill { + tokens: token_ids.clone(), + slot, + }, + ); let logits = model.forward_prefill_paged(&token_ids, slot, &mut cache); wait_workers(&worker_handles); let ttft = t0.elapsed(); @@ -178,12 +247,20 @@ fn main() { let text = tokenizer.decode(&[next]); print!("{text}"); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } let pos = cache.seq_len(slot); - broadcast_cmd(&worker_txs, &worker_handles, WorkerCmd::Decode { - tokens: vec![next], positions: vec![pos], slots: vec![slot], - }); + broadcast_cmd( + &worker_txs, + &worker_handles, + WorkerCmd::Decode { + tokens: vec![next], + positions: vec![pos], + slots: vec![slot], + }, + ); let logits = decoder.decode(&model, &[next], &[pos], &[slot], &mut cache); wait_workers(&worker_handles); @@ -196,13 +273,20 @@ fn main() { let gen_tokens = output_tokens.len(); let full_text = tokenizer.decode(&output_tokens); eprintln!("\nGenerated text: {full_text}"); - eprintln!("Token IDs: {:?}", &output_tokens[..output_tokens.len().min(20)]); + eprintln!( + "Token IDs: {:?}", + &output_tokens[..output_tokens.len().min(20)] + ); let tpot = if gen_tokens > 1 { decode_elapsed.as_secs_f64() * 1000.0 / (gen_tokens - 1) as f64 - } else { 0.0 }; + } else { + 0.0 + }; let tok_s = if gen_tokens > 1 { (gen_tokens - 1) as f64 / decode_elapsed.as_secs_f64() - } else { 0.0 }; + } else { + 0.0 + }; eprintln!("\n--- Performance ---"); eprintln!("Generated: {} tokens", gen_tokens); @@ -222,8 +306,15 @@ fn main() { #[derive(Clone)] enum WorkerCmd { Register(usize), - Prefill { tokens: Vec, slot: usize }, - Decode { tokens: Vec, positions: Vec, slots: Vec }, + Prefill { + tokens: Vec, + slot: usize, + }, + Decode { + tokens: Vec, + positions: Vec, + slots: Vec, + }, Shutdown, } @@ -241,12 +332,20 @@ fn worker_loop( let tp = Arc::new(TpContext::init(rank, world, uid, rank as u32)); eprintln!("[rank {rank}] Loading weights..."); let weights = loader::load_model_dir(&model_dir, Device::Cpu); - let model = GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)); + let model = + GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp)); let local_kv = config.num_kv_heads() / world; let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 64; let mut cache = PagedKVCache::new_tp( - &config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, rank as u32, + &config, + local_kv, + total_blocks, + 0, + 4, + max_blocks_per_seq, + DType::BF16, + rank as u32, ); eprintln!("[rank {rank}] Ready."); ack_tx.send(()).unwrap(); @@ -260,7 +359,11 @@ fn worker_loop( WorkerCmd::Prefill { tokens, slot } => { let _ = model.forward_prefill_paged(&tokens, slot, &mut cache); } - WorkerCmd::Decode { tokens, positions, slots } => { + WorkerCmd::Decode { + tokens, + positions, + slots, + } => { let _ = decoder.decode(&model, &tokens, &positions, &slots, &mut cache); } WorkerCmd::Shutdown => break, @@ -299,14 +402,15 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 { let data = logits_cpu.as_slice::(); let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - - last.iter().enumerate() + last.iter() + .enumerate() .max_by(|a, b| { let af = a.1.to_f32(); let bf = b.1.to_f32(); af.partial_cmp(&bf).unwrap_or(std::cmp::Ordering::Equal) }) - .map(|(i, _)| i as u32).unwrap() + .map(|(i, _)| i as u32) + .unwrap() } fn get_arg(args: &[String], flag: &str) -> Option { diff --git a/crates/xserv-model/src/bin/bench-gpt2.rs b/crates/xserv-model/src/bin/bench-gpt2.rs index 9544aae..f81700c 100644 --- a/crates/xserv-model/src/bin/bench-gpt2.rs +++ b/crates/xserv-model/src/bin/bench-gpt2.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::time::Instant; -use xserv_model::gpt2::{sample_greedy, KVCache}; -use xserv_model::{loader, GPT2, ModelConfig}; +use xserv_model::gpt2::{KVCache, sample_greedy}; +use xserv_model::{GPT2, ModelConfig, loader}; use xserv_tensor::Device; use xserv_tokenizer::Tokenizer; @@ -104,9 +104,15 @@ fn main() { let tbt_us = if !token_times_us.is_empty() { token_times_us.iter().sum::() / token_times_us.len() as u128 - } else { 0 }; + } else { + 0 + }; let total_gen_us: u128 = ttft_us + token_times_us.iter().sum::(); - let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 }; + let tpot_us = if num_generated > 0 { + total_gen_us / num_generated as u128 + } else { + 0 + }; let gen_text_escaped = generated_text .replace('\\', "\\\\") @@ -124,11 +130,16 @@ fn main() { print!("\"ttft_us\": {ttft_us}, "); print!("\"tbt_us\": {tbt_us}, "); print!("\"tpot_us\": {tpot_us}}}"); - if i < prompts.len() - 1 { println!(","); } else { println!(); } + if i < prompts.len() - 1 { + println!(","); + } else { + println!(); + } eprintln!( "[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}", - i + 1, prompts.len(), + i + 1, + prompts.len(), ttft_us as f64 / 1000.0, tbt_us as f64 / 1000.0, &generated_text.replace('\n', " ")[..generated_text.len().min(60)] @@ -138,12 +149,18 @@ fn main() { } fn generate_with_cache( - model: &GPT2, config: &ModelConfig, tokenizer: &Tokenizer, - input_ids: &[u32], gen_tokens: usize, + model: &GPT2, + config: &ModelConfig, + tokenizer: &Tokenizer, + input_ids: &[u32], + gen_tokens: usize, ) -> (Vec, u128, Vec) { let mut cache = KVCache::new( - config.num_layers(), config.num_heads(), config.head_dim(), - xserv_tensor::DType::F32, Device::Cuda(0), + config.num_layers(), + config.num_heads(), + config.head_dim(), + xserv_tensor::DType::F32, + Device::Cuda(0), ); // Prefill @@ -163,15 +180,19 @@ fn generate_with_cache( let next = sample_greedy(&logits); token_times.push(t_start.elapsed().as_micros()); generated.push(next); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } } (generated, ttft_us, token_times) } fn generate_no_cache( - model: &GPT2, tokenizer: &Tokenizer, - input_ids: &[u32], gen_tokens: usize, + model: &GPT2, + tokenizer: &Tokenizer, + input_ids: &[u32], + gen_tokens: usize, ) -> (Vec, u128, Vec) { let mut all_ids = input_ids.to_vec(); @@ -191,7 +212,9 @@ fn generate_no_cache( token_times.push(t_start.elapsed().as_micros()); all_ids.push(next); generated.push(next); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } } (generated, ttft_us, token_times) diff --git a/crates/xserv-model/src/bin/bench-qwen3.rs b/crates/xserv-model/src/bin/bench-qwen3.rs index 5608170..96dad80 100644 --- a/crates/xserv-model/src/bin/bench-qwen3.rs +++ b/crates/xserv-model/src/bin/bench-qwen3.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::time::Instant; use xserv_model::qwen3::sample_greedy; -use xserv_model::{loader, DecodeGraphState, GpuKVCache, ModelConfig, Qwen3}; +use xserv_model::{DecodeGraphState, GpuKVCache, ModelConfig, Qwen3, loader}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -139,18 +139,35 @@ fn main() { } else { // Replay captured graphs let pos = cache.seq_len() as u32; - graph.execute(last, pos, &mut cache, &layer_ptrs, embed, config.vocab_size as i32, config.hidden() as i32); + graph.execute( + last, + pos, + &mut cache, + &layer_ptrs, + embed, + config.vocab_size as i32, + config.hidden() as i32, + ); cache.advance_seq_len(1); // Read logits from graph buffer let vocab_size = config.vocab_size; let mut logits_bytes = vec![0u8; vocab_size * 2]; - graph.logits_buffer().copy_to_host(&mut logits_bytes).unwrap(); + graph + .logits_buffer() + .copy_to_host(&mut logits_bytes) + .unwrap(); let logits_data: &[half::bf16] = unsafe { - std::slice::from_raw_parts(logits_bytes.as_ptr() as *const half::bf16, vocab_size) + std::slice::from_raw_parts( + logits_bytes.as_ptr() as *const half::bf16, + vocab_size, + ) }; - logits_data.iter().enumerate() + logits_data + .iter() + .enumerate() .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(idx, _)| idx as u32).unwrap() + .map(|(idx, _)| idx as u32) + .unwrap() } } else { let logits = model.forward_gpu_cache(&[last], &mut cache); @@ -159,16 +176,24 @@ fn main() { token_times.push(t_start.elapsed().as_micros()); generated.push(next); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } } let num_generated = generated.len(); let generated_text = tokenizer.decode(&generated); let tbt_us = if !token_times.is_empty() { token_times.iter().sum::() / token_times.len() as u128 - } else { 0 }; + } else { + 0 + }; let total_gen_us: u128 = ttft_us + token_times.iter().sum::(); - let tpot_us = if num_generated > 0 { total_gen_us / num_generated as u128 } else { 0 }; + let tpot_us = if num_generated > 0 { + total_gen_us / num_generated as u128 + } else { + 0 + }; let gen_text_escaped = generated_text .replace('\\', "\\\\") @@ -186,13 +211,18 @@ fn main() { print!("\"ttft_us\": {ttft_us}, "); print!("\"tbt_us\": {tbt_us}, "); print!("\"tpot_us\": {tpot_us}}}"); - if i < prompts.len() - 1 { println!(","); } else { println!(); } + if i < prompts.len() - 1 { + println!(","); + } else { + println!(); + } let display_text = generated_text.replace('\n', " "); let truncated: String = display_text.chars().take(60).collect(); eprintln!( "[{}/{}] input={input_len}tok gen={num_generated}tok ttft={:.1}ms tbt={:.1}ms | {}", - i + 1, prompts.len(), + i + 1, + prompts.len(), ttft_us as f64 / 1000.0, tbt_us as f64 / 1000.0, truncated diff --git a/crates/xserv-model/src/bin/bench-tp.rs b/crates/xserv-model/src/bin/bench-tp.rs index 0a4e8fd..03a444e 100644 --- a/crates/xserv-model/src/bin/bench-tp.rs +++ b/crates/xserv-model/src/bin/bench-tp.rs @@ -18,7 +18,7 @@ use std::thread; use std::time::Instant; use xserv_model::qwen3::sample_greedy; -use xserv_model::{loader, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; +use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, loader}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -35,8 +35,13 @@ fn main() { std::process::exit(1); } let model_dir = PathBuf::from(&args[1]); - let world: usize = arg(&args, "--tp").and_then(|s| s.parse().ok()).unwrap_or(1).max(1); - let gen_tokens: usize = arg(&args, "--gen-tokens").and_then(|s| s.parse().ok()).unwrap_or(64); + let world: usize = arg(&args, "--tp") + .and_then(|s| s.parse().ok()) + .unwrap_or(1) + .max(1); + let gen_tokens: usize = arg(&args, "--gen-tokens") + .and_then(|s| s.parse().ok()) + .unwrap_or(64); let devices: Vec = match arg(&args, "--devices") { Some(s) => s.split(',').filter_map(|d| d.trim().parse().ok()).collect(), None => (0..world as u32).collect(), @@ -67,7 +72,11 @@ fn main() { // Tensors are not Send (their Storage holds a raw GPU pointer), so each rank // thread loads its own CPU copy of the weights and shards in-thread. Loading // is not part of the timed region. - let id = if world > 1 { Some(xserv_distributed::get_unique_id()) } else { None }; + let id = if world > 1 { + Some(xserv_distributed::get_unique_id()) + } else { + None + }; let handles: Vec<_> = (0..world) .map(|rank| { @@ -76,7 +85,9 @@ fn main() { let prompt_ids = prompt_ids.clone(); let device = devices[rank]; thread::spawn(move || { - run_rank(rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos) + run_rank( + rank, world, device, id, config, model_dir, prompt_ids, gen_tokens, eos, + ) }) }) .collect(); @@ -91,7 +102,10 @@ fn main() { let results = rank0.expect("rank 0 produced no results"); println!("\n=== TP={world} (devices {devices:?}) — Qwen3 E2E benchmark ==="); - println!("{:<45} {:>10} {:>12} {:>8}", "prompt", "TTFT(ms)", "decode tok/s", "gen"); + println!( + "{:<45} {:>10} {:>12} {:>8}", + "prompt", "TTFT(ms)", "decode tok/s", "gen" + ); let mut tps_sum = 0.0; for (i, r) in results.iter().enumerate() { let text = tokenizer.decode(&r.gen_ids).replace('\n', " "); @@ -99,16 +113,29 @@ fn main() { let p: String = prompts[i].chars().take(43).collect(); println!( "{:<45} {:>10.1} {:>12.1} {:>8} | {}", - p, r.ttft_ms, r.decode_tok_s, r.gen_ids.len(), short + p, + r.ttft_ms, + r.decode_tok_s, + r.gen_ids.len(), + short ); tps_sum += r.decode_tok_s; } - println!("--- mean decode throughput: {:.1} tok/s ---", tps_sum / results.len() as f64); + println!( + "--- mean decode throughput: {:.1} tok/s ---", + tps_sum / results.len() as f64 + ); // Machine-readable line for cross-TP correctness diffing (rank 0 token ids). let all_ids: Vec = results .iter() - .map(|r| r.gen_ids.iter().map(|x| x.to_string()).collect::>().join(",")) + .map(|r| { + r.gen_ids + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") + }) .collect(); println!("CORRECTNESS_IDS tp={world} {}", all_ids.join(" | ")); } @@ -126,7 +153,12 @@ fn run_rank( ) -> Option> { // Bind this thread to its GPU and set up the TP communicator. let tp = if world > 1 { - Some(Arc::new(xserv_distributed::TpContext::init(rank, world, id.unwrap(), device))) + Some(Arc::new(xserv_distributed::TpContext::init( + rank, + world, + id.unwrap(), + device, + ))) } else { xserv_cuda::device::set_device(device).unwrap(); None @@ -142,7 +174,14 @@ fn run_rank( let max_blocks_per_seq = max_seq.div_ceil(BLOCK_SIZE); let total_blocks = max_blocks_per_seq + 8; let mut cache = PagedKVCache::new_tp( - &config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, device, + &config, + local_kv, + total_blocks, + 0, + 1, + max_blocks_per_seq, + DType::BF16, + device, ); // Warmup (init kernels / allocator / NCCL channels) — not timed. @@ -177,12 +216,20 @@ fn run_rank( steps += 1; } let decode_s = t1.elapsed().as_secs_f64(); - let decode_tok_s = if steps > 0 && decode_s > 0.0 { steps as f64 / decode_s } else { 0.0 }; + let decode_tok_s = if steps > 0 && decode_s > 0.0 { + steps as f64 / decode_s + } else { + 0.0 + }; cache.free_sequence(0); if rank == 0 { - out.push(PromptResult { gen_ids: generated, ttft_ms, decode_tok_s }); + out.push(PromptResult { + gen_ids: generated, + ttft_ms, + decode_tok_s, + }); } } @@ -190,5 +237,8 @@ fn run_rank( } fn arg<'a>(args: &'a [String], flag: &str) -> Option<&'a str> { - args.iter().position(|a| a == flag).and_then(|i| args.get(i + 1)).map(|s| s.as_str()) + args.iter() + .position(|a| a == flag) + .and_then(|i| args.get(i + 1)) + .map(|s| s.as_str()) } diff --git a/crates/xserv-model/src/bin/dump-logits.rs b/crates/xserv-model/src/bin/dump-logits.rs index 080345a..b8029f3 100644 --- a/crates/xserv-model/src/bin/dump-logits.rs +++ b/crates/xserv-model/src/bin/dump-logits.rs @@ -1,8 +1,8 @@ +use half::bf16; use std::path::PathBuf; -use xserv_model::{loader, KVCache, ModelConfig, Qwen3}; +use xserv_model::{KVCache, ModelConfig, Qwen3, loader}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; -use half::bf16; fn main() { let args: Vec = std::env::args().collect(); @@ -20,8 +20,11 @@ fn main() { eprintln!("Token IDs: {token_ids:?}"); let mut cache = KVCache::new( - config.num_layers(), config.num_kv_heads(), config.head_dim(), - DType::BF16, Device::Cuda(0), + config.num_layers(), + config.num_kv_heads(), + config.head_dim(), + DType::BF16, + Device::Cuda(0), ); let logits = model.forward_with_cache(&token_ids, &mut cache); let logits_cpu = logits.to_device(Device::Cpu); @@ -31,7 +34,9 @@ fn main() { // Print top-20 logits for the last position let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - let mut indexed: Vec<(usize, f32)> = last_row.iter().enumerate() + let mut indexed: Vec<(usize, f32)> = last_row + .iter() + .enumerate() .map(|(i, v)| (i, v.to_f32())) .collect(); indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); diff --git a/crates/xserv-model/src/bin/xserv-chat.rs b/crates/xserv-model/src/bin/xserv-chat.rs index fbe8357..c02f9e4 100644 --- a/crates/xserv-model/src/bin/xserv-chat.rs +++ b/crates/xserv-model/src/bin/xserv-chat.rs @@ -1,10 +1,13 @@ use std::io::{self, IsTerminal, Read, Write}; use std::path::PathBuf; -use std::sync::{mpsc, Arc}; +use std::sync::{Arc, mpsc}; use std::thread; -use xserv_model::{GraphedGptOssDecoder, loader, sample, sample_greedy_penalized, GptOss, ModelConfig, PagedKVCache, Qwen3, SamplingParams, BLOCK_SIZE}; +use xserv_model::{ + BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, SamplingParams, + loader, sample, sample_greedy_penalized, +}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -14,13 +17,24 @@ enum ChatModel { } impl ChatModel { - fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> xserv_tensor::Tensor { + fn forward_prefill_paged( + &self, + tokens: &[u32], + slot: usize, + cache: &mut PagedKVCache, + ) -> xserv_tensor::Tensor { match self { ChatModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache), ChatModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache), } } - fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> xserv_tensor::Tensor { + fn forward_decode_paged( + &self, + tokens: &[u32], + positions: &[usize], + slots: &[usize], + cache: &mut PagedKVCache, + ) -> xserv_tensor::Tensor { match self { ChatModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache), ChatModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache), @@ -33,8 +47,15 @@ impl ChatModel { enum TpCommand { Register(usize), Free(usize), - Prefill { tokens: Vec, slot: usize }, - Decode { tokens: Vec, positions: Vec, slots: Vec }, + Prefill { + tokens: Vec, + slot: usize, + }, + Decode { + tokens: Vec, + positions: Vec, + slots: Vec, + }, } struct TpHandle { @@ -56,7 +77,8 @@ impl TpHandle { } fn tp_worker_loop( - rank: usize, world: usize, + rank: usize, + world: usize, id: xserv_distributed::UniqueId, model_dir: std::path::PathBuf, config: ModelConfig, @@ -64,29 +86,68 @@ fn tp_worker_loop( cmd_rx: mpsc::Receiver, ack_tx: mpsc::Sender<()>, ) { - let tp = Arc::new(xserv_distributed::TpContext::init(rank, world, id, rank as u32)); + let tp = Arc::new(xserv_distributed::TpContext::init( + rank, + world, + id, + rank as u32, + )); let weights = loader::load_model_dir(&model_dir, Device::Cpu); let model = if config.is_moe() { - ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp))) + ChatModel::GptOss(GptOss::from_weights_tp( + config.clone(), + weights, + rank, + world, + rank as u32, + Some(tp), + )) } else { - ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, rank as u32, Some(tp))) + ChatModel::Qwen3(Qwen3::from_weights_tp( + config.clone(), + weights, + rank, + world, + rank as u32, + Some(tp), + )) }; let local_kv = config.num_kv_heads() / world; let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 8; let mut cache = PagedKVCache::new_tp( - &config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, rank as u32, + &config, + local_kv, + total_blocks, + 0, + 1, + max_blocks_per_seq, + DType::BF16, + rank as u32, ); let mut decoder = GraphedGptOssDecoder::new(); while let Ok(cmd) = cmd_rx.recv() { match cmd { - TpCommand::Register(slot) => { let _ = cache.register_sequence(slot); } + TpCommand::Register(slot) => { + let _ = cache.register_sequence(slot); + } TpCommand::Free(slot) => cache.free_sequence(slot), TpCommand::Prefill { tokens, slot } => { let _ = model.forward_prefill_paged(&tokens, slot, &mut cache); } - TpCommand::Decode { tokens, positions, slots } => { - let _ = chat_decode(&model, &mut decoder, &tokens, &positions, &slots, &mut cache); + TpCommand::Decode { + tokens, + positions, + slots, + } => { + let _ = chat_decode( + &model, + &mut decoder, + &tokens, + &positions, + &slots, + &mut cache, + ); } } let _ = ack_tx.send(()); @@ -221,7 +282,13 @@ fn read_line_edited(prompt: &str) -> Line { } b => { // UTF-8 multi-byte: read the continuation bytes for this char. - let extra = if b >= 0xF0 { 3 } else if b >= 0xE0 { 2 } else { 1 }; + let extra = if b >= 0xF0 { + 3 + } else if b >= 0xE0 { + 2 + } else { + 1 + }; let mut bytes = vec![b]; let mut cont = [0u8; 1]; let mut ok = true; @@ -275,7 +342,8 @@ fn main() { if world > 1 { assert!( config.num_kv_heads() % world == 0, - "num_kv_heads {} not divisible by tp {world}", config.num_kv_heads() + "num_kv_heads {} not divisible by tp {world}", + config.num_kv_heads() ); } @@ -290,7 +358,16 @@ fn main() { let model_dir = opts.model_dir.clone(); let config = config.clone(); thread::spawn(move || { - tp_worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx); + tp_worker_loop( + rank, + world, + id, + model_dir, + config, + max_seq_len, + ctx_rx, + ack_tx, + ); }); } eprintln!("Loading weights (tp={world})..."); @@ -298,14 +375,37 @@ fn main() { let weights = loader::load_model_dir(&opts.model_dir, Device::Cpu); eprintln!("Loaded {} tensors", weights.len()); let m = if is_moe { - ChatModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp))) + ChatModel::GptOss(GptOss::from_weights_tp( + config.clone(), + weights, + 0, + world, + 0, + Some(tp), + )) } else { - ChatModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, 0, world, 0, Some(tp))) + ChatModel::Qwen3(Qwen3::from_weights_tp( + config.clone(), + weights, + 0, + world, + 0, + Some(tp), + )) }; let local_kv = config.num_kv_heads() / world; let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 8; - let c = PagedKVCache::new_tp(&config, local_kv, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0); + let c = PagedKVCache::new_tp( + &config, + local_kv, + total_blocks, + 0, + 1, + max_blocks_per_seq, + DType::BF16, + 0, + ); let h = TpHandle { cmd_txs, ack_rx }; (m, c, Some(h)) } else { @@ -323,7 +423,10 @@ fn main() { let tokenizer = Tokenizer::from_file(&opts.model_dir.join("tokenizer.json")); let mut decoder = GraphedGptOssDecoder::new(); - if let Some(h) = &tp_handle { h.send(TpCommand::Register(SLOT)); h.wait(); } + if let Some(h) = &tp_handle { + h.send(TpCommand::Register(SLOT)); + h.wait(); + } cache.register_sequence(SLOT).expect("register chat slot"); let use_color = opts.color && io::stdout().is_terminal(); @@ -365,11 +468,8 @@ fn main() { if is_moe { // Harmony multi-turn: re-render the whole conversation (prior // analysis dropped) and re-prefill into a freshly cleared slot. - let prompt = build_conversation_gpt_oss( - opts.system_prompt.as_deref(), - &moe_history, - input, - ); + let prompt = + build_conversation_gpt_oss(opts.system_prompt.as_deref(), &moe_history, input); let prompt_tokens = tokenizer.encode(&prompt); if prompt_tokens.is_empty() { continue; @@ -386,8 +486,17 @@ fn main() { print!("assistant> "); io::stdout().flush().unwrap(); let (_finish, answer) = generate_with_paged_cache( - &model, &mut decoder, &mut cache, &tokenizer, &prompt_tokens, &opts.sampling, - max_new_tokens, use_color, &tp_handle, is_moe, opts.enable_thinking, + &model, + &mut decoder, + &mut cache, + &tokenizer, + &prompt_tokens, + &opts.sampling, + max_new_tokens, + use_color, + &tp_handle, + is_moe, + opts.enable_thinking, ); moe_history.push((input.to_string(), answer)); println!(); @@ -436,10 +545,24 @@ fn main() { ); match finish { Finish::Stop { token_id } => { - append_after_stop(&model, &mut cache, &tokenizer, max_seq_len, token_id, &tp_handle); + append_after_stop( + &model, + &mut cache, + &tokenizer, + max_seq_len, + token_id, + &tp_handle, + ); } Finish::Length => { - append_text_to_cache(&model, &mut cache, &tokenizer, max_seq_len, "<|im_end|>\n", &tp_handle); + append_text_to_cache( + &model, + &mut cache, + &tokenizer, + max_seq_len, + "<|im_end|>\n", + &tp_handle, + ); } } println!(); @@ -448,9 +571,15 @@ fn main() { /// Free and re-register the single chat KV slot (clears all cached context). fn reset_slot(cache: &mut PagedKVCache, tp: &Option) { - if let Some(h) = tp { h.send(TpCommand::Free(SLOT)); h.wait(); } + if let Some(h) = tp { + h.send(TpCommand::Free(SLOT)); + h.wait(); + } cache.free_sequence(SLOT); - if let Some(h) = tp { h.send(TpCommand::Register(SLOT)); h.wait(); } + if let Some(h) = tp { + h.send(TpCommand::Register(SLOT)); + h.wait(); + } cache.register_sequence(SLOT).expect("register chat slot"); } @@ -588,7 +717,15 @@ fn new_paged_cache(config: &ModelConfig, max_seq_len: usize) -> PagedKVCache { let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = (max_blocks_per_seq + 1).max(2); // Single-slot interactive CLI: no swap pool (cpu_total_blocks = 0). - PagedKVCache::new(config, total_blocks, 0, 1, max_blocks_per_seq, DType::BF16, 0) + PagedKVCache::new( + config, + total_blocks, + 0, + 1, + max_blocks_per_seq, + DType::BF16, + 0, + ) } fn build_turn_prompt( @@ -668,7 +805,10 @@ fn build_conversation_gpt_oss( /// civil-calendar conversion (same algorithm the server uses for strftime_now). fn today_ymd() -> String { use std::time::{SystemTime, UNIX_EPOCH}; - let secs = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); let z = (secs / 86400) as i64 + 719468; let era = (if z >= 0 { z } else { z - 146096 }) / 146097; let doe = z - era * 146097; @@ -709,12 +849,32 @@ fn generate_with_paged_cache( is_moe: bool, enable_thinking: bool, ) -> (Finish, String) { - let harmony_end_id = if is_moe { tokenizer.special_token_id("<|end|>") } else { None }; - let harmony_channel_id = if is_moe { tokenizer.special_token_id("<|channel|>") } else { None }; - let harmony_message_id = if is_moe { tokenizer.special_token_id("<|message|>") } else { None }; + let harmony_end_id = if is_moe { + tokenizer.special_token_id("<|end|>") + } else { + None + }; + let harmony_channel_id = if is_moe { + tokenizer.special_token_id("<|channel|>") + } else { + None + }; + let harmony_message_id = if is_moe { + tokenizer.special_token_id("<|message|>") + } else { + None + }; let harmony_special: Vec = if is_moe { - ["<|channel|>", "<|start|>", "<|end|>", "<|message|>", "<|return|>"] - .iter().filter_map(|s| tokenizer.special_token_id(s)).collect() + [ + "<|channel|>", + "<|start|>", + "<|end|>", + "<|message|>", + "<|return|>", + ] + .iter() + .filter_map(|s| tokenizer.special_token_id(s)) + .collect() } else { Vec::new() }; @@ -722,18 +882,29 @@ fn generate_with_paged_cache( // "analysis" channel is rendered as thinking (gray). After <|channel|> // we read the channel name tokens until <|message|>. #[derive(PartialEq, Clone, Copy)] - enum HarmonyState { Normal, ReadingChannel, InAnalysis, InFinal } - let mut hstate = if is_moe { HarmonyState::InFinal } else { HarmonyState::Normal }; + enum HarmonyState { + Normal, + ReadingChannel, + InAnalysis, + InFinal, + } + let mut hstate = if is_moe { + HarmonyState::InFinal + } else { + HarmonyState::Normal + }; // Off by default. A repetition penalty over a harmony stream penalizes the // control tokens (<|channel|>, <|message|>, <|start|>) that MUST repeat to // open the final channel — so a non-1.0 default makes gpt-oss stop right // after the analysis block, before emitting any answer. Opt in via the env // var if you want it for plain (non-harmony) generation. - let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok() + let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY") + .ok() .and_then(|s| s.parse().ok()) .unwrap_or(1.0); - let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok() + let rep_window: usize = std::env::var("XSERV_REP_WINDOW") + .ok() .and_then(|s| s.parse().ok()) .unwrap_or(512); let mut history: Vec = Vec::new(); @@ -747,9 +918,16 @@ fn generate_with_paged_cache( } }; - if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: prompt_tokens.to_vec(), slot: SLOT }); } + if let Some(h) = tp { + h.send(TpCommand::Prefill { + tokens: prompt_tokens.to_vec(), + slot: SLOT, + }); + } let logits = model.forward_prefill_paged(prompt_tokens, SLOT, cache); - if let Some(h) = tp { h.wait(); } + if let Some(h) = tp { + h.wait(); + } let mut next = pick(&logits, sampling, &history); let mut decode_buffer = Vec::new(); let mut in_thinking = false; @@ -762,9 +940,17 @@ fn generate_with_paged_cache( for _ in 0..max_tokens { let position = cache.seq_len(SLOT); - if let Some(h) = tp { h.send(TpCommand::Decode { tokens: vec![next], positions: vec![position], slots: vec![SLOT] }); } + if let Some(h) = tp { + h.send(TpCommand::Decode { + tokens: vec![next], + positions: vec![position], + slots: vec![SLOT], + }); + } let logits = chat_decode(model, decoder, &[next], &[position], &[SLOT], cache); - if let Some(h) = tp { h.wait(); } + if let Some(h) = tp { + h.wait(); + } if tokenizer.is_eos(next) { print_stream_text( &tokenizer.flush_decode_stream(&mut decode_buffer), @@ -775,7 +961,10 @@ fn generate_with_paged_cache( print_stream_text("\n\n\n", true, use_color); } io::stdout().flush().unwrap(); - return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids)); + return ( + Finish::Stop { token_id: next }, + tokenizer.decode(&answer_ids), + ); } if harmony_end_id == Some(next) { // <|end|> closes current segment; if in final channel, we're done @@ -786,7 +975,10 @@ fn generate_with_paged_cache( ); if hstate == HarmonyState::InFinal { io::stdout().flush().unwrap(); - return (Finish::Stop { token_id: next }, tokenizer.decode(&answer_ids)); + return ( + Finish::Stop { token_id: next }, + tokenizer.decode(&answer_ids), + ); } // Closing a thinking (analysis/commentary) channel: emit the // marker so it renders like Qwen3's thinking block. @@ -842,7 +1034,13 @@ fn generate_with_paged_cache( // Analysis channel = the model's reasoning. With --think, show it as a // thinking block (gray if color); otherwise suppress it (answer only). if show_thinking { - print_generated_token(tokenizer, next, &mut decode_buffer, &mut in_thinking, use_color); + print_generated_token( + tokenizer, + next, + &mut decode_buffer, + &mut in_thinking, + use_color, + ); io::stdout().flush().unwrap(); } next = pick(&logits, sampling, &history); @@ -904,9 +1102,16 @@ fn append_text_to_cache( if tokens.is_empty() || cache.seq_len(SLOT) + tokens.len() > max_seq_len { return; } - if let Some(h) = tp { h.send(TpCommand::Prefill { tokens: tokens.clone(), slot: SLOT }); } + if let Some(h) = tp { + h.send(TpCommand::Prefill { + tokens: tokens.clone(), + slot: SLOT, + }); + } let _ = model.forward_prefill_paged(&tokens, SLOT, cache); - if let Some(h) = tp { h.wait(); } + if let Some(h) = tp { + h.wait(); + } } fn print_generated_token( @@ -952,4 +1157,3 @@ fn print_stream_text(text: &str, in_thinking: bool, use_color: bool) { print!("{text}"); } } - diff --git a/crates/xserv-model/src/bin/xserv-cli.rs b/crates/xserv-model/src/bin/xserv-cli.rs index a7d666e..fd49c82 100644 --- a/crates/xserv-model/src/bin/xserv-cli.rs +++ b/crates/xserv-model/src/bin/xserv-cli.rs @@ -1,6 +1,6 @@ use std::io::{self, Write}; use std::path::PathBuf; -use xserv_model::{loader, KVCache, ModelConfig, PagedKVCache, BLOCK_SIZE}; +use xserv_model::{BLOCK_SIZE, KVCache, ModelConfig, PagedKVCache, loader}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -21,14 +21,21 @@ fn main() { xserv_cuda::device::set_device(0).unwrap(); let info = xserv_cuda::device::device_info(0).unwrap(); - eprintln!("GPU: {} ({} MB free)", info.name, info.free_memory / 1024 / 1024); + eprintln!( + "GPU: {} ({} MB free)", + info.name, + info.free_memory / 1024 / 1024 + ); let config = ModelConfig::from_file(&model_dir.join("config.json")); let model_type = config.model_type.as_deref().unwrap_or("unknown"); eprintln!( "Model: {model_type}, layers={}, hidden={}, heads={}/{} kv, vocab={}", - config.num_layers(), config.hidden(), config.num_heads(), - config.num_kv_heads(), config.vocab_size + config.num_layers(), + config.hidden(), + config.num_heads(), + config.num_kv_heads(), + config.vocab_size ); eprintln!("Loading weights..."); @@ -37,7 +44,11 @@ fn main() { let is_qwen3 = model_type.contains("qwen"); let is_gpt_oss = model_type.contains("gpt_oss"); - let dtype = if is_qwen3 || is_gpt_oss { DType::BF16 } else { DType::F32 }; + let dtype = if is_qwen3 || is_gpt_oss { + DType::BF16 + } else { + DType::F32 + }; // Build model enum Model { @@ -60,10 +71,16 @@ fn main() { print!("xserv> "); io::stdout().flush().unwrap(); let mut input = String::new(); - if io::stdin().read_line(&mut input).unwrap() == 0 { break; } + if io::stdin().read_line(&mut input).unwrap() == 0 { + break; + } let input = input.trim(); - if input.is_empty() { continue; } - if input == "quit" || input == "exit" { break; } + if input.is_empty() { + continue; + } + if input == "quit" || input == "exit" { + break; + } let token_ids = tokenizer.encode(input); @@ -73,12 +90,21 @@ fn main() { let max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 64; let mut paged_cache = PagedKVCache::new( - &config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, 0, + &config, + total_blocks, + 0, + 4, + max_blocks_per_seq, + DType::BF16, + 0, ); let slot = 0; paged_cache.register_sequence(slot).expect("register slot"); - let model = match &model { Model::GptOss(m) => m, _ => unreachable!() }; + let model = match &model { + Model::GptOss(m) => m, + _ => unreachable!(), + }; let logits = model.forward_prefill_paged(&token_ids, slot, &mut paged_cache); let mut next = sample_greedy_last(&logits); @@ -90,20 +116,28 @@ fn main() { print!("{text}"); io::stdout().flush().unwrap(); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } let pos = paged_cache.seq_len(slot); - let logits = model.forward_decode_paged( - &[next], &[pos], &[slot], &mut paged_cache, - ); + let logits = model.forward_decode_paged(&[next], &[pos], &[slot], &mut paged_cache); next = sample_greedy_last(&logits); } println!(); paged_cache.free_sequence(slot); } else { - let kv_heads = if is_qwen3 { config.num_kv_heads() } else { config.num_heads() }; + let kv_heads = if is_qwen3 { + config.num_kv_heads() + } else { + config.num_heads() + }; let mut cache = KVCache::new( - config.num_layers(), kv_heads, config.head_dim(), dtype, Device::Cuda(0), + config.num_layers(), + kv_heads, + config.head_dim(), + dtype, + Device::Cuda(0), ); let logits = match &model { @@ -125,7 +159,9 @@ fn main() { print!("{text}"); io::stdout().flush().unwrap(); - if tokenizer.eos_token_id() == Some(next) { break; } + if tokenizer.eos_token_id() == Some(next) { + break; + } let logits = match &model { Model::GPT2(m) => m.forward_with_cache(&[next], &mut cache), @@ -151,7 +187,9 @@ fn sample_greedy_last(logits: &xserv_tensor::Tensor) -> u32 { let seq_len = logits.shape()[0]; let data = logits_cpu.as_slice::(); let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - last.iter().enumerate() + last.iter() + .enumerate() .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(i, _)| i as u32).unwrap() + .map(|(i, _)| i as u32) + .unwrap() } diff --git a/crates/xserv-model/src/config.rs b/crates/xserv-model/src/config.rs index 1225156..9d5b24b 100644 --- a/crates/xserv-model/src/config.rs +++ b/crates/xserv-model/src/config.rs @@ -88,23 +88,33 @@ impl ModelConfig { } pub fn hidden(&self) -> usize { - self.hidden_size.or(self.n_embd).expect("hidden_size or n_embd required") + self.hidden_size + .or(self.n_embd) + .expect("hidden_size or n_embd required") } pub fn num_heads(&self) -> usize { - self.num_attention_heads.or(self.n_head).expect("num_attention_heads or n_head required") + self.num_attention_heads + .or(self.n_head) + .expect("num_attention_heads or n_head required") } pub fn num_layers(&self) -> usize { - self.num_hidden_layers.or(self.n_layer).expect("num_hidden_layers or n_layer required") + self.num_hidden_layers + .or(self.n_layer) + .expect("num_hidden_layers or n_layer required") } pub fn max_seq_len(&self) -> usize { - self.max_position_embeddings.or(self.n_positions).unwrap_or(2048) + self.max_position_embeddings + .or(self.n_positions) + .unwrap_or(2048) } pub fn ffn_hidden(&self) -> usize { - self.intermediate_size.or(self.n_inner).unwrap_or(self.hidden() * 4) + self.intermediate_size + .or(self.n_inner) + .unwrap_or(self.hidden() * 4) } pub fn num_kv_heads(&self) -> usize { @@ -112,7 +122,8 @@ impl ModelConfig { } pub fn head_dim(&self) -> usize { - self.explicit_head_dim.unwrap_or_else(|| self.hidden() / self.num_heads()) + self.explicit_head_dim + .unwrap_or_else(|| self.hidden() / self.num_heads()) } pub fn ln_eps(&self) -> f32 { diff --git a/crates/xserv-model/src/decode_graph.rs b/crates/xserv-model/src/decode_graph.rs index e78c242..ed518e4 100644 --- a/crates/xserv-model/src/decode_graph.rs +++ b/crates/xserv-model/src/decode_graph.rs @@ -18,19 +18,19 @@ use crate::kv_cache::GpuKVCache; /// All buffers have stable GPU addresses for CUDA Graph replay. struct DecodeBuffers { // Hidden-size buffers: [1, hidden] - x: GpuBuffer, // running hidden state - normed: GpuBuffer, // rmsnorm output - attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim] - attn_merged: GpuBuffer, // merge_heads output [1, hidden] - o_proj: GpuBuffer, // O projection output [1, hidden] - normed2: GpuBuffer, // post-attn norm output [1, hidden] - sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden] - down: GpuBuffer, // down projection output [1, hidden] + x: GpuBuffer, // running hidden state + normed: GpuBuffer, // rmsnorm output + attn_out: GpuBuffer, // attention output [1, num_heads, 1, head_dim] + attn_merged: GpuBuffer, // merge_heads output [1, hidden] + o_proj: GpuBuffer, // O projection output [1, hidden] + normed2: GpuBuffer, // post-attn norm output [1, hidden] + sum_out: GpuBuffer, // add_rmsnorm sum output [1, hidden] + down: GpuBuffer, // down projection output [1, hidden] // QKV projection outputs - q_proj: GpuBuffer, // [1, num_heads * head_dim] - k_proj: GpuBuffer, // [1, num_kv_heads * head_dim] - v_proj: GpuBuffer, // [1, num_kv_heads * head_dim] + q_proj: GpuBuffer, // [1, num_heads * head_dim] + k_proj: GpuBuffer, // [1, num_kv_heads * head_dim] + v_proj: GpuBuffer, // [1, num_kv_heads * head_dim] // Reshaped: [1, H, 1, D] q_reshaped: GpuBuffer, @@ -50,23 +50,23 @@ struct DecodeBuffers { k_final: GpuBuffer, // FFN intermediates - gate: GpuBuffer, // [1, intermediate] - up: GpuBuffer, // [1, intermediate] - silu_out: GpuBuffer, // [1, intermediate] + gate: GpuBuffer, // [1, intermediate] + up: GpuBuffer, // [1, intermediate] + silu_out: GpuBuffer, // [1, intermediate] // GEMV fp32 accumulators (separate per output dimension) - fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs - fp32_q: GpuBuffer, // for Q projection - fp32_kv: GpuBuffer, // for K/V projection - fp32_intermediate: GpuBuffer,// for gate/up projections - fp32_vocab: GpuBuffer, // for lm_head + fp32_hidden: GpuBuffer, // for hidden-sized GEMV outputs + fp32_q: GpuBuffer, // for Q projection + fp32_kv: GpuBuffer, // for K/V projection + fp32_intermediate: GpuBuffer, // for gate/up projections + fp32_vocab: GpuBuffer, // for lm_head // Token ID and position (GPU-resident, updated before replay) - token_id_gpu: GpuBuffer, // 4 bytes (u32) - position_gpu: GpuBuffer, // 4 bytes (u32) + token_id_gpu: GpuBuffer, // 4 bytes (u32) + position_gpu: GpuBuffer, // 4 bytes (u32) // Final output - logits: GpuBuffer, // [1, vocab_size] + logits: GpuBuffer, // [1, vocab_size] } pub struct DecodeGraphState { @@ -199,127 +199,296 @@ impl DecodeGraphState { let cublas = cublas_handle(); // Set cuBLAS to use our stream - unsafe { dispatch::set_cublas_stream(cublas, s); } + unsafe { + dispatch::set_cublas_stream(cublas, s); + } for (l, lw) in layers.iter().enumerate() { // === Pre-attention graph === - self.pre_attn_graphs[l].begin_capture(&self.stream).expect("begin pre-attn capture"); + self.pre_attn_graphs[l] + .begin_capture(&self.stream) + .expect("begin pre-attn capture"); unsafe { // RMSNorm dispatch::rmsnorm_bf16( - self.buffers.x.as_ptr() as _, lw.input_norm, self.buffers.normed.as_mut_ptr() as _, - 1, h, eps, s, + self.buffers.x.as_ptr() as _, + lw.input_norm, + self.buffers.normed.as_mut_ptr() as _, + 1, + h, + eps, + s, ); // Q projection (GEMV) dispatch::gemv_bf16( - self.buffers.normed.as_ptr() as _, lw.q_proj_wt, self.buffers.q_proj.as_mut_ptr() as _, + self.buffers.normed.as_ptr() as _, + lw.q_proj_wt, + self.buffers.q_proj.as_mut_ptr() as _, self.buffers.fp32_q.as_mut_ptr() as _, - h, nh * hd, s, + h, + nh * hd, + s, ); // K projection (GEMV) dispatch::gemv_bf16( - self.buffers.normed.as_ptr() as _, lw.k_proj_wt, self.buffers.k_proj.as_mut_ptr() as _, + self.buffers.normed.as_ptr() as _, + lw.k_proj_wt, + self.buffers.k_proj.as_mut_ptr() as _, self.buffers.fp32_kv.as_mut_ptr() as _, - h, nkv * hd, s, + h, + nkv * hd, + s, ); // V projection (GEMV) dispatch::gemv_bf16( - self.buffers.normed.as_ptr() as _, lw.v_proj_wt, self.buffers.v_proj.as_mut_ptr() as _, + self.buffers.normed.as_ptr() as _, + lw.v_proj_wt, + self.buffers.v_proj.as_mut_ptr() as _, self.buffers.fp32_kv.as_mut_ptr() as _, - h, nkv * hd, s, + h, + nkv * hd, + s, ); // Reshape heads: [1, H*D] -> [1, H, 1, D] - dispatch::reshape_heads_bf16(self.buffers.q_proj.as_ptr() as _, self.buffers.q_reshaped.as_mut_ptr() as _, 1, nh, hd, s); - dispatch::reshape_heads_bf16(self.buffers.k_proj.as_ptr() as _, self.buffers.k_reshaped.as_mut_ptr() as _, 1, nkv, hd, s); - dispatch::reshape_heads_bf16(self.buffers.v_proj.as_ptr() as _, self.buffers.v_reshaped.as_mut_ptr() as _, 1, nkv, hd, s); + dispatch::reshape_heads_bf16( + self.buffers.q_proj.as_ptr() as _, + self.buffers.q_reshaped.as_mut_ptr() as _, + 1, + nh, + hd, + s, + ); + dispatch::reshape_heads_bf16( + self.buffers.k_proj.as_ptr() as _, + self.buffers.k_reshaped.as_mut_ptr() as _, + 1, + nkv, + hd, + s, + ); + dispatch::reshape_heads_bf16( + self.buffers.v_proj.as_ptr() as _, + self.buffers.v_reshaped.as_mut_ptr() as _, + 1, + nkv, + hd, + s, + ); // QK norm (head-level rmsnorm: treat [1,H,1,D] as [H, D]) - dispatch::rmsnorm_bf16(self.buffers.q_reshaped.as_ptr() as _, lw.q_norm, self.buffers.q_normed.as_mut_ptr() as _, nh, hd, eps, s); - dispatch::rmsnorm_bf16(self.buffers.k_reshaped.as_ptr() as _, lw.k_norm, self.buffers.k_normed.as_mut_ptr() as _, nkv, hd, eps, s); + dispatch::rmsnorm_bf16( + self.buffers.q_reshaped.as_ptr() as _, + lw.q_norm, + self.buffers.q_normed.as_mut_ptr() as _, + nh, + hd, + eps, + s, + ); + dispatch::rmsnorm_bf16( + self.buffers.k_reshaped.as_ptr() as _, + lw.k_norm, + self.buffers.k_normed.as_mut_ptr() as _, + nkv, + hd, + eps, + s, + ); // Transpose for RoPE: [1,H,1,D] -> [1,H,D] - dispatch::transpose_hsd_to_shd_bf16(self.buffers.q_normed.as_ptr() as _, self.buffers.q_rope.as_mut_ptr() as _, 1, nh, hd, s); - dispatch::transpose_hsd_to_shd_bf16(self.buffers.k_normed.as_ptr() as _, self.buffers.k_rope.as_mut_ptr() as _, 1, nkv, hd, s); + dispatch::transpose_hsd_to_shd_bf16( + self.buffers.q_normed.as_ptr() as _, + self.buffers.q_rope.as_mut_ptr() as _, + 1, + nh, + hd, + s, + ); + dispatch::transpose_hsd_to_shd_bf16( + self.buffers.k_normed.as_ptr() as _, + self.buffers.k_rope.as_mut_ptr() as _, + 1, + nkv, + hd, + s, + ); // RoPE (in-place, reads position_gpu) - dispatch::rope_bf16(self.buffers.q_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nh, hd, s); - dispatch::rope_bf16(self.buffers.k_rope.as_mut_ptr() as _, rope_cos, rope_sin, self.buffers.position_gpu.as_ptr() as _, 1, nkv, hd, s); + dispatch::rope_bf16( + self.buffers.q_rope.as_mut_ptr() as _, + rope_cos, + rope_sin, + self.buffers.position_gpu.as_ptr() as _, + 1, + nh, + hd, + s, + ); + dispatch::rope_bf16( + self.buffers.k_rope.as_mut_ptr() as _, + rope_cos, + rope_sin, + self.buffers.position_gpu.as_ptr() as _, + 1, + nkv, + hd, + s, + ); // Transpose back: [1,H,D] -> [1,H,1,D] - dispatch::transpose_shd_to_hsd_bf16(self.buffers.q_rope.as_ptr() as _, self.buffers.q_final.as_mut_ptr() as _, 1, nh, hd, s); - dispatch::transpose_shd_to_hsd_bf16(self.buffers.k_rope.as_ptr() as _, self.buffers.k_final.as_mut_ptr() as _, 1, nkv, hd, s); + dispatch::transpose_shd_to_hsd_bf16( + self.buffers.q_rope.as_ptr() as _, + self.buffers.q_final.as_mut_ptr() as _, + 1, + nh, + hd, + s, + ); + dispatch::transpose_shd_to_hsd_bf16( + self.buffers.k_rope.as_ptr() as _, + self.buffers.k_final.as_mut_ptr() as _, + 1, + nkv, + hd, + s, + ); } - self.pre_attn_graphs[l].end_capture(&self.stream).expect("end pre-attn capture"); + self.pre_attn_graphs[l] + .end_capture(&self.stream) + .expect("end pre-attn capture"); // === Post-attention graph === - self.post_attn_graphs[l].begin_capture(&self.stream).expect("begin post-attn capture"); + self.post_attn_graphs[l] + .begin_capture(&self.stream) + .expect("begin post-attn capture"); unsafe { // Merge heads: [1,H,1,D] -> [1, hidden] // attn_out is written by ungraphed attention - dispatch::merge_heads_bf16(self.buffers.attn_out.as_ptr() as _, self.buffers.attn_merged.as_mut_ptr() as _, 1, nh, hd, s); + dispatch::merge_heads_bf16( + self.buffers.attn_out.as_ptr() as _, + self.buffers.attn_merged.as_mut_ptr() as _, + 1, + nh, + hd, + s, + ); // O projection dispatch::gemv_bf16( - self.buffers.attn_merged.as_ptr() as _, lw.o_proj_wt, self.buffers.o_proj.as_mut_ptr() as _, + self.buffers.attn_merged.as_ptr() as _, + lw.o_proj_wt, + self.buffers.o_proj.as_mut_ptr() as _, self.buffers.fp32_hidden.as_mut_ptr() as _, - nh * hd, h, s, + nh * hd, + h, + s, ); // Fused Add+RMSNorm: normed2 = rmsnorm(o_proj + x), sum_out = o_proj + x dispatch::add_rmsnorm_bf16( - self.buffers.o_proj.as_ptr() as _, self.buffers.x.as_ptr() as _, lw.post_norm, - self.buffers.normed2.as_mut_ptr() as _, self.buffers.sum_out.as_mut_ptr() as _, - 1, h, eps, s, + self.buffers.o_proj.as_ptr() as _, + self.buffers.x.as_ptr() as _, + lw.post_norm, + self.buffers.normed2.as_mut_ptr() as _, + self.buffers.sum_out.as_mut_ptr() as _, + 1, + h, + eps, + s, ); // Gate projection dispatch::gemv_bf16( - self.buffers.normed2.as_ptr() as _, lw.gate_proj_wt, self.buffers.gate.as_mut_ptr() as _, + self.buffers.normed2.as_ptr() as _, + lw.gate_proj_wt, + self.buffers.gate.as_mut_ptr() as _, self.buffers.fp32_intermediate.as_mut_ptr() as _, - h, inter, s, + h, + inter, + s, ); // Up projection dispatch::gemv_bf16( - self.buffers.normed2.as_ptr() as _, lw.up_proj_wt, self.buffers.up.as_mut_ptr() as _, + self.buffers.normed2.as_ptr() as _, + lw.up_proj_wt, + self.buffers.up.as_mut_ptr() as _, self.buffers.fp32_intermediate.as_mut_ptr() as _, - h, inter, s, + h, + inter, + s, ); // Fused SiLU x Mul - dispatch::silu_mul_bf16(self.buffers.gate.as_ptr() as _, self.buffers.up.as_ptr() as _, self.buffers.silu_out.as_mut_ptr() as _, inter, s); + dispatch::silu_mul_bf16( + self.buffers.gate.as_ptr() as _, + self.buffers.up.as_ptr() as _, + self.buffers.silu_out.as_mut_ptr() as _, + inter, + s, + ); // Down projection dispatch::gemv_bf16( - self.buffers.silu_out.as_ptr() as _, lw.down_proj_wt, self.buffers.down.as_mut_ptr() as _, + self.buffers.silu_out.as_ptr() as _, + lw.down_proj_wt, + self.buffers.down.as_mut_ptr() as _, self.buffers.fp32_hidden.as_mut_ptr() as _, - inter, h, s, + inter, + h, + s, ); // x = sum_out + down (residual connection for next layer) - dispatch::add_bf16(self.buffers.sum_out.as_ptr() as _, self.buffers.down.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, h, s); + dispatch::add_bf16( + self.buffers.sum_out.as_ptr() as _, + self.buffers.down.as_ptr() as _, + self.buffers.x.as_mut_ptr() as _, + h, + s, + ); } - self.post_attn_graphs[l].end_capture(&self.stream).expect("end post-attn capture"); + self.post_attn_graphs[l] + .end_capture(&self.stream) + .expect("end post-attn capture"); } // === Final graph: norm + lm_head === - self.final_graph.begin_capture(&self.stream).expect("begin final capture"); + self.final_graph + .begin_capture(&self.stream) + .expect("begin final capture"); unsafe { - dispatch::rmsnorm_bf16(self.buffers.x.as_ptr() as _, norm_weight, self.buffers.normed.as_mut_ptr() as _, 1, h, eps, s); + dispatch::rmsnorm_bf16( + self.buffers.x.as_ptr() as _, + norm_weight, + self.buffers.normed.as_mut_ptr() as _, + 1, + h, + eps, + s, + ); dispatch::gemv_bf16( - self.buffers.normed.as_ptr() as _, lm_head_wt, self.buffers.logits.as_mut_ptr() as _, + self.buffers.normed.as_ptr() as _, + lm_head_wt, + self.buffers.logits.as_mut_ptr() as _, self.buffers.fp32_vocab.as_mut_ptr() as _, - h, vocab, s, + h, + vocab, + s, ); } - self.final_graph.end_capture(&self.stream).expect("end final capture"); + self.final_graph + .end_capture(&self.stream) + .expect("end final capture"); // Reset cuBLAS back to null stream - unsafe { dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); } + unsafe { + dispatch::set_cublas_stream(cublas, std::ptr::null_mut()); + } self.captured = true; } @@ -343,8 +512,14 @@ impl DecodeGraphState { let es = 2usize; // BF16 // Upload token ID and position to fixed GPU buffers - self.buffers.token_id_gpu.copy_from_host(&token_id.to_le_bytes()).unwrap(); - self.buffers.position_gpu.copy_from_host(&position.to_le_bytes()).unwrap(); + self.buffers + .token_id_gpu + .copy_from_host(&token_id.to_le_bytes()) + .unwrap(); + self.buffers + .position_gpu + .copy_from_host(&position.to_le_bytes()) + .unwrap(); // Embedding (outside graph since token_id changes each step) unsafe { @@ -352,13 +527,18 @@ impl DecodeGraphState { embed_table, self.buffers.token_id_gpu.as_ptr() as _, self.buffers.x.as_mut_ptr() as _, - 1, hidden_size, vocab_size, s, + 1, + hidden_size, + vocab_size, + s, ); } for l in 0..self.num_layers { // Pre-attention graph (norm + QKV + reshape + QK-norm + RoPE) - self.pre_attn_graphs[l].launch(&self.stream).expect("launch pre-attn graph"); + self.pre_attn_graphs[l] + .launch(&self.stream) + .expect("launch pre-attn graph"); // Ungraphed: KV cache append // k_final shape: [1, num_kv_heads, 1, head_dim] (after RoPE pipeline) @@ -402,9 +582,13 @@ impl DecodeGraphState { k_full.data_ptr() as _, v_full.data_ptr() as _, self.buffers.attn_out.as_mut_ptr() as _, - 1, nh as i32, nkv as i32, - kv_len, hd as i32, - scale, s, + 1, + nh as i32, + nkv as i32, + kv_len, + hd as i32, + scale, + s, ); } @@ -412,11 +596,15 @@ impl DecodeGraphState { self.stream.synchronize().expect("sync before post-attn"); // Post-attention graph (merge + O-proj + add_rmsnorm + FFN + residual) - self.post_attn_graphs[l].launch(&self.stream).expect("launch post-attn graph"); + self.post_attn_graphs[l] + .launch(&self.stream) + .expect("launch post-attn graph"); } // Final graph (norm + lm_head) - self.final_graph.launch(&self.stream).expect("launch final graph"); + self.final_graph + .launch(&self.stream) + .expect("launch final graph"); // Sync to ensure logits are ready self.stream.synchronize().expect("sync after decode"); diff --git a/crates/xserv-model/src/gpt2.rs b/crates/xserv-model/src/gpt2.rs index f44c3c4..797773a 100644 --- a/crates/xserv-model/src/gpt2.rs +++ b/crates/xserv-model/src/gpt2.rs @@ -31,7 +31,7 @@ struct GPT2Block { pub struct KVCache { // Per layer, per head: raw bytes (works for both f32 and bf16) - k: Vec>>, // [num_layers][num_heads][seq_len * head_dim * elem_size] + k: Vec>>, // [num_layers][num_heads][seq_len * head_dim * elem_size] v: Vec>>, len: usize, num_heads: usize, @@ -42,7 +42,13 @@ pub struct KVCache { } impl KVCache { - pub fn new(num_layers: usize, num_heads: usize, head_dim: usize, dtype: DType, device: Device) -> Self { + pub fn new( + num_layers: usize, + num_heads: usize, + head_dim: usize, + dtype: DType, + device: Device, + ) -> Self { Self { k: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(), v: (0..num_layers).map(|_| vec![vec![]; num_heads]).collect(), @@ -55,10 +61,18 @@ impl KVCache { } } - pub fn seq_len(&self) -> usize { self.len } + pub fn seq_len(&self) -> usize { + self.len + } /// Append from a CPU tensor with shape [1, H, new_tokens, D]. - pub fn append_kv_tensor(&mut self, layer: usize, k_cpu: &Tensor, v_cpu: &Tensor, new_tokens: usize) { + pub fn append_kv_tensor( + &mut self, + layer: usize, + k_cpu: &Tensor, + v_cpu: &Tensor, + new_tokens: usize, + ) { let hd = self.head_dim; let es = self.elem_size; let k_bytes = k_cpu.storage().as_cpu_bytes(); @@ -118,7 +132,8 @@ impl GPT2 { pub fn from_weights(config: ModelConfig, mut w: HashMap) -> Self { crate::init_kernels(); let take = |w: &mut HashMap, name: &str| -> Tensor { - w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}")) + w.remove(name) + .unwrap_or_else(|| panic!("missing weight: {name}")) }; let wte = take(&mut w, "wte.weight"); @@ -147,7 +162,15 @@ impl GPT2 { }); } - Self { config, wte, wpe, layers, ln_f_g, ln_f_b, lm_head } + Self { + config, + wte, + wpe, + layers, + ln_f_g, + ln_f_b, + lm_head, + } } /// Full forward pass without KV cache (for testing / correctness comparison). @@ -179,14 +202,22 @@ impl GPT2 { let head_dim = self.config.head_dim(); let tok_emb = embedding(&self.wte, token_ids); - let pos_ids: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let pos_ids: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); let pos_emb = embedding(&self.wpe, &pos_ids); let mut x = add_tensors(&tok_emb, &pos_emb); for (layer_idx, layer) in self.layers.iter().enumerate() { x = self.transformer_block( - layer, &x, Some((cache, layer_idx)), - pos_offset, new_tokens, num_heads, head_dim, hidden, + layer, + &x, + Some((cache, layer_idx)), + pos_offset, + new_tokens, + num_heads, + head_dim, + hidden, ); } @@ -238,7 +269,11 @@ impl GPT2 { fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor { let out = matmul_2d(x, weight); - if let Some(b) = bias { add_bias(&out, b) } else { out } + if let Some(b) = bias { + add_bias(&out, b) + } else { + out + } } fn matmul_2d(a: &Tensor, b: &Tensor) -> Tensor { @@ -277,7 +312,12 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { } } -fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> (Tensor, Tensor, Tensor) { +fn split_qkv( + qkv: &Tensor, + num_heads: usize, + head_dim: usize, + seq_len: usize, +) -> (Tensor, Tensor, Tensor) { let hidden = num_heads * head_dim; let qkv_cpu = qkv.to_device(Device::Cpu); let device = qkv.device(); @@ -294,14 +334,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> for h in 0..num_heads { let src_off = h * head_dim; let dst_off = (h * seq_len + s) * head_dim; - q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]); - k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]); - v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]); + q_data[dst_off..dst_off + head_dim] + .copy_from_slice(&row[src_off..src_off + head_dim]); + k_data[dst_off..dst_off + head_dim] + .copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]); + v_data[dst_off..dst_off + head_dim].copy_from_slice( + &row[2 * hidden + src_off..2 * hidden + src_off + head_dim], + ); } } - let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device); - let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device); - let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let q = + Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let k = + Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let v = + Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device); (q, k, v) } DType::BF16 => { @@ -314,14 +361,21 @@ fn split_qkv(qkv: &Tensor, num_heads: usize, head_dim: usize, seq_len: usize) -> for h in 0..num_heads { let src_off = h * head_dim; let dst_off = (h * seq_len + s) * head_dim; - q_data[dst_off..dst_off + head_dim].copy_from_slice(&row[src_off..src_off + head_dim]); - k_data[dst_off..dst_off + head_dim].copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]); - v_data[dst_off..dst_off + head_dim].copy_from_slice(&row[2 * hidden + src_off..2 * hidden + src_off + head_dim]); + q_data[dst_off..dst_off + head_dim] + .copy_from_slice(&row[src_off..src_off + head_dim]); + k_data[dst_off..dst_off + head_dim] + .copy_from_slice(&row[hidden + src_off..hidden + src_off + head_dim]); + v_data[dst_off..dst_off + head_dim].copy_from_slice( + &row[2 * hidden + src_off..2 * hidden + src_off + head_dim], + ); } } - let q = Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device); - let k = Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device); - let v = Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let q = + Tensor::from_slice(&q_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let k = + Tensor::from_slice(&k_data, &[1, num_heads, seq_len, head_dim]).to_device(device); + let v = + Tensor::from_slice(&v_data, &[1, num_heads, seq_len, head_dim]).to_device(device); (q, k, v) } _ => panic!("unsupported dtype {:?} in split_qkv", dtype), @@ -343,7 +397,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor { for h in 0..num_heads { let src_off = (h * seq_len + s) * head_dim; let dst_off = s * hidden + h * head_dim; - out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]); + out[dst_off..dst_off + head_dim] + .copy_from_slice(&src[src_off..src_off + head_dim]); } } Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device) @@ -355,7 +410,8 @@ fn merge_heads(x: &Tensor, seq_len: usize, hidden: usize) -> Tensor { for h in 0..num_heads { let src_off = (h * seq_len + s) * head_dim; let dst_off = s * hidden + h * head_dim; - out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]); + out[dst_off..dst_off + head_dim] + .copy_from_slice(&src[src_off..src_off + head_dim]); } } Tensor::from_slice(&out, &[seq_len, hidden]).to_device(device) @@ -372,7 +428,8 @@ pub fn sample_greedy(logits: &Tensor) -> u32 { let vocab_size = logits.shape()[1]; let seq_len = logits.shape()[0]; let last_row = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - last_row.iter() + last_row + .iter() .enumerate() .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) .map(|(idx, _)| idx as u32) diff --git a/crates/xserv-model/src/gpt_oss.rs b/crates/xserv-model/src/gpt_oss.rs index 4c4378c..f6773b8 100644 --- a/crates/xserv-model/src/gpt_oss.rs +++ b/crates/xserv-model/src/gpt_oss.rs @@ -1,6 +1,6 @@ +use half::bf16; use std::collections::HashMap; use std::ffi::c_void; -use half::bf16; use xserv_kernels::*; use xserv_tensor::{Device, Tensor}; @@ -49,10 +49,10 @@ struct GptOssBlock { expert_down_bias: Tensor, // [local_experts, hidden] // FP8 quantized expert weights (Some when running FP8 W8A8) // Transposed layout [E, N, K] for cuBLASLt FP8 (Blackwell requires transA=T) - expert_gate_up_fp8: Option, // [local_experts, 2*inter, hidden] FP8E4M3 - expert_gate_up_scale: Option,// [local_experts] F32 - expert_down_fp8: Option, // [local_experts, hidden, inter] FP8E4M3 - expert_down_scale: Option, // [local_experts] F32 + expert_gate_up_fp8: Option, // [local_experts, 2*inter, hidden] FP8E4M3 + expert_gate_up_scale: Option, // [local_experts] F32 + expert_down_fp8: Option, // [local_experts, hidden, inter] FP8E4M3 + expert_down_scale: Option, // [local_experts] F32 // MXFP4 W4A16 expert weights (Some when running 4-bit weight-only). // (packed [E, N, K/2] u8, scales [E, N, K/32] u8) in [E, N, K] layout. expert_gate_up_mxfp4: Option<(Tensor, Tensor)>, @@ -79,16 +79,23 @@ impl GptOss { crate::init_kernels(); let dev = Device::Cuda(device); let take = |w: &mut HashMap, name: &str| -> Tensor { - w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}")) + w.remove(name) + .unwrap_or_else(|| panic!("missing weight: {name}")) }; let repl = |t: Tensor| -> Tensor { t.to_device(dev) }; // column-parallel: shard rows of [out, in], transpose → [in, out/world] let col = |t: Tensor| -> Tensor { - shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() + shard_rows(&t, rank, world) + .to_device(dev) + .transpose(0, 1) + .contiguous() }; // row-parallel: shard cols of [out, in], transpose → [in/world, out] let row = |t: Tensor| -> Tensor { - shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() + shard_cols(&t, rank, world) + .to_device(dev) + .transpose(0, 1) + .contiguous() }; // Bias sharding helpers let col_bias = |t: Tensor| -> Tensor { shard_1d(&t, rank, world).to_device(dev) }; @@ -97,7 +104,9 @@ impl GptOss { let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight")); let norm = repl(take(&mut w, "model.norm.weight")); let norm_bias = w.remove("model.norm.bias").map(|t| repl(t)); - let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous(); + let lm_head_t = repl(take(&mut w, "lm_head.weight")) + .transpose(0, 1) + .contiguous(); let head_dim = config.head_dim(); let rope_theta = config.rope_theta.unwrap_or(150000.0); @@ -176,15 +185,30 @@ impl GptOss { // MXFP4 stores 4-bit weights in an FP8E4M3 byte container (same dtype // as FP8), so distinguish by the scale rank: FP8 scale is 1-D [E], // MXFP4 scale is 3-D [E, N, K/32]. - let is_mxfp4 = gate_up_scale.as_ref().map(|s| s.ndim() == 3).unwrap_or(false); + let is_mxfp4 = gate_up_scale + .as_ref() + .map(|s| s.ndim() == 3) + .unwrap_or(false); let is_fp8 = !is_mxfp4 && gate_up_3d.dtype() == xserv_tensor::DType::FP8E4M3; let mut expert_gate_up_mxfp4: Option<(Tensor, Tensor)> = None; let mut expert_down_mxfp4: Option<(Tensor, Tensor)> = None; - let inter2 = if is_mxfp4 { gate_up_3d.shape()[1] } else { gate_up_3d.shape()[2] }; // 2*inter (N) - let hidden = if is_mxfp4 { gate_up_3d.shape()[2] * 2 } else { gate_up_3d.shape()[1] }; - let inter = if is_mxfp4 { down_3d.shape()[2] * 2 } else { down_3d.shape()[1] }; + let inter2 = if is_mxfp4 { + gate_up_3d.shape()[1] + } else { + gate_up_3d.shape()[2] + }; // 2*inter (N) + let hidden = if is_mxfp4 { + gate_up_3d.shape()[2] * 2 + } else { + gate_up_3d.shape()[1] + }; + let inter = if is_mxfp4 { + down_3d.shape()[2] * 2 + } else { + down_3d.shape()[1] + }; // Slice the rank's range of experts as contiguous 3D tensors on GPU let expert_gate_up_wt; @@ -199,10 +223,38 @@ impl GptOss { // + scales [E, N, K/32]. Slice this rank's experts (raw bytes). let gu_s = gate_up_scale.expect("MXFP4 model missing gate_up_proj_scale"); let d_s = down_scale.expect("MXFP4 model missing down_proj_scale"); - let gu_packed = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, inter2, hidden / 2).to_device(dev); - let gu_scl = slice_expert_range_3d_raw(&gu_s, expert_start, local_experts, inter2, hidden / 32).to_device(dev); - let dn_packed = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, hidden, inter / 2).to_device(dev); - let dn_scl = slice_expert_range_3d_raw(&d_s, expert_start, local_experts, hidden, inter / 32).to_device(dev); + let gu_packed = slice_expert_range_3d_raw( + &gate_up_3d, + expert_start, + local_experts, + inter2, + hidden / 2, + ) + .to_device(dev); + let gu_scl = slice_expert_range_3d_raw( + &gu_s, + expert_start, + local_experts, + inter2, + hidden / 32, + ) + .to_device(dev); + let dn_packed = slice_expert_range_3d_raw( + &down_3d, + expert_start, + local_experts, + hidden, + inter / 2, + ) + .to_device(dev); + let dn_scl = slice_expert_range_3d_raw( + &d_s, + expert_start, + local_experts, + hidden, + inter / 32, + ) + .to_device(dev); expert_gate_up_mxfp4 = Some((gu_packed, gu_scl)); expert_down_mxfp4 = Some((dn_packed, dn_scl)); expert_gate_up_fp8 = None; @@ -214,36 +266,65 @@ impl GptOss { } else if is_fp8 { // FP8 W8A8 path: load and TRANSPOSE weights for cuBLASLt (requires transA=T on Blackwell). // Original: [E, K, N] → Transposed: [E, N, K] - let gu_sliced = slice_expert_range_3d_raw(&gate_up_3d, expert_start, local_experts, hidden, inter2); - let dn_sliced = slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden); - expert_gate_up_fp8 = Some(transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2).to_device(dev)); - expert_down_fp8 = Some(transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev)); + let gu_sliced = slice_expert_range_3d_raw( + &gate_up_3d, + expert_start, + local_experts, + hidden, + inter2, + ); + let dn_sliced = + slice_expert_range_3d_raw(&down_3d, expert_start, local_experts, inter, hidden); + expert_gate_up_fp8 = Some( + transpose_3d_inner_raw(&gu_sliced, local_experts, hidden, inter2) + .to_device(dev), + ); + expert_down_fp8 = Some( + transpose_3d_inner_raw(&dn_sliced, local_experts, inter, hidden).to_device(dev), + ); // Scales: [num_experts] F32 → slice to [local_experts] let gu_s = gate_up_scale.expect("FP8 model missing gate_up_proj_scale"); let d_s = down_scale.expect("FP8 model missing down_proj_scale"); - expert_gate_up_scale_gpu = Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev)); - expert_down_scale_gpu = Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev)); + expert_gate_up_scale_gpu = + Some(slice_scale_range(&gu_s, expert_start, local_experts).to_device(dev)); + expert_down_scale_gpu = + Some(slice_scale_range(&d_s, expert_start, local_experts).to_device(dev)); // Dummy BF16 tensors (never read in FP8 path) expert_gate_up_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); expert_down_wt = Tensor::empty(&[1, 1, 1], xserv_tensor::DType::BF16, dev); } else { // BF16 path: existing behavior - expert_gate_up_wt = slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2).to_device(dev); - expert_down_wt = slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden).to_device(dev); + expert_gate_up_wt = + slice_expert_range_3d(&gate_up_3d, expert_start, local_experts, hidden, inter2) + .to_device(dev); + expert_down_wt = + slice_expert_range_3d(&down_3d, expert_start, local_experts, inter, hidden) + .to_device(dev); expert_gate_up_fp8 = None; expert_gate_up_scale_gpu = None; expert_down_fp8 = None; expert_down_scale_gpu = None; } - let expert_gate_up_bias = slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2).to_device(dev); - let expert_down_bias = slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden).to_device(dev); + let expert_gate_up_bias = + slice_expert_range_2d(&gate_up_bias_2d, expert_start, local_experts, inter2) + .to_device(dev); + let expert_down_bias = + slice_expert_range_2d(&down_bias_2d, expert_start, local_experts, hidden) + .to_device(dev); xserv_cuda::allocator::cached_trim(); let input_norm = repl(take(&mut w, &format!("{p}.input_layernorm.weight"))); - let input_norm_bias = w.remove(&format!("{p}.input_layernorm.bias")).map(|t| repl(t)); - let post_norm = repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))); - let post_norm_bias = w.remove(&format!("{p}.post_attention_layernorm.bias")).map(|t| repl(t)); + let input_norm_bias = w + .remove(&format!("{p}.input_layernorm.bias")) + .map(|t| repl(t)); + let post_norm = repl(take( + &mut w, + &format!("{p}.post_attention_layernorm.weight"), + )); + let post_norm_bias = w + .remove(&format!("{p}.post_attention_layernorm.bias")) + .map(|t| repl(t)); layers.push(GptOssBlock { input_norm, @@ -283,17 +364,27 @@ impl GptOss { let local_num_kv_heads = config.num_kv_heads() / world; let has_norm_bias = norm_bias.is_some(); - let is_fp8 = layers.first().map(|l| l.expert_gate_up_fp8.is_some()).unwrap_or(false); - let is_mxfp4 = layers.first().map(|l| l.expert_gate_up_mxfp4.is_some()).unwrap_or(false); + let is_fp8 = layers + .first() + .map(|l| l.expert_gate_up_fp8.is_some()) + .unwrap_or(false); + let is_mxfp4 = layers + .first() + .map(|l| l.expert_gate_up_mxfp4.is_some()) + .unwrap_or(false); if rank == 0 { if has_norm_bias { eprintln!("gpt-oss: detected LayerNorm bias — using LayerNorm instead of RMSNorm"); } if is_fp8 { - eprintln!("gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)"); + eprintln!( + "gpt-oss: FP8 E4M3 quantized expert weights detected (W8A8 cuBLASLt mode)" + ); } if is_mxfp4 { - eprintln!("gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)"); + eprintln!( + "gpt-oss: MXFP4 quantized expert weights detected (W4A16 fused-GEMV mode)" + ); } } @@ -341,7 +432,13 @@ impl GptOss { } #[inline] - fn add_norm(x: &Tensor, residual: &Tensor, weight: &Tensor, bias: &Option, eps: f32) -> (Tensor, Tensor) { + fn add_norm( + x: &Tensor, + residual: &Tensor, + weight: &Tensor, + bias: &Option, + eps: f32, + ) -> (Tensor, Tensor) { match bias { Some(b) => { let sum = xserv_kernels::add(x, residual); @@ -439,7 +536,6 @@ impl GptOss { let k_all = add_bias(&matmul_2d(&normed, &layer.k_proj_wt), &layer.k_proj_bias); let v_all = add_bias(&matmul_2d(&normed, &layer.v_proj_wt), &layer.v_proj_bias); - // Reshape for RoPE: [B, H*D] → [B, H, D] let q_3d = q_all.reshape(&[batch, num_heads, head_dim]); let k_3d = k_all.reshape(&[batch, num_kv_heads, head_dim]); @@ -460,9 +556,17 @@ impl GptOss { let sinks_ptr = layer.sinks.data_ptr() as *const c_void; let attn_out = paged_decode_attention_sinks( - &q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr, + &q_4d, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, sinks_ptr, - batch, num_heads, num_kv_heads, head_dim, max_blocks, + batch, + num_heads, + num_kv_heads, + head_dim, + max_blocks, layer.window_size, ); @@ -471,9 +575,14 @@ impl GptOss { self.all_reduce(&attn_proj); let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias); - // Residual + post-norm - let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps); + let (normed, x_new) = Self::add_norm( + &attn_proj, + &residual, + &layer.post_norm, + &layer.post_norm_bias, + eps, + ); let residual = x_new; let normed = normed.contiguous(); @@ -505,7 +614,9 @@ impl GptOss { paged_cache.advance_seq_len(slot, new_tokens); let mut x = embedding(&self.embed_tokens, token_ids); - let positions: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -532,14 +643,21 @@ impl GptOss { let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx); // Flash attention with gpt-oss sinks + (per-layer) sliding window. - let attn_out = flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size); + let attn_out = + flash_attention_sinks(&q, &k_full, &v_full, &layer.sinks, layer.window_size); let attn_merged = merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); self.all_reduce(&attn_proj); let attn_proj = add_bias(&attn_proj, &layer.o_proj_bias); - let (normed, x_new) = Self::add_norm(&attn_proj, &residual, &layer.post_norm, &layer.post_norm_bias, eps); + let (normed, x_new) = Self::add_norm( + &attn_proj, + &residual, + &layer.post_norm, + &layer.post_norm_bias, + eps, + ); let residual = x_new; // MoE MLP @@ -566,15 +684,11 @@ impl GptOss { let expert_start = rank * local_experts; // 1. Router: [tokens, hidden] @ [hidden, num_experts] + bias → [tokens, num_experts] - let router_logits = add_bias( - &matmul_2d(x, &layer.router_wt), - &layer.router_bias, - ); + let router_logits = add_bias(&matmul_2d(x, &layer.router_wt), &layer.router_bias); // 2. GPU top-k + softmax - let (topk_ids, topk_weights) = xserv_kernels::moe::moe_topk_softmax( - &router_logits, num_experts, top_k, - ); + let (topk_ids, topk_weights) = + xserv_kernels::moe::moe_topk_softmax(&router_logits, num_experts, top_k); // Sparse decode path: compute ONLY the routed experts. The dense path // below reads every local expert's weights per forward; the sparse @@ -588,15 +702,31 @@ impl GptOss { let n = packed.shape()[1]; let k = packed.shape()[2] * 2; xserv_kernels::moe::moe_sparse_gemv_mxfp4( - x, packed, scales, &layer.expert_gate_up_bias, &topk_ids, - num_tokens, top_k, n, k, expert_start, local_experts, false, + x, + packed, + scales, + &layer.expert_gate_up_bias, + &topk_ids, + num_tokens, + top_k, + n, + k, + expert_start, + local_experts, + false, ) } else { xserv_kernels::moe::moe_sparse_gemv_fp8( - x, layer.expert_gate_up_fp8.as_ref().unwrap(), + x, + layer.expert_gate_up_fp8.as_ref().unwrap(), layer.expert_gate_up_scale.as_ref().unwrap(), - &layer.expert_gate_up_bias, &topk_ids, - num_tokens, top_k, expert_start, local_experts, false, + &layer.expert_gate_up_bias, + &topk_ids, + num_tokens, + top_k, + expert_start, + local_experts, + false, ) }; @@ -611,20 +741,40 @@ impl GptOss { let n = packed.shape()[1]; let k = packed.shape()[2] * 2; xserv_kernels::moe::moe_sparse_gemv_mxfp4( - &activated, packed, scales, &layer.expert_down_bias, &topk_ids, - num_tokens, top_k, n, k, expert_start, local_experts, true, + &activated, + packed, + scales, + &layer.expert_down_bias, + &topk_ids, + num_tokens, + top_k, + n, + k, + expert_start, + local_experts, + true, ) } else { xserv_kernels::moe::moe_sparse_gemv_fp8( - &activated, layer.expert_down_fp8.as_ref().unwrap(), + &activated, + layer.expert_down_fp8.as_ref().unwrap(), layer.expert_down_scale.as_ref().unwrap(), - &layer.expert_down_bias, &topk_ids, - num_tokens, top_k, expert_start, local_experts, true, + &layer.expert_down_bias, + &topk_ids, + num_tokens, + top_k, + expert_start, + local_experts, + true, ) }; let moe_out = xserv_kernels::moe::moe_weighted_sum_sparse( - &down, &topk_ids, &topk_weights, expert_start, local_experts, + &down, + &topk_ids, + &topk_weights, + expert_start, + local_experts, ); self.all_reduce(&moe_out); return moe_out; @@ -644,14 +794,24 @@ impl GptOss { xserv_kernels::quantization::batched_gemv_mxfp4(&x2, packed, scales, n, k) .reshape(&[local_experts, 1, n]) } else { - let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k); + let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t( + packed, + scales, + local_experts, + n, + k, + ); xserv_kernels::moe::batched_gemm_strided(&x_rep, &w_bf16) } } else if let Some(ref wt_fp8_t) = layer.expert_gate_up_fp8 { // W8A8: quantize activations with per-expert scalar scale, use cuBLASLt FP8 GEMM - let (x_fp8, x_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep); + let (x_fp8, x_scales) = + xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&x_rep); xserv_kernels::quantization::batched_gemm_fp8( - &x_fp8, &x_scales, wt_fp8_t, layer.expert_gate_up_scale.as_ref().unwrap(), + &x_fp8, + &x_scales, + wt_fp8_t, + layer.expert_gate_up_scale.as_ref().unwrap(), ) } else { xserv_kernels::moe::batched_gemm_strided(&x_rep, &layer.expert_gate_up_wt) @@ -677,14 +837,24 @@ impl GptOss { xserv_kernels::quantization::batched_gemv_mxfp4(&a2, packed, scales, n, k) .reshape(&[local_experts, 1, n]) } else { - let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t(packed, scales, local_experts, n, k); + let w_bf16 = xserv_kernels::quantization::dequant_mxfp4_to_bf16_t( + packed, + scales, + local_experts, + n, + k, + ); xserv_kernels::moe::batched_gemm_strided(&activated, &w_bf16) } } else if let Some(ref wt_fp8) = layer.expert_down_fp8 { // W8A8: quantize post-GLU activations to FP8, use cuBLASLt FP8 GEMM - let (act_fp8, act_scales) = xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated); + let (act_fp8, act_scales) = + xserv_kernels::quantization::quantize_bf16_to_fp8_rowwise(&activated); xserv_kernels::quantization::batched_gemm_fp8( - &act_fp8, &act_scales, wt_fp8, layer.expert_down_scale.as_ref().unwrap(), + &act_fp8, + &act_scales, + wt_fp8, + layer.expert_down_scale.as_ref().unwrap(), ) } else { xserv_kernels::moe::batched_gemm_strided(&activated, &layer.expert_down_wt) @@ -695,8 +865,12 @@ impl GptOss { // 9. Weighted sum across experts → [tokens, hidden] let moe_out = xserv_kernels::moe::moe_weighted_sum( - &down, &topk_ids, &topk_weights, - expert_start, local_experts, top_k, + &down, + &topk_ids, + &topk_weights, + expert_start, + local_experts, + top_k, ); self.all_reduce(&moe_out); @@ -708,9 +882,7 @@ impl GptOss { /// Upload a u32 slice to a pooled GPU buffer (synchronous H2D). fn upload_u32(vals: &[u32]) -> xserv_cuda::GpuBuffer { - let bytes = unsafe { - std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) - }; + let bytes = unsafe { std::slice::from_raw_parts(vals.as_ptr() as *const u8, vals.len() * 4) }; let mut buf = xserv_cuda::allocator::cached_alloc(bytes.len()).expect("alloc u32 upload"); buf.copy_from_host(bytes).unwrap(); buf @@ -737,11 +909,16 @@ fn add_bias(x: &Tensor, bias: &Tensor) -> Tensor { } fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor { - if world == 1 { return t.clone(); } + if world == 1 { + return t.clone(); + } let shape = t.shape(); assert_eq!(shape.len(), 2); let (rows, cols) = (shape[0], shape[1]); - assert!(rows % world == 0, "rows {rows} not divisible by world {world}"); + assert!( + rows % world == 0, + "rows {rows} not divisible by world {world}" + ); let local = rows / world; let host = t.to_device(Device::Cpu); let data = host.as_slice::(); @@ -751,11 +928,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor { } fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor { - if world == 1 { return t.clone(); } + if world == 1 { + return t.clone(); + } let shape = t.shape(); assert_eq!(shape.len(), 2); let (rows, cols) = (shape[0], shape[1]); - assert!(cols % world == 0, "cols {cols} not divisible by world {world}"); + assert!( + cols % world == 0, + "cols {cols} not divisible by world {world}" + ); let local = cols / world; let c0 = rank * local; let host = t.to_device(Device::Cpu); @@ -769,11 +951,16 @@ fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor { } fn shard_1d(t: &Tensor, rank: usize, world: usize) -> Tensor { - if world == 1 { return t.clone(); } + if world == 1 { + return t.clone(); + } let shape = t.shape(); assert_eq!(shape.len(), 1); let total = shape[0]; - assert!(total % world == 0, "dim {total} not divisible by world {world}"); + assert!( + total % world == 0, + "dim {total} not divisible by world {world}" + ); let local = total / world; let host = t.to_device(Device::Cpu); let data = host.as_slice::(); @@ -804,7 +991,13 @@ fn transpose_3d_inner_raw(t: &Tensor, batch: usize, rows: usize, cols: usize) -> } /// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor (any dtype, raw bytes). -fn slice_expert_range_3d_raw(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor { +fn slice_expert_range_3d_raw( + t: &Tensor, + start: usize, + count: usize, + rows: usize, + cols: usize, +) -> Tensor { assert_eq!(t.ndim(), 3); let host = t.to_device(Device::Cpu); let elem_size = t.dtype().size_bytes(); @@ -826,7 +1019,13 @@ fn slice_scale_range(t: &Tensor, start: usize, count: usize) -> Tensor { } /// Extract experts [start..start+count) from a [num_experts, rows, cols] 3D tensor -fn slice_expert_range_3d(t: &Tensor, start: usize, count: usize, rows: usize, cols: usize) -> Tensor { +fn slice_expert_range_3d( + t: &Tensor, + start: usize, + count: usize, + rows: usize, + cols: usize, +) -> Tensor { assert_eq!(t.ndim(), 3); let host = t.to_device(Device::Cpu); let data = host.as_slice::(); diff --git a/crates/xserv-model/src/gpt_oss_graph.rs b/crates/xserv-model/src/gpt_oss_graph.rs index 3785427..d1727f8 100644 --- a/crates/xserv-model/src/gpt_oss_graph.rs +++ b/crates/xserv-model/src/gpt_oss_graph.rs @@ -59,7 +59,9 @@ impl GptOssDecodeGraph { model.decode_prepare(&[position], &[slot], cache); ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); - pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap(); + pos_buf + .copy_from_host(&(position as u32).to_le_bytes()) + .unwrap(); // Retained warmup: run the exact step once eagerly with the quarantine // ON. Freed intermediates are held back instead of recycled, so the @@ -88,21 +90,32 @@ impl GptOssDecodeGraph { let logits; { let _guard = xserv_cuda::stream::push_stream(&stream); - graph.begin_capture(&stream).expect("begin decode-graph capture"); + graph + .begin_capture(&stream) + .expect("begin decode-graph capture"); logits = model.decode_core( ids_buf.as_ptr() as *const c_void, pos_buf.as_ptr() as *const c_void, 1, cache, ); - graph.end_capture(&stream).expect("end decode-graph capture"); + graph + .end_capture(&stream) + .expect("end decode-graph capture"); } let arena = allocator::end_retain(); graph.launch(&stream).expect("first decode-graph replay"); cache.advance_seq_len(slot, 1); - Self { stream, graph, ids_buf, pos_buf, logits, _arena: arena } + Self { + stream, + graph, + ids_buf, + pos_buf, + logits, + _arena: arena, + } } /// Run one decode step by replaying the captured graph. @@ -116,8 +129,12 @@ impl GptOssDecodeGraph { ) -> Tensor { model.decode_prepare(&[position], &[slot], cache); self.ids_buf.copy_from_host(&token.to_le_bytes()).unwrap(); - self.pos_buf.copy_from_host(&(position as u32).to_le_bytes()).unwrap(); - self.graph.launch(&self.stream).expect("decode-graph replay"); + self.pos_buf + .copy_from_host(&(position as u32).to_le_bytes()) + .unwrap(); + self.graph + .launch(&self.stream) + .expect("decode-graph replay"); cache.advance_seq_len(slot, 1); // Shallow clone: the caller reads these logits before the next replay // rewrites the underlying buffer. @@ -137,8 +154,14 @@ pub struct GraphedGptOssDecoder { impl GraphedGptOssDecoder { pub fn new() -> Self { - let enabled = std::env::var("XSERV_DECODE_GRAPH").map(|v| v != "0").unwrap_or(true); - Self { graph: None, eager_steps: 0, enabled } + let enabled = std::env::var("XSERV_DECODE_GRAPH") + .map(|v| v != "0") + .unwrap_or(true); + Self { + graph: None, + eager_steps: 0, + enabled, + } } pub fn decode( diff --git a/crates/xserv-model/src/kv_cache.rs b/crates/xserv-model/src/kv_cache.rs index 5daf6cd..6c1d6e5 100644 --- a/crates/xserv-model/src/kv_cache.rs +++ b/crates/xserv-model/src/kv_cache.rs @@ -1,6 +1,6 @@ +use crate::config::ModelConfig; use xserv_cuda::GpuBuffer; use xserv_tensor::{DType, Tensor}; -use crate::config::ModelConfig; /// GPU-resident KV cache. Pre-allocates max_seq_len on GPU, /// appends new K/V via D2D copy at offset (no CPU round-trip). @@ -46,17 +46,43 @@ impl GpuKVCache { v_staging.push(GpuBuffer::alloc(buf_size).expect("alloc KV staging V")); } - Self { k_bufs, v_bufs, k_staging, v_staging, seq_len: 0, max_seq_len, num_kv_heads, head_dim, elem_size, dtype, device } + Self { + k_bufs, + v_bufs, + k_staging, + v_staging, + seq_len: 0, + max_seq_len, + num_kv_heads, + head_dim, + elem_size, + dtype, + device, + } } - pub fn seq_len(&self) -> usize { self.seq_len } - pub fn max_seq_len(&self) -> usize { self.max_seq_len } + pub fn seq_len(&self) -> usize { + self.seq_len + } + pub fn max_seq_len(&self) -> usize { + self.max_seq_len + } /// Append new K/V tensors for a given layer. /// k_new, v_new: [1, num_kv_heads, new_tokens, head_dim] on GPU, contiguous. /// `write_pos` is the sequence position to write at (caller manages this). - pub fn append(&mut self, layer: usize, k_new: &Tensor, v_new: &Tensor, new_tokens: usize, write_pos: usize) { - assert!(write_pos + new_tokens <= self.max_seq_len, "KV cache overflow"); + pub fn append( + &mut self, + layer: usize, + k_new: &Tensor, + v_new: &Tensor, + new_tokens: usize, + write_pos: usize, + ) { + assert!( + write_pos + new_tokens <= self.max_seq_len, + "KV cache overflow" + ); let es = self.elem_size; let hd = self.head_dim; let max_s = self.max_seq_len; @@ -69,14 +95,23 @@ impl GpuKVCache { let src_off = h * new_tokens * hd * es; let dst_off = (h * max_s + write_pos) * hd * es; let count = new_tokens * hd * es; - self.k_bufs[layer].copy_from_device_at(k_src, src_off, dst_off, count).unwrap(); - self.v_bufs[layer].copy_from_device_at(v_src, src_off, dst_off, count).unwrap(); + self.k_bufs[layer] + .copy_from_device_at(k_src, src_off, dst_off, count) + .unwrap(); + self.v_bufs[layer] + .copy_from_device_at(v_src, src_off, dst_off, count) + .unwrap(); } } pub fn advance_seq_len(&mut self, new_tokens: usize) { self.seq_len += new_tokens; - assert!(self.seq_len <= self.max_seq_len, "KV cache seq_len ({}) exceeds max_seq_len ({})", self.seq_len, self.max_seq_len); + assert!( + self.seq_len <= self.max_seq_len, + "KV cache seq_len ({}) exceeds max_seq_len ({})", + self.seq_len, + self.max_seq_len + ); } /// Get K/V cache tensors for a layer up to `seq_len` tokens: [1, num_kv_heads, seq_len, head_dim] @@ -86,7 +121,11 @@ impl GpuKVCache { } pub fn get_kv_len(&mut self, layer: usize, sl: usize) -> (Tensor, Tensor) { - assert!(sl <= self.max_seq_len, "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", self.max_seq_len); + assert!( + sl <= self.max_seq_len, + "get_kv_len: sl ({sl}) exceeds max_seq_len ({})", + self.max_seq_len + ); let hd = self.head_dim; let nh = self.num_kv_heads; let es = self.elem_size; @@ -104,8 +143,12 @@ impl GpuKVCache { let src_off = (h * max_s) * hd * es; let dst_off = (h * sl) * hd * es; let count = sl * hd * es; - k_stg.copy_from_device_at(k_buf, src_off, dst_off, count).unwrap(); - v_stg.copy_from_device_at(v_buf, src_off, dst_off, count).unwrap(); + k_stg + .copy_from_device_at(k_buf, src_off, dst_off, count) + .unwrap(); + v_stg + .copy_from_device_at(v_buf, src_off, dst_off, count) + .unwrap(); } // Grab raw pointers before dropping the mutable borrows let k_ptr = k_stg.as_mut_ptr(); @@ -117,20 +160,35 @@ impl GpuKVCache { // get_kv_len call overwrites the staging buffer). let shape = &[1usize, nh, sl, hd]; let k = unsafe { - tensor_from_gpu_buffer(GpuBuffer::borrow_raw(k_ptr, out_size), shape, self.dtype, self.device) + tensor_from_gpu_buffer( + GpuBuffer::borrow_raw(k_ptr, out_size), + shape, + self.dtype, + self.device, + ) }; let v = unsafe { - tensor_from_gpu_buffer(GpuBuffer::borrow_raw(v_ptr, out_size), shape, self.dtype, self.device) + tensor_from_gpu_buffer( + GpuBuffer::borrow_raw(v_ptr, out_size), + shape, + self.dtype, + self.device, + ) }; (k, v) } } /// Create a Tensor from a GpuBuffer (takes ownership). -unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor { - use xserv_tensor::storage::Storage; - use xserv_tensor::shape::contiguous_strides; +unsafe fn tensor_from_gpu_buffer( + buf: GpuBuffer, + shape: &[usize], + dtype: DType, + device: u32, +) -> Tensor { use smallvec::SmallVec; + use xserv_tensor::shape::contiguous_strides; + use xserv_tensor::storage::Storage; let storage = Storage::cuda(buf, device); Tensor::from_storage( @@ -146,6 +204,11 @@ unsafe fn tensor_from_gpu_buffer(buf: GpuBuffer, shape: &[usize], dtype: DType, /// /// # Safety /// `buf` must be a valid GPU allocation with at least `product(shape) * dtype.size_bytes()` bytes. -pub unsafe fn tensor_from_gpu_buffer_pub(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor { +pub unsafe fn tensor_from_gpu_buffer_pub( + buf: GpuBuffer, + shape: &[usize], + dtype: DType, + device: u32, +) -> Tensor { tensor_from_gpu_buffer(buf, shape, dtype, device) } diff --git a/crates/xserv-model/src/lib.rs b/crates/xserv-model/src/lib.rs index e97a80e..922aa33 100644 --- a/crates/xserv-model/src/lib.rs +++ b/crates/xserv-model/src/lib.rs @@ -11,11 +11,11 @@ pub mod sampling; pub use config::ModelConfig; pub use decode_graph::{DecodeGraphState, LayerWeightPtrs}; -pub use gpt2::{GPT2, KVCache}; pub use gpt_oss::GptOss; pub use gpt_oss_graph::{GptOssDecodeGraph, GraphedGptOssDecoder}; +pub use gpt2::{GPT2, KVCache}; pub use kv_cache::GpuKVCache; -pub use paged_kv_cache::{BlockAllocator, Location, PagedKVCache, BLOCK_SIZE}; +pub use paged_kv_cache::{BLOCK_SIZE, BlockAllocator, Location, PagedKVCache}; pub use qwen3::Qwen3; pub use sampling::{SamplingParams, sample, sample_greedy_penalized}; diff --git a/crates/xserv-model/src/loader.rs b/crates/xserv-model/src/loader.rs index 77b0096..dd2c5e5 100644 --- a/crates/xserv-model/src/loader.rs +++ b/crates/xserv-model/src/loader.rs @@ -5,8 +5,8 @@ use std::path::Path; use xserv_tensor::{DType, Device, Tensor}; pub fn load_safetensors(path: &Path, device: Device) -> HashMap { - let data = std::fs::read(path) - .unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display())); + let data = + std::fs::read(path).unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display())); let st = SafeTensors::deserialize(&data) .unwrap_or_else(|e| panic!("failed to parse safetensors {}: {e}", path.display())); @@ -60,7 +60,11 @@ pub fn load_model_dir(dir: &Path, device: Device) -> HashMap { all_tensors.extend(tensors); } - assert!(!all_tensors.is_empty(), "no safetensors files found in {}", dir.display()); + assert!( + !all_tensors.is_empty(), + "no safetensors files found in {}", + dir.display() + ); all_tensors } @@ -84,8 +88,6 @@ fn make_tensor(raw_bytes: &[u8], shape: &[usize], dtype: DType) -> Tensor { }; Tensor::from_slice(bfs, shape) } - DType::FP8E4M3 => { - Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3) - } + DType::FP8E4M3 => Tensor::from_raw_bytes(raw_bytes, shape, DType::FP8E4M3), } } diff --git a/crates/xserv-model/src/paged_kv_cache.rs b/crates/xserv-model/src/paged_kv_cache.rs index 9b34dbb..50d6256 100644 --- a/crates/xserv-model/src/paged_kv_cache.rs +++ b/crates/xserv-model/src/paged_kv_cache.rs @@ -29,7 +29,10 @@ impl BlockAllocator { for b in (1..total_blocks).rev() { free_stack.push(b as u32); } - Self { free_stack, total: total_blocks } + Self { + free_stack, + total: total_blocks, + } } pub fn alloc(&mut self) -> Option { @@ -136,8 +139,14 @@ impl PagedKVCache { device: u32, ) -> Self { Self::new_tp( - config, config.num_kv_heads(), total_blocks, cpu_total_blocks, - max_seqs, max_blocks_per_seq, dtype, device, + config, + config.num_kv_heads(), + total_blocks, + cpu_total_blocks, + max_seqs, + max_blocks_per_seq, + dtype, + device, ) } @@ -155,7 +164,10 @@ impl PagedKVCache { dtype: DType, device: u32, ) -> Self { - assert!(total_blocks >= 2, "need at least 2 blocks (one is sentinel)"); + assert!( + total_blocks >= 2, + "need at least 2 blocks (one is sentinel)" + ); let num_layers = config.num_layers(); let head_dim = config.head_dim(); let elem_size = dtype.size_bytes(); @@ -179,11 +191,17 @@ impl PagedKVCache { if cpu_total_blocks >= 2 { let cpu_pool_bytes = cpu_total_blocks * block_bytes; for _ in 0..num_layers { - cpu_k_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool")); - cpu_v_pools.push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool")); + cpu_k_pools + .push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU K swap pool")); + cpu_v_pools + .push(PinnedBuffer::alloc(cpu_pool_bytes).expect("alloc CPU V swap pool")); } } - let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { cpu_total_blocks } else { 0 }); + let cpu_allocator = BlockAllocator::new(if cpu_total_blocks >= 2 { + cpu_total_blocks + } else { + 0 + }); let block_table_gpu = GpuBuffer::alloc(max_seqs * max_blocks_per_seq * std::mem::size_of::()) @@ -220,22 +238,49 @@ impl PagedKVCache { } } - pub fn num_layers(&self) -> usize { self.num_layers } - pub fn num_kv_heads(&self) -> usize { self.num_kv_heads } - pub fn head_dim(&self) -> usize { self.head_dim } - pub fn dtype(&self) -> DType { self.dtype } - pub fn max_seqs(&self) -> usize { self.max_seqs } - pub fn max_blocks_per_seq(&self) -> usize { self.max_blocks_per_seq } - pub fn free_blocks(&self) -> usize { self.allocator.free_count() } - pub fn total_blocks(&self) -> usize { self.allocator.total() } + pub fn num_layers(&self) -> usize { + self.num_layers + } + pub fn num_kv_heads(&self) -> usize { + self.num_kv_heads + } + pub fn head_dim(&self) -> usize { + self.head_dim + } + pub fn dtype(&self) -> DType { + self.dtype + } + pub fn max_seqs(&self) -> usize { + self.max_seqs + } + pub fn max_blocks_per_seq(&self) -> usize { + self.max_blocks_per_seq + } + pub fn free_blocks(&self) -> usize { + self.allocator.free_count() + } + pub fn total_blocks(&self) -> usize { + self.allocator.total() + } - pub fn k_pool(&self, layer: usize) -> &GpuBuffer { &self.k_pools[layer] } - pub fn v_pool(&self, layer: usize) -> &GpuBuffer { &self.v_pools[layer] } - pub fn block_table_gpu(&self) -> &GpuBuffer { &self.block_table_gpu } - pub fn context_lens_gpu(&self) -> &GpuBuffer { &self.context_lens_gpu } + pub fn k_pool(&self, layer: usize) -> &GpuBuffer { + &self.k_pools[layer] + } + pub fn v_pool(&self, layer: usize) -> &GpuBuffer { + &self.v_pools[layer] + } + pub fn block_table_gpu(&self) -> &GpuBuffer { + &self.block_table_gpu + } + pub fn context_lens_gpu(&self) -> &GpuBuffer { + &self.context_lens_gpu + } pub fn seq_len(&self, slot: usize) -> usize { - self.seq_states[slot].as_ref().map(|s| s.seq_len).unwrap_or(0) + self.seq_states[slot] + .as_ref() + .map(|s| s.seq_len) + .unwrap_or(0) } pub fn is_slot_free(&self, slot: usize) -> bool { @@ -280,7 +325,11 @@ impl PagedKVCache { let state = self.seq_states[slot].as_ref().expect("unregistered slot"); let cur = state.block_ids.len(); let needed_total = (state.seq_len + new_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE; - if needed_total > cur { needed_total - cur } else { 0 } + if needed_total > cur { + needed_total - cur + } else { + 0 + } } /// Pre-allocate enough physical blocks in `slot` to cover positions @@ -290,8 +339,14 @@ impl PagedKVCache { let state = self.seq_states[slot].as_mut().expect("unregistered slot"); let needed_total = (end_pos + BLOCK_SIZE - 1) / BLOCK_SIZE; while state.block_ids.len() < needed_total { - let b = self.allocator.alloc().expect("out of blocks (caller must check)"); - assert!(state.block_ids.len() < self.max_blocks_per_seq, "block table overflow"); + let b = self + .allocator + .alloc() + .expect("out of blocks (caller must check)"); + assert!( + state.block_ids.len() < self.max_blocks_per_seq, + "block table overflow" + ); state.block_ids.push(b); } } @@ -318,7 +373,9 @@ impl PagedKVCache { num_tokens: usize, start_pos: usize, ) { - if num_tokens == 0 { return; } + if num_tokens == 0 { + return; + } // Make sure blocks exist for the target range. self.ensure_capacity(slot, start_pos + num_tokens); @@ -328,15 +385,21 @@ impl PagedKVCache { // Stage block_ids on the GPU. Pool-allocated so this is essentially // free after the first call (same bucket every step). - let block_ids: Vec = self.seq_states[slot].as_ref().unwrap() - .block_ids.iter().map(|&b| b as i32).collect(); + let block_ids: Vec = self.seq_states[slot] + .as_ref() + .unwrap() + .block_ids + .iter() + .map(|&b| b as i32) + .collect(); let bytes = block_ids.len() * std::mem::size_of::(); - let mut block_ids_gpu = xserv_cuda::allocator::cached_alloc(bytes) - .expect("alloc append block_ids"); - let block_ids_bytes = unsafe { - std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) - }; - block_ids_gpu.copy_from_host(block_ids_bytes).expect("upload block_ids"); + let mut block_ids_gpu = + xserv_cuda::allocator::cached_alloc(bytes).expect("alloc append block_ids"); + let block_ids_bytes = + unsafe { std::slice::from_raw_parts(block_ids.as_ptr() as *const u8, bytes) }; + block_ids_gpu + .copy_from_host(block_ids_bytes) + .expect("upload block_ids"); let k_src = k_new.data_ptr() as *const std::ffi::c_void; let v_src = v_new.data_ptr() as *const std::ffi::c_void; @@ -345,10 +408,16 @@ impl PagedKVCache { unsafe { xserv_kernels::reshape_and_cache_bf16( - k_src, v_src, - k_pool_ptr, v_pool_ptr, + k_src, + v_src, + k_pool_ptr, + v_pool_ptr, block_ids_gpu.as_ptr() as *const i32, - num_tokens, nkv, hd, start_pos, bs, + num_tokens, + nkv, + hd, + start_pos, + bs, xserv_cuda::current_stream_raw(), ); } @@ -378,7 +447,9 @@ impl PagedKVCache { v_new: &Tensor, batch: usize, ) { - if batch == 0 { return; } + if batch == 0 { + return; + } let nkv = self.num_kv_heads; let hd = self.head_dim; debug_assert_eq!(k_new.shape(), &[batch, nkv, hd]); @@ -393,10 +464,17 @@ impl PagedKVCache { unsafe { xserv_kernels::reshape_and_cache_batched_bf16( - k_src, v_src, - k_pool_ptr, v_pool_ptr, - bt_ptr, cl_ptr, - batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq, + k_src, + v_src, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + batch, + nkv, + hd, + BLOCK_SIZE, + self.max_blocks_per_seq, xserv_cuda::current_stream_raw(), ); } @@ -447,7 +525,10 @@ impl PagedKVCache { /// before advance_seq_len has run). pub fn sync_active_batch_with_lens(&mut self, slots: &[usize], kv_lens: &[i32]) { assert_eq!(slots.len(), kv_lens.len()); - assert!(slots.len() <= self.max_seqs, "active batch exceeds max_seqs"); + assert!( + slots.len() <= self.max_seqs, + "active batch exceeds max_seqs" + ); let stride = self.max_blocks_per_seq; for row in &mut self.block_table_host { *row = 0; @@ -456,7 +537,9 @@ impl PagedKVCache { *cl = 0; } for (i, &slot) in slots.iter().enumerate() { - let s = self.seq_states[slot].as_ref().expect("unregistered slot in active batch"); + let s = self.seq_states[slot] + .as_ref() + .expect("unregistered slot in active batch"); let row = &mut self.block_table_host[i * stride..(i + 1) * stride]; for (j, b) in s.block_ids.iter().enumerate() { row[j] = *b as i32; @@ -515,8 +598,12 @@ impl PagedKVCache { let src_off = ((phys * nkv + h) * bs + slot_in_blk) * hd * es; let dst_off = (h * sl + p) * hd * es; let count = chunk * hd * es; - k_dst.copy_from_device_at(k_pool, src_off, dst_off, count).unwrap(); - v_dst.copy_from_device_at(v_pool, src_off, dst_off, count).unwrap(); + k_dst + .copy_from_device_at(k_pool, src_off, dst_off, count) + .unwrap(); + v_dst + .copy_from_device_at(v_pool, src_off, dst_off, count) + .unwrap(); } p += chunk; } @@ -529,16 +616,26 @@ impl PagedKVCache { // ----- Swapping (vLLM-style preemption to pinned host memory) ----- - pub fn free_cpu_blocks(&self) -> usize { self.cpu_allocator.free_count() } - pub fn swap_enabled(&self) -> bool { !self.cpu_k_pools.is_empty() } + pub fn free_cpu_blocks(&self) -> usize { + self.cpu_allocator.free_count() + } + pub fn swap_enabled(&self) -> bool { + !self.cpu_k_pools.is_empty() + } pub fn is_swapped(&self, slot: usize) -> bool { - matches!(self.seq_states[slot].as_ref().map(|s| s.location), Some(Location::Cpu)) + matches!( + self.seq_states[slot].as_ref().map(|s| s.location), + Some(Location::Cpu) + ) } /// Number of physical blocks currently held by `slot` (in either pool). pub fn block_count(&self, slot: usize) -> usize { - self.seq_states[slot].as_ref().map(|s| s.block_ids.len()).unwrap_or(0) + self.seq_states[slot] + .as_ref() + .map(|s| s.block_ids.len()) + .unwrap_or(0) } /// Whether a swapped sequence at `slot` can be brought back (enough free GPU blocks). @@ -554,11 +651,17 @@ impl PagedKVCache { /// Evict `slot`'s KV from GPU to pinned host memory and free its GPU blocks. /// The slot stays registered (location = Cpu); the sequence is paused. pub fn swap_out(&mut self, slot: usize) -> Result<(), &'static str> { - let state = self.seq_states[slot].as_ref().ok_or("swap_out: empty slot")?; - if state.location == Location::Cpu { return Ok(()); } + let state = self.seq_states[slot] + .as_ref() + .ok_or("swap_out: empty slot")?; + if state.location == Location::Cpu { + return Ok(()); + } let gpu_ids = state.block_ids.clone(); let n = gpu_ids.len(); - if !self.cpu_allocator.can_alloc(n) { return Err("swap_out: CPU pool full"); } + if !self.cpu_allocator.can_alloc(n) { + return Err("swap_out: CPU pool full"); + } let cpu_ids: Vec = (0..n) .map(|_| self.cpu_allocator.alloc().expect("checked can_alloc")) @@ -570,10 +673,18 @@ impl PagedKVCache { let g_off = gpu_ids[i] as usize * bb; let c_off = cpu_ids[i] as usize * bb; self.k_pools[layer] - .copy_to_host_at(&mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb) + .copy_to_host_at( + &mut self.cpu_k_pools[layer].as_mut_slice()[c_off..c_off + bb], + g_off, + bb, + ) .unwrap(); self.v_pools[layer] - .copy_to_host_at(&mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], g_off, bb) + .copy_to_host_at( + &mut self.cpu_v_pools[layer].as_mut_slice()[c_off..c_off + bb], + g_off, + bb, + ) .unwrap(); } } @@ -589,11 +700,17 @@ impl PagedKVCache { /// Bring `slot`'s KV back from host to GPU and free its CPU blocks. pub fn swap_in(&mut self, slot: usize) -> Result<(), &'static str> { - let state = self.seq_states[slot].as_ref().ok_or("swap_in: empty slot")?; - if state.location == Location::Gpu { return Ok(()); } + let state = self.seq_states[slot] + .as_ref() + .ok_or("swap_in: empty slot")?; + if state.location == Location::Gpu { + return Ok(()); + } let cpu_ids = state.block_ids.clone(); let n = cpu_ids.len(); - if !self.allocator.can_alloc(n) { return Err("swap_in: GPU pool full"); } + if !self.allocator.can_alloc(n) { + return Err("swap_in: GPU pool full"); + } let gpu_ids: Vec = (0..n) .map(|_| self.allocator.alloc().expect("checked can_alloc")) @@ -605,10 +722,18 @@ impl PagedKVCache { let g_off = gpu_ids[i] as usize * bb; let c_off = cpu_ids[i] as usize * bb; self.k_pools[layer] - .copy_from_host_at(&self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb) + .copy_from_host_at( + &self.cpu_k_pools[layer].as_slice()[c_off..c_off + bb], + g_off, + bb, + ) .unwrap(); self.v_pools[layer] - .copy_from_host_at(&self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], g_off, bb) + .copy_from_host_at( + &self.cpu_v_pools[layer].as_slice()[c_off..c_off + bb], + g_off, + bb, + ) .unwrap(); } } @@ -623,7 +748,12 @@ impl PagedKVCache { } } -unsafe fn tensor_from_owned_buf(buf: GpuBuffer, shape: &[usize], dtype: DType, device: u32) -> Tensor { +unsafe fn tensor_from_owned_buf( + buf: GpuBuffer, + shape: &[usize], + dtype: DType, + device: u32, +) -> Tensor { use smallvec::SmallVec; use xserv_tensor::shape::contiguous_strides; use xserv_tensor::storage::Storage; diff --git a/crates/xserv-model/src/qwen3.rs b/crates/xserv-model/src/qwen3.rs index 30998ef..2eef2a2 100644 --- a/crates/xserv-model/src/qwen3.rs +++ b/crates/xserv-model/src/qwen3.rs @@ -1,5 +1,5 @@ -use std::collections::HashMap; use half::bf16; +use std::collections::HashMap; use xserv_kernels::*; use xserv_tensor::{Device, Tensor}; @@ -13,7 +13,7 @@ pub struct Qwen3 { embed_tokens: Tensor, layers: Vec, norm: Tensor, - lm_head_t: Tensor, // precomputed transpose + lm_head_t: Tensor, // precomputed transpose rope_cache: RopeCache, // Tensor parallelism. `tp` is None (or world==1) for single-GPU; otherwise // this rank holds 1/world of the heads and AllReduces after o_proj/down_proj. @@ -28,22 +28,29 @@ pub struct Qwen3 { } struct Qwen3Block { - input_norm: Tensor, // [hidden] + input_norm: Tensor, // [hidden] qkv_proj_wt: Tensor, // FUSED: [hidden, (H+2*KV)*D] — Q|K|V columns q_dim: usize, // num_heads * head_dim (Q slice boundary) kv_dim: usize, // num_kv_heads * head_dim (K/V slice size) - o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden] - q_norm: Tensor, // [head_dim] - k_norm: Tensor, // [head_dim] - post_norm: Tensor, // [hidden] - gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate] - down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden] + o_proj_wt: Tensor, // TRANSPOSED: [num_heads*head_dim, hidden] + q_norm: Tensor, // [head_dim] + k_norm: Tensor, // [head_dim] + post_norm: Tensor, // [hidden] + gate_up_proj_wt: Tensor, // FUSED: [hidden, 2*intermediate] + down_proj_wt: Tensor, // TRANSPOSED: [intermediate, hidden] } impl Qwen3Block { - fn q_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, 0, self.q_dim) } - fn k_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) } - fn v_proj_wt(&self) -> Tensor { self.qkv_proj_wt.narrow(1, self.q_dim + self.kv_dim, self.kv_dim) } + fn q_proj_wt(&self) -> Tensor { + self.qkv_proj_wt.narrow(1, 0, self.q_dim) + } + fn k_proj_wt(&self) -> Tensor { + self.qkv_proj_wt.narrow(1, self.q_dim, self.kv_dim) + } + fn v_proj_wt(&self) -> Tensor { + self.qkv_proj_wt + .narrow(1, self.q_dim + self.kv_dim, self.kv_dim) + } fn gate_proj_wt(&self) -> Tensor { let half = self.gate_up_proj_wt.shape()[1] / 2; self.gate_up_proj_wt.narrow(1, 0, half) @@ -80,18 +87,31 @@ impl Qwen3 { crate::init_kernels(); let dev = Device::Cuda(device); let take = |w: &mut HashMap, name: &str| -> Tensor { - w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}")) + w.remove(name) + .unwrap_or_else(|| panic!("missing weight: {name}")) }; // Replicated weight: upload whole to this rank's device. let repl = |t: Tensor| -> Tensor { t.to_device(dev) }; // column-parallel: keep this rank's rows of [out, in], upload, transpose → [in, out/world]. - let col = |t: Tensor| -> Tensor { shard_rows(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() }; + let col = |t: Tensor| -> Tensor { + shard_rows(&t, rank, world) + .to_device(dev) + .transpose(0, 1) + .contiguous() + }; // row-parallel: keep this rank's cols of [out, in], upload, transpose → [in/world, out]. - let row = |t: Tensor| -> Tensor { shard_cols(&t, rank, world).to_device(dev).transpose(0, 1).contiguous() }; + let row = |t: Tensor| -> Tensor { + shard_cols(&t, rank, world) + .to_device(dev) + .transpose(0, 1) + .contiguous() + }; let embed_tokens = repl(take(&mut w, "model.embed_tokens.weight")); let norm = repl(take(&mut w, "model.norm.weight")); - let lm_head_t = repl(take(&mut w, "lm_head.weight")).transpose(0, 1).contiguous(); + let lm_head_t = repl(take(&mut w, "lm_head.weight")) + .transpose(0, 1) + .contiguous(); let rope_cache = RopeCache::new( config.max_seq_len(), @@ -102,7 +122,10 @@ impl Qwen3 { let num_layers = config.num_layers(); let mut layers = Vec::with_capacity(num_layers); if rank == 0 { - eprintln!("Loading+sharding weights for {} layers (world={world})...", num_layers); + eprintln!( + "Loading+sharding weights for {} layers (world={world})...", + num_layers + ); } for i in 0..num_layers { let p = format!("model.layers.{i}"); @@ -126,7 +149,10 @@ impl Qwen3 { o_proj_wt: row(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))), q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))), k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))), - post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))), + post_norm: repl(take( + &mut w, + &format!("{p}.post_attention_layernorm.weight"), + )), gate_up_proj_wt, down_proj_wt: row(take(&mut w, &format!("{p}.mlp.down_proj.weight"))), }); @@ -165,7 +191,10 @@ impl Qwen3 { let dev = Device::Cuda(device); assert!(num_stages >= 1); let num_layers = config.num_layers(); - assert!(num_layers % num_stages == 0, "num_layers {num_layers} not divisible by pp {num_stages}"); + assert!( + num_layers % num_stages == 0, + "num_layers {num_layers} not divisible by pp {num_stages}" + ); let per_stage = num_layers / num_stages; let lo = stage * per_stage; let hi = lo + per_stage; @@ -173,16 +202,29 @@ impl Qwen3 { let is_last_stage = stage == num_stages - 1; let take = |w: &mut HashMap, name: &str| -> Tensor { - w.remove(name).unwrap_or_else(|| panic!("missing weight: {name}")) + w.remove(name) + .unwrap_or_else(|| panic!("missing weight: {name}")) }; let repl = |t: Tensor| -> Tensor { t.to_device(dev) }; // Pre-transpose like the TP path's `col`/`row` do for world==1 (no shard). let wt = |t: Tensor| -> Tensor { t.to_device(dev).transpose(0, 1).contiguous() }; let placeholder = || Tensor::from_slice(&[bf16::ZERO], &[1, 1]).to_device(dev); - let embed_tokens = if is_first_stage { repl(take(&mut w, "model.embed_tokens.weight")) } else { placeholder() }; - let norm = if is_last_stage { repl(take(&mut w, "model.norm.weight")) } else { placeholder() }; - let lm_head_t = if is_last_stage { wt(take(&mut w, "lm_head.weight")) } else { placeholder() }; + let embed_tokens = if is_first_stage { + repl(take(&mut w, "model.embed_tokens.weight")) + } else { + placeholder() + }; + let norm = if is_last_stage { + repl(take(&mut w, "model.norm.weight")) + } else { + placeholder() + }; + let lm_head_t = if is_last_stage { + wt(take(&mut w, "lm_head.weight")) + } else { + placeholder() + }; let rope_cache = RopeCache::new( config.max_seq_len(), @@ -217,7 +259,10 @@ impl Qwen3 { o_proj_wt: wt(take(&mut w, &format!("{p}.self_attn.o_proj.weight"))), q_norm: repl(take(&mut w, &format!("{p}.self_attn.q_norm.weight"))), k_norm: repl(take(&mut w, &format!("{p}.self_attn.k_norm.weight"))), - post_norm: repl(take(&mut w, &format!("{p}.post_attention_layernorm.weight"))), + post_norm: repl(take( + &mut w, + &format!("{p}.post_attention_layernorm.weight"), + )), gate_up_proj_wt, down_proj_wt: wt(take(&mut w, &format!("{p}.mlp.down_proj.weight"))), }); @@ -252,8 +297,12 @@ impl Qwen3 { matmul_2d(&x, &self.lm_head_t) } - pub fn pp_is_first(&self) -> bool { self.is_first_stage } - pub fn pp_is_last(&self) -> bool { self.is_last_stage } + pub fn pp_is_first(&self) -> bool { + self.is_first_stage + } + pub fn pp_is_last(&self) -> bool { + self.is_last_stage + } /// PP prefill over THIS stage's layers. `x` is `[S, hidden]` (stage 0: from /// `embed`; otherwise received from the previous stage). Writes K/V for this @@ -276,7 +325,9 @@ impl Qwen3 { paged_cache.ensure_capacity(slot, pos_offset + new_tokens); paged_cache.advance_seq_len(slot, new_tokens); - let positions: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -285,7 +336,9 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q = qkv.narrow(1, 0, layer.q_dim).contiguous(); let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v = qkv + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim); let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim); @@ -305,10 +358,12 @@ impl Qwen3 { let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx); let attn_out = flash_attention(&q, &k_full, &v_full, true); - let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); + let attn_merged = + xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); @@ -356,7 +411,9 @@ impl Qwen3 { let qkv_all = matmul_2d(&normed, &layer.qkv_proj_wt); let q_all = qkv_all.narrow(1, 0, layer.q_dim).contiguous(); let k_all = qkv_all.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v_all = qkv_all.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v_all = qkv_all + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); let mut q_rows: Vec = Vec::with_capacity(batch); for b in 0..batch { @@ -394,14 +451,23 @@ impl Qwen3 { let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; let attn_out = xserv_kernels::paged_decode_attention( - &q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr, - batch, num_heads, num_kv_heads, head_dim, max_blocks, + &q_4d, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + batch, + num_heads, + num_kv_heads, + head_dim, + max_blocks, ); let attn_merged = attn_out.reshape(&[batch, num_heads * head_dim]); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); @@ -441,7 +507,9 @@ impl Qwen3 { let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; let mut x = embedding(&self.embed_tokens, token_ids); - let positions: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -450,7 +518,9 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q = qkv.narrow(1, 0, layer.q_dim).contiguous(); let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v = qkv + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); let q = reshape_heads(&q, new_tokens, num_heads, head_dim); let k = reshape_heads(&k, new_tokens, num_kv_heads, head_dim); @@ -531,7 +601,9 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q_all = qkv.narrow(1, 0, layer.q_dim).contiguous(); let k_all = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v_all = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v_all = qkv + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); // Per-sequence: reshape, qk-norm, RoPE, KV cache, attention, merge let mut attn_outputs: Vec = Vec::with_capacity(batch); @@ -583,7 +655,8 @@ impl Qwen3 { let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); // Fused add + rmsnorm - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); @@ -662,13 +735,15 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); // [B, (H+2*KV)*D] let q_dim = num_heads * head_dim; let kv_dim = num_kv_heads * head_dim; - let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view) - let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view) + let q_all = qkv.narrow(1, 0, q_dim); // [B, H*D] (view) + let k_all = qkv.narrow(1, q_dim, kv_dim); // [B, KV*D] (view) let v_all = qkv.narrow(1, q_dim + kv_dim, kv_dim); // Per-head RMSNorm on contiguous copies (narrow views are strided). let q_flat = q_all.contiguous().reshape(&[batch * num_heads, head_dim]); - let k_flat = k_all.contiguous().reshape(&[batch * num_kv_heads, head_dim]); + let k_flat = k_all + .contiguous() + .reshape(&[batch * num_kv_heads, head_dim]); let q_normed = rmsnorm(&q_flat, &layer.q_norm, eps); let k_normed = rmsnorm(&k_flat, &layer.k_norm, eps); @@ -688,8 +763,16 @@ impl Qwen3 { let k_pool_ptr = paged_cache.k_pool(layer_idx).as_ptr() as *const std::ffi::c_void; let v_pool_ptr = paged_cache.v_pool(layer_idx).as_ptr() as *const std::ffi::c_void; let attn_out = xserv_kernels::paged_decode_attention( - &q_4d, k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr, - batch, num_heads, num_kv_heads, head_dim, max_blocks, + &q_4d, + k_pool_ptr, + v_pool_ptr, + bt_ptr, + cl_ptr, + batch, + num_heads, + num_kv_heads, + head_dim, + max_blocks, ); // attn_out shape [B, H, 1, D] is contiguous-equivalent to [B, H*D]. @@ -697,7 +780,8 @@ impl Qwen3 { let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); self.all_reduce(&attn_proj); // TP: sum partial attention outputs - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); // Fused gate+up projection: one GEMV instead of two. @@ -743,7 +827,9 @@ impl Qwen3 { paged_cache.advance_seq_len(slot, new_tokens); let mut x = embedding(&self.embed_tokens, token_ids); - let positions: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -752,7 +838,9 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q = qkv.narrow(1, 0, layer.q_dim).contiguous(); let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v = qkv + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim); let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim); @@ -773,11 +861,13 @@ impl Qwen3 { let (k_full, v_full) = paged_cache.gather_kv_contiguous(slot, layer_idx); let attn_out = flash_attention(&q, &k_full, &v_full, true); - let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); + let attn_merged = + xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); self.all_reduce(&attn_proj); - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); @@ -805,7 +895,9 @@ impl Qwen3 { let eps = self.config.rms_norm_eps.unwrap_or(1e-6) as f32; let mut x = embedding(&self.embed_tokens, token_ids); - let positions: Vec = (pos_offset..pos_offset + new_tokens).map(|p| p as u32).collect(); + let positions: Vec = (pos_offset..pos_offset + new_tokens) + .map(|p| p as u32) + .collect(); for (layer_idx, layer) in self.layers.iter().enumerate() { let residual = x.clone(); @@ -814,7 +906,9 @@ impl Qwen3 { let qkv = matmul_2d(&normed, &layer.qkv_proj_wt); let q = qkv.narrow(1, 0, layer.q_dim).contiguous(); let k = qkv.narrow(1, layer.q_dim, layer.kv_dim).contiguous(); - let v = qkv.narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim).contiguous(); + let v = qkv + .narrow(1, layer.q_dim + layer.kv_dim, layer.kv_dim) + .contiguous(); let q = xserv_kernels::reshape_heads_gpu(&q, new_tokens, num_heads, head_dim); let k = xserv_kernels::reshape_heads_gpu(&k, new_tokens, num_kv_heads, head_dim); @@ -834,10 +928,12 @@ impl Qwen3 { let (k_full, v_full) = cache.get_kv_len(layer_idx, pos_offset + new_tokens); let attn_out = flash_attention(&q, &k_full, &v_full, true); - let attn_merged = xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); + let attn_merged = + xserv_kernels::merge_heads_gpu(&attn_out, new_tokens, num_heads, head_dim); let attn_proj = matmul_2d(&attn_merged, &layer.o_proj_wt); - let (normed, x_new) = xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); + let (normed, x_new) = + xserv_kernels::add_rmsnorm(&attn_proj, &residual, &layer.post_norm, eps); let residual = x_new.clone(); let gate_up = matmul_2d(&normed, &layer.gate_up_proj_wt); @@ -856,28 +952,33 @@ impl Qwen3 { /// Extract weight pointers for CUDA Graph capture. pub fn layer_weight_ptrs(&self) -> Vec { - self.layers.iter().map(|l| crate::decode_graph::LayerWeightPtrs { - input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void, - q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void, - k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void, - v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void, - o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void, - q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void, - k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void, - post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void, - gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void, - up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void, - down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void, - }).collect() + self.layers + .iter() + .map(|l| crate::decode_graph::LayerWeightPtrs { + input_norm: l.input_norm.data_ptr() as *const std::ffi::c_void, + q_proj_wt: l.q_proj_wt().data_ptr() as *const std::ffi::c_void, + k_proj_wt: l.k_proj_wt().data_ptr() as *const std::ffi::c_void, + v_proj_wt: l.v_proj_wt().data_ptr() as *const std::ffi::c_void, + o_proj_wt: l.o_proj_wt.data_ptr() as *const std::ffi::c_void, + q_norm: l.q_norm.data_ptr() as *const std::ffi::c_void, + k_norm: l.k_norm.data_ptr() as *const std::ffi::c_void, + post_norm: l.post_norm.data_ptr() as *const std::ffi::c_void, + gate_proj_wt: l.gate_proj_wt().data_ptr() as *const std::ffi::c_void, + up_proj_wt: l.up_proj_wt().data_ptr() as *const std::ffi::c_void, + down_proj_wt: l.down_proj_wt.data_ptr() as *const std::ffi::c_void, + }) + .collect() } /// Get pointers needed for CUDA Graph capture. - pub fn graph_capture_ptrs(&self) -> ( - *const std::ffi::c_void, // norm weight - *const std::ffi::c_void, // lm_head_t - *const std::ffi::c_void, // embed_tokens - *const std::ffi::c_void, // rope cos - *const std::ffi::c_void, // rope sin + pub fn graph_capture_ptrs( + &self, + ) -> ( + *const std::ffi::c_void, // norm weight + *const std::ffi::c_void, // lm_head_t + *const std::ffi::c_void, // embed_tokens + *const std::ffi::c_void, // rope cos + *const std::ffi::c_void, // rope sin ) { ( self.norm.data_ptr() as *const std::ffi::c_void, @@ -895,11 +996,16 @@ impl Qwen3 { /// (column-parallel split: split the OUTPUT dim). `world==1` returns the whole. /// Input must be a contiguous CPU (or device) BF16 tensor. fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor { - if world == 1 { return t.clone(); } + if world == 1 { + return t.clone(); + } let shape = t.shape(); assert_eq!(shape.len(), 2, "shard_rows expects 2D weight"); let (rows, cols) = (shape[0], shape[1]); - assert!(rows % world == 0, "rows {rows} not divisible by world {world}"); + assert!( + rows % world == 0, + "rows {rows} not divisible by world {world}" + ); let local = rows / world; let host = t.to_device(Device::Cpu); let data = host.as_slice::(); @@ -911,11 +1017,16 @@ fn shard_rows(t: &Tensor, rank: usize, world: usize) -> Tensor { /// Keep this rank's column-block of a 2D `[rows, cols]` BF16 tensor (row-parallel /// split: split the INPUT dim). Strided copy. `world==1` returns the whole. fn shard_cols(t: &Tensor, rank: usize, world: usize) -> Tensor { - if world == 1 { return t.clone(); } + if world == 1 { + return t.clone(); + } let shape = t.shape(); assert_eq!(shape.len(), 2, "shard_cols expects 2D weight"); let (rows, cols) = (shape[0], shape[1]); - assert!(cols % world == 0, "cols {cols} not divisible by world {world}"); + assert!( + cols % world == 0, + "cols {cols} not divisible by world {world}" + ); let local = cols / world; let c0 = rank * local; let host = t.to_device(Device::Cpu); @@ -1009,7 +1120,9 @@ fn transpose_from_rope(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u } fn repeat_kv(x: &Tensor, n_rep: usize) -> Tensor { - if n_rep == 1 { return x.clone(); } + if n_rep == 1 { + return x.clone(); + } let kv_heads = x.shape()[1]; let seq_len = x.shape()[2]; let head_dim = x.shape()[3]; @@ -1065,11 +1178,16 @@ fn concat_rows(rows: &[Tensor]) -> Tensor { let src_buf = row.storage().gpu_buffer(); let src_offset = row.offset() * elem_size; let dst_offset = b * row_bytes; - out_buf.copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes).unwrap(); + out_buf + .copy_from_device_at(src_buf, src_offset, dst_offset, row_bytes) + .unwrap(); } // Wrap in a Tensor - let device_id = match device { Device::Cuda(id) => id, _ => panic!("expected CUDA device") }; + let device_id = match device { + Device::Cuda(id) => id, + _ => panic!("expected CUDA device"), + }; unsafe { crate::kv_cache::tensor_from_gpu_buffer_pub(out_buf, &[batch, cols], dtype, device_id) } @@ -1082,12 +1200,15 @@ fn cat_cols(tensors: &[&Tensor]) -> Tensor { let dtype = tensors[0].dtype(); let device = tensors[0].device(); let elem = dtype.size_bytes(); - let total_cols: usize = tensors.iter().map(|t| { - assert_eq!(t.ndim(), 2); - assert_eq!(t.shape()[0], rows); - assert!(t.is_contiguous()); - t.shape()[1] - }).sum(); + let total_cols: usize = tensors + .iter() + .map(|t| { + assert_eq!(t.ndim(), 2); + assert_eq!(t.shape()[0], rows); + assert!(t.is_contiguous()); + t.shape()[1] + }) + .sum(); let out = Tensor::empty(&[rows, total_cols], dtype, device); let dst_base = out.data_ptr() as *mut u8; for r in 0..rows { @@ -1126,7 +1247,9 @@ pub fn sample_greedy(logits: &Tensor) -> u32 { let seq_len = logits.shape()[0]; let data = logits_cpu.as_slice::(); let last = &data[(seq_len - 1) * vocab_size..seq_len * vocab_size]; - last.iter().enumerate() + last.iter() + .enumerate() .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(i, _)| i as u32).unwrap() + .map(|(i, _)| i as u32) + .unwrap() } diff --git a/crates/xserv-model/src/sampling.rs b/crates/xserv-model/src/sampling.rs index efa2c4c..5b0a030 100644 --- a/crates/xserv-model/src/sampling.rs +++ b/crates/xserv-model/src/sampling.rs @@ -11,7 +11,11 @@ pub struct SamplingParams { impl Default for SamplingParams { fn default() -> Self { - Self { temperature: 0.0, top_k: 0, top_p: 1.0 } + Self { + temperature: 0.0, + top_k: 0, + top_p: 1.0, + } } } @@ -134,9 +138,14 @@ pub fn sample_greedy_penalized(logits: &Tensor, recent: &[u32], penalty: f32) -> let seq_len = logits.shape()[0]; let logits_cpu = logits.to_device(Device::Cpu); let mut last_row: Vec = match logits.dtype() { - DType::F32 => logits_cpu.as_slice::()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec(), - DType::BF16 => logits_cpu.as_slice::()[(seq_len - 1) * vocab_size..seq_len * vocab_size] - .iter().map(|v| v.to_f32()).collect(), + DType::F32 => { + logits_cpu.as_slice::()[(seq_len - 1) * vocab_size..seq_len * vocab_size].to_vec() + } + DType::BF16 => logits_cpu.as_slice::() + [(seq_len - 1) * vocab_size..seq_len * vocab_size] + .iter() + .map(|v| v.to_f32()) + .collect(), _ => panic!("unsupported dtype for sampling: {:?}", logits.dtype()), }; if penalty > 1.0 { diff --git a/crates/xserv-server/src/api.rs b/crates/xserv-server/src/api.rs index c984700..73eacc6 100644 --- a/crates/xserv-server/src/api.rs +++ b/crates/xserv-server/src/api.rs @@ -72,7 +72,10 @@ impl ChatTemplate { let source = std::fs::read_to_string(&jinja_path) .unwrap_or_else(|e| panic!("failed to read {}: {e}", jinja_path.display())); eprintln!("[chat-template] loaded from {}", jinja_path.display()); - return Self { source, model_type: model_type.to_string() }; + return Self { + source, + model_type: model_type.to_string(), + }; } // 2. Try tokenizer_config.json → chat_template field @@ -82,7 +85,10 @@ impl ChatTemplate { if let Ok(v) = serde_json::from_str::(&data) { if let Some(ct) = v.get("chat_template").and_then(|v| v.as_str()) { eprintln!("[chat-template] loaded from tokenizer_config.json"); - return Self { source: ct.to_string(), model_type: model_type.to_string() }; + return Self { + source: ct.to_string(), + model_type: model_type.to_string(), + }; } } } @@ -90,7 +96,10 @@ impl ChatTemplate { // 3. No template found — use empty source, will fall back to hardcoded eprintln!("[chat-template] no Jinja template found, using hardcoded fallback"); - Self { source: String::new(), model_type: model_type.to_string() } + Self { + source: String::new(), + model_type: model_type.to_string(), + } } pub fn render(&self, messages: &[Message]) -> String { @@ -206,7 +215,10 @@ fn build_prompt_gpt_oss(messages: &[Message]) -> String { prompt.push_str("<|start|>system<|message|>"); prompt.push_str("You are ChatGPT, a large language model trained by OpenAI.\n"); prompt.push_str("Knowledge cutoff: 2024-06\n"); - prompt.push_str(&format!("Current date: {}\n\n", strftime_now("%Y-%m-%d".to_string()))); + prompt.push_str(&format!( + "Current date: {}\n\n", + strftime_now("%Y-%m-%d".to_string()) + )); prompt.push_str("Reasoning: low\n\n"); prompt.push_str("# Valid channels: analysis, commentary, final. Channel must be included for every message."); prompt.push_str("<|end|>"); @@ -334,13 +346,11 @@ async fn chat_non_stream(state: Arc, req: ChatRequest) -> Response { "completion_tokens": completion_token_count, "total_tokens": prompt_token_count + completion_token_count } - })).into_response() + })) + .into_response() } -fn chat_stream( - state: Arc, - req: ChatRequest, -) -> Response { +fn chat_stream(state: Arc, req: ChatRequest) -> Response { let id = format!("chatcmpl-{}", Uuid::new_v4()); let model_name = state.model_name.clone(); let created = unix_timestamp(); @@ -356,7 +366,8 @@ fn chat_stream( if prompt_tokens.len() >= max_seq_len { return bad_request(format!( "prompt is {} tokens, exceeds max_seq_len {}", - prompt_tokens.len(), max_seq_len + prompt_tokens.len(), + max_seq_len )); } let max_tokens = req.max_tokens.min(max_seq_len - prompt_tokens.len()); @@ -413,7 +424,9 @@ fn chat_stream( } }); - Sse::new(ReceiverStream::new(sse_rx)).keep_alive(KeepAlive::default()).into_response() + Sse::new(ReceiverStream::new(sse_rx)) + .keep_alive(KeepAlive::default()) + .into_response() } fn validate_request(req: &ChatRequest, model_name: &str) -> Option { @@ -436,8 +449,13 @@ fn validate_request(req: &ChatRequest, model_name: &str) -> Option { /// prior handler panicked) and returns a clean 503 instead of panicking when the /// engine thread is gone, so one dead engine doesn't cascade into every request. fn submit_to_engine(state: &AppState, req: GenerateRequest) -> Result<(), Response> { - let sender = state.engine_sender.lock().unwrap_or_else(|e| e.into_inner()); - sender.send(req).map_err(|_| service_unavailable("inference engine is not available")) + let sender = state + .engine_sender + .lock() + .unwrap_or_else(|e| e.into_inner()); + sender + .send(req) + .map_err(|_| service_unavailable("inference engine is not available")) } fn service_unavailable(message: impl Into) -> Response { diff --git a/crates/xserv-server/src/engine.rs b/crates/xserv-server/src/engine.rs index e280052..5a98d3f 100644 --- a/crates/xserv-server/src/engine.rs +++ b/crates/xserv-server/src/engine.rs @@ -1,10 +1,10 @@ use std::collections::VecDeque; use std::path::Path; -use std::sync::mpsc; use std::sync::Once; +use std::sync::mpsc; use std::time::Instant; -use xserv_model::{ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample, BLOCK_SIZE}; use xserv_model::loader; +use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, SamplingParams, sample}; use xserv_tensor::{DType, Device}; use xserv_tokenizer::Tokenizer; @@ -109,12 +109,23 @@ impl Engine { (total_blocks * bytes_per_block) as f64 / 1e9, info.free_memory as f64 / 1e9, ); - Self { model, config, tokenizer, max_batch_size, max_seq_len, paged_cache } + Self { + model, + config, + tokenizer, + max_batch_size, + max_seq_len, + paged_cache, + } } - pub fn tokenizer(&self) -> &Tokenizer { &self.tokenizer } + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } - pub fn max_seq_len(&self) -> usize { self.max_seq_len } + pub fn max_seq_len(&self) -> usize { + self.max_seq_len + } /// Main scheduler loop. Receives requests from channel, manages concurrent sequences. /// @@ -134,7 +145,8 @@ impl Engine { loop { // Step 1: Remove finished sequences and return their slots. - let finished_slots: Vec = running.iter() + let finished_slots: Vec = running + .iter() .filter(|s| is_finished(s)) .filter_map(|s| s.seq_slot) .collect(); @@ -147,10 +159,16 @@ impl Engine { // room (oldest first). They resume decoding from where they paused. while running.len() < self.max_batch_size && !swapped.is_empty() { let slot = swapped[0].seq_slot.expect("swapped slot"); - if !self.paged_cache.can_swap_in(slot) { break; } + if !self.paged_cache.can_swap_in(slot) { + break; + } self.paged_cache.swap_in(slot).expect("swap_in"); let seq = swapped.remove(0); - eprintln!("[scheduler] swapped in seq {} ({} blocks)", seq.id, self.paged_cache.block_count(slot)); + eprintln!( + "[scheduler] swapped in seq {} ({} blocks)", + seq.id, + self.paged_cache.block_count(slot) + ); running.push(seq); } @@ -161,14 +179,22 @@ impl Engine { let mut avail = self.paged_cache.free_blocks(); let decode_reserve = running.len(); while running.len() < self.max_batch_size { - let Some(front) = waiting.front() else { break; }; + let Some(front) = waiting.front() else { + break; + }; let prompt_blocks = front.prompt_tokens.len().div_ceil(BLOCK_SIZE).max(1); - if avail < prompt_blocks + decode_reserve { break; } + if avail < prompt_blocks + decode_reserve { + break; + } let free_slot = (0..self.paged_cache.max_seqs()) .find(|&s| self.paged_cache.is_slot_free(s)); - let Some(slot) = free_slot else { break; }; + let Some(slot) = free_slot else { + break; + }; let mut seq = waiting.pop_front().unwrap(); - self.paged_cache.register_sequence(slot).expect("register paged slot"); + self.paged_cache + .register_sequence(slot) + .expect("register paged slot"); seq.seq_slot = Some(slot); running.push(seq); avail -= prompt_blocks; // projected free after this seq prefills @@ -199,7 +225,9 @@ impl Engine { if !seq.prefilled { let slot = seq.seq_slot.expect("slot"); let logits = self.model.forward_prefill_paged( - &seq.prompt_tokens, slot, &mut self.paged_cache, + &seq.prompt_tokens, + slot, + &mut self.paged_cache, ); let next = sample(&logits, &seq.sampling); seq.generated_tokens.push(next); @@ -219,13 +247,18 @@ impl Engine { && !newly_prefilled.contains(&running[p].id) && running[p].seq_slot.is_some() }); - let Some(pos) = victim else { break; }; + let Some(pos) = victim else { + break; + }; let seq = running.remove(pos); let slot = seq.seq_slot.unwrap(); if self.paged_cache.can_swap_out(slot) { let nblocks = self.paged_cache.block_count(slot); self.paged_cache.swap_out(slot).expect("swap_out"); - eprintln!("[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", seq.id); + eprintln!( + "[scheduler] preempt: swapped out seq {} ({nblocks} blocks) to host", + seq.id + ); swapped.push(seq); needed = decode_block_need(&self.paged_cache, &running, &newly_prefilled); } else { @@ -235,7 +268,9 @@ impl Engine { } // Step 5c: Batched paged decode for the surviving prefilled sequences. - let decode_indices: Vec = running.iter().enumerate() + let decode_indices: Vec = running + .iter() + .enumerate() .filter(|(_, s)| s.prefilled && !newly_prefilled.contains(&s.id)) .map(|(i, _)| i) .collect(); @@ -246,25 +281,32 @@ impl Engine { eprintln!("[scheduler] paged decode active"); }); - let tokens: Vec = decode_indices.iter() + let tokens: Vec = decode_indices + .iter() .map(|&i| *running[i].generated_tokens.last().unwrap()) .collect(); - let positions: Vec = decode_indices.iter() + let positions: Vec = decode_indices + .iter() .map(|&i| self.paged_cache.seq_len(running[i].seq_slot.unwrap())) .collect(); - let slots: Vec = decode_indices.iter() + let slots: Vec = decode_indices + .iter() .map(|&i| running[i].seq_slot.unwrap()) .collect(); let logits = self.model.forward_decode_paged( - &tokens, &positions, &slots, &mut self.paged_cache, + &tokens, + &positions, + &slots, + &mut self.paged_cache, ); // Fast path: every active sequence is greedy → run argmax on // the GPU and only D2H the chosen token ids (a few bytes per // sequence) instead of the full [B, vocab_size] BF16 logits // (~1.2 MB for B=4, Qwen3 vocab=152K). - let all_greedy = decode_indices.iter() + let all_greedy = decode_indices + .iter() .all(|&i| running[i].sampling.temperature == 0.0); if all_greedy { let next_ids = xserv_kernels::argmax_bf16_to_host(&logits); @@ -285,11 +327,15 @@ impl Engine { let row_start = j * vocab_size; let row_logits = &data[row_start..row_start + vocab_size]; let next = if running[i].sampling.temperature == 0.0 { - row_logits.iter().enumerate() + row_logits + .iter() + .enumerate() .max_by(|a, b| a.1.to_f32().partial_cmp(&b.1.to_f32()).unwrap()) - .map(|(idx, _)| idx as u32).unwrap() + .map(|(idx, _)| idx as u32) + .unwrap() } else { - let row_tensor = xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]); + let row_tensor = + xserv_tensor::Tensor::from_slice(row_logits, &[1, vocab_size]); sample(&row_tensor, &running[i].sampling) }; running[i].generated_tokens.push(next); @@ -334,7 +380,8 @@ impl Engine { /// Total additional GPU blocks the next decode step needs across all /// currently-decoding (prefilled, not just-prefilled) sequences. fn decode_block_need(paged: &PagedKVCache, running: &[Sequence], newly_prefilled: &[u64]) -> usize { - running.iter() + running + .iter() .filter(|s| s.prefilled && !newly_prefilled.contains(&s.id)) .filter_map(|s| s.seq_slot) .map(|slot| paged.additional_blocks_needed(slot, 1)) @@ -372,8 +419,12 @@ fn send_token_if_nonempty(seq: &Sequence, text: String) { } fn is_finished(seq: &Sequence) -> bool { - if seq.generated_tokens.is_empty() { return false; } + if seq.generated_tokens.is_empty() { + return false; + } let last = *seq.generated_tokens.last().unwrap(); - if seq.generated_tokens.len() >= seq.max_tokens { return true; } + if seq.generated_tokens.len() >= seq.max_tokens { + return true; + } seq.sender.is_closed() || seq.eos_token_id == Some(last) } diff --git a/crates/xserv-server/src/main.rs b/crates/xserv-server/src/main.rs index f7937fa..88e9ed0 100644 --- a/crates/xserv-server/src/main.rs +++ b/crates/xserv-server/src/main.rs @@ -3,10 +3,13 @@ mod engine; mod pp_engine; mod tp_engine; -use axum::{routing::{get, post}, Extension, Router}; -use std::path::PathBuf; -use std::sync::{mpsc, Arc, Mutex}; +use axum::{ + Extension, Router, + routing::{get, post}, +}; use engine::GenerateRequest; +use std::path::PathBuf; +use std::sync::{Arc, Mutex, mpsc}; use xserv_model::ModelConfig; pub struct AppState { @@ -21,40 +24,48 @@ pub struct AppState { async fn main() { let args: Vec = std::env::args().collect(); if args.len() < 2 { - eprintln!("Usage: xserv-server [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]"); + eprintln!( + "Usage: xserv-server [--port PORT] [--max-batch N] [--max-seq-len N] [--swap-space-gb N] [--tp N] [--pp N]" + ); std::process::exit(1); } let model_dir = PathBuf::from(&args[1]); - let port: u16 = args.iter() + let port: u16 = args + .iter() .position(|a| a == "--port") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(8080); - let max_batch: usize = args.iter() + let max_batch: usize = args + .iter() .position(|a| a == "--max-batch") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(4) .max(1); - let requested_max_seq_len: usize = args.iter() + let requested_max_seq_len: usize = args + .iter() .position(|a| a == "--max-seq-len") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(2048) .max(1); - let swap_space_gb: usize = args.iter() + let swap_space_gb: usize = args + .iter() .position(|a| a == "--swap-space-gb") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(8); - let tp: usize = args.iter() + let tp: usize = args + .iter() .position(|a| a == "--tp") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) .unwrap_or(1) .max(1); - let pp: usize = args.iter() + let pp: usize = args + .iter() .position(|a| a == "--pp") .and_then(|i| args.get(i + 1)) .and_then(|s| s.parse().ok()) @@ -69,7 +80,9 @@ async fn main() { // tp=1 (single-rank world) so quantized models can serve on one GPU. let is_gpt_oss = model_config.model_type.as_deref() == Some("gpt_oss"); if pp > 1 && is_gpt_oss { - eprintln!("gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead"); + eprintln!( + "gpt-oss is not supported by the pipeline-parallel engine (Qwen3 only); use --tp instead" + ); std::process::exit(1); } let model_max_seq_len = model_config.max_seq_len(); @@ -84,7 +97,8 @@ async fn main() { ); } - let model_name = model_dir.file_name() + let model_name = model_dir + .file_name() .map(|n| n.to_string_lossy().to_string()) .unwrap_or_else(|| "unknown".to_string()); @@ -99,7 +113,12 @@ async fn main() { // Pipeline-parallel path: stage-0 coordinator + worker stage threads. pp_engine::run_pp(&model_dir_clone, pp, max_seq_len, rx); } else if tp <= 1 && !is_gpt_oss { - let mut engine = engine::Engine::load_with_swap(&model_dir_clone, max_batch, max_seq_len, swap_space_gb); + let mut engine = engine::Engine::load_with_swap( + &model_dir_clone, + max_batch, + max_seq_len, + swap_space_gb, + ); engine.run(rx); } else { // Tensor-parallel path: rank-0 coordinator + worker rank threads. diff --git a/crates/xserv-server/src/pp_engine.rs b/crates/xserv-server/src/pp_engine.rs index 9396fd3..0889165 100644 --- a/crates/xserv-server/src/pp_engine.rs +++ b/crates/xserv-server/src/pp_engine.rs @@ -15,15 +15,15 @@ use std::ffi::c_void; use std::path::{Path, PathBuf}; -use std::sync::mpsc; use std::sync::Arc; +use std::sync::mpsc; use std::thread; use half::bf16; use xserv_distributed::{PpContext, UniqueId}; use xserv_model::loader; use xserv_model::sampling::SamplingParams; -use xserv_model::{sample, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; +use xserv_model::{BLOCK_SIZE, ModelConfig, PagedKVCache, Qwen3, sample}; use xserv_tensor::{DType, Device, Tensor}; use xserv_tokenizer::Tokenizer; @@ -38,9 +38,16 @@ enum PpCommand { Free(usize), /// Receive `[n_tokens, hidden]` from the previous stage, run this stage's /// layers; if last stage, sample with `sampling` and return the token. - Prefill { n_tokens: usize, slot: usize, sampling: SamplingParams }, + Prefill { + n_tokens: usize, + slot: usize, + sampling: SamplingParams, + }, /// Receive `[1, hidden]`, run this stage's layers; last stage samples. - Decode { slot: usize, sampling: SamplingParams }, + Decode { + slot: usize, + sampling: SamplingParams, + }, Shutdown, } @@ -76,9 +83,21 @@ fn build_stage( let max_blocks_per_seq = max_seq_len.div_ceil(BLOCK_SIZE); let total_blocks = max_blocks_per_seq + 8; // v1 serial: one active sequence let cache = PagedKVCache::new( - &stage_config, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device, + &stage_config, + total_blocks, + 0, + 4, + max_blocks_per_seq, + DType::BF16, + device, ); - StageCtx { model, cache, pp, hidden: config.hidden(), device } + StageCtx { + model, + cache, + pp, + hidden: config.hidden(), + device, + } } /// Allocate a zeroed `[n, hidden]` device tensor and receive into it from `peer`. @@ -110,7 +129,15 @@ fn worker_loop( ack_tx: mpsc::Sender<()>, token_tx: mpsc::Sender, ) { - let mut sc = build_stage(&model_dir, &config, stage, world, stage as u32, max_seq_len, id); + let mut sc = build_stage( + &model_dir, + &config, + stage, + world, + stage as u32, + max_seq_len, + id, + ); let is_last = stage == world - 1; let prev = stage - 1; let next = stage + 1; @@ -125,7 +152,11 @@ fn worker_loop( sc.cache.free_sequence(slot); let _ = ack_tx.send(()); } - PpCommand::Prefill { n_tokens, slot, sampling } => { + PpCommand::Prefill { + n_tokens, + slot, + sampling, + } => { let x = recv_hidden(&sc, n_tokens, prev); let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache); if is_last { @@ -155,7 +186,12 @@ fn worker_loop( /// Run the PP coordinator (stage 0) on the calling thread. Spawns worker stages /// 1..world and consumes generation requests from `rx`. -pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver) { +pub fn run_pp( + model_dir: &Path, + world: usize, + max_seq_len: usize, + rx: mpsc::Receiver, +) { assert!(world >= 2, "run_pp requires world >= 2"); let config = ModelConfig::from_file(&model_dir.join("config.json")); assert!( @@ -179,7 +215,17 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece let model_dir = model_dir.to_path_buf(); let config = config.clone(); thread::spawn(move || { - worker_loop(stage, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx, token_tx); + worker_loop( + stage, + world, + id, + model_dir, + config, + max_seq_len, + ctx_rx, + ack_tx, + token_tx, + ); }); } @@ -207,11 +253,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece wait_acks(&ack_rx); // Prefill: embed prompt, run stage-0 layers, push hidden into the pipe. - broadcast(&cmd_txs, PpCommand::Prefill { - n_tokens: req.prompt_tokens.len(), - slot, - sampling: req.sampling.clone(), - }); + broadcast( + &cmd_txs, + PpCommand::Prefill { + n_tokens: req.prompt_tokens.len(), + slot, + sampling: req.sampling.clone(), + }, + ); let x = sc.model.embed(&req.prompt_tokens); let x = sc.model.forward_layers_prefill(x, slot, &mut sc.cache); send_hidden(&sc, &x, next_peer); @@ -228,7 +277,13 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece if generated >= req.max_tokens { break "length"; } - broadcast(&cmd_txs, PpCommand::Decode { slot, sampling: req.sampling.clone() }); + broadcast( + &cmd_txs, + PpCommand::Decode { + slot, + sampling: req.sampling.clone(), + }, + ); let x = sc.model.embed(&[next]); let x = sc.model.forward_layers_decode(x, &[slot], &mut sc.cache); send_hidden(&sc, &x, next_peer); @@ -239,9 +294,14 @@ pub fn run_pp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece let tail = tokenizer.flush_decode_stream(&mut decode_buf); if !tail.is_empty() { - let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail }); + let _ = req.sender.blocking_send(GenerateEvent::Token { + id: next, + text: tail, + }); } - let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() }); + let _ = req.sender.blocking_send(GenerateEvent::Done { + finish_reason: finish.to_string(), + }); broadcast(&cmd_txs, PpCommand::Free(slot)); sc.cache.free_sequence(slot); @@ -258,6 +318,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: & } let text = tokenizer.decode_token_stream(token_id, buf); if !text.is_empty() { - let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text }); + let _ = req + .sender + .blocking_send(GenerateEvent::Token { id: token_id, text }); } } diff --git a/crates/xserv-server/src/tp_engine.rs b/crates/xserv-server/src/tp_engine.rs index 14deba0..b4dad64 100644 --- a/crates/xserv-server/src/tp_engine.rs +++ b/crates/xserv-server/src/tp_engine.rs @@ -13,13 +13,16 @@ //! work; the single-GPU `Engine` still handles TP=1. use std::path::{Path, PathBuf}; -use std::sync::mpsc; use std::sync::Arc; +use std::sync::mpsc; use std::thread; use xserv_distributed::{TpContext, UniqueId}; use xserv_model::loader; -use xserv_model::{sample, sample_greedy_penalized, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, BLOCK_SIZE}; +use xserv_model::{ + BLOCK_SIZE, GptOss, GraphedGptOssDecoder, ModelConfig, PagedKVCache, Qwen3, sample, + sample_greedy_penalized, +}; use xserv_tensor::{DType, Device, Tensor}; use xserv_tokenizer::Tokenizer; @@ -29,8 +32,15 @@ use crate::engine::{GenerateEvent, GenerateRequest}; enum TpCommand { Register(usize), Free(usize), - Prefill { tokens: Vec, slot: usize }, - Decode { tokens: Vec, positions: Vec, slots: Vec }, + Prefill { + tokens: Vec, + slot: usize, + }, + Decode { + tokens: Vec, + positions: Vec, + slots: Vec, + }, Shutdown, } @@ -40,14 +50,25 @@ enum TpModel { } impl TpModel { - fn forward_prefill_paged(&self, tokens: &[u32], slot: usize, cache: &mut PagedKVCache) -> Tensor { + fn forward_prefill_paged( + &self, + tokens: &[u32], + slot: usize, + cache: &mut PagedKVCache, + ) -> Tensor { match self { TpModel::Qwen3(m) => m.forward_prefill_paged(tokens, slot, cache), TpModel::GptOss(m) => m.forward_prefill_paged(tokens, slot, cache), } } - fn forward_decode_paged(&self, tokens: &[u32], positions: &[usize], slots: &[usize], cache: &mut PagedKVCache) -> Tensor { + fn forward_decode_paged( + &self, + tokens: &[u32], + positions: &[usize], + slots: &[usize], + cache: &mut PagedKVCache, + ) -> Tensor { match self { TpModel::Qwen3(m) => m.forward_decode_paged(tokens, positions, slots, cache), TpModel::GptOss(m) => m.forward_decode_paged(tokens, positions, slots, cache), @@ -65,8 +86,12 @@ struct RankCtx { /// (lazy capture, replay thereafter); everything else runs eager. fn rank_decode(rc: &mut RankCtx, tokens: &[u32], positions: &[usize], slots: &[usize]) -> Tensor { match &rc.model { - TpModel::GptOss(m) => rc.decoder.decode(m, tokens, positions, slots, &mut rc.cache), - TpModel::Qwen3(_) => rc.model.forward_decode_paged(tokens, positions, slots, &mut rc.cache), + TpModel::GptOss(m) => rc + .decoder + .decode(m, tokens, positions, slots, &mut rc.cache), + TpModel::Qwen3(_) => rc + .model + .forward_decode_paged(tokens, positions, slots, &mut rc.cache), } } @@ -81,17 +106,42 @@ fn build_rank( ) -> RankCtx { let weights = loader::load_model_dir(model_dir, Device::Cpu); let model = if config.is_moe() { - TpModel::GptOss(GptOss::from_weights_tp(config.clone(), weights, rank, world, device, tp)) + TpModel::GptOss(GptOss::from_weights_tp( + config.clone(), + weights, + rank, + world, + device, + tp, + )) } else { - TpModel::Qwen3(Qwen3::from_weights_tp(config.clone(), weights, rank, world, device, tp)) + TpModel::Qwen3(Qwen3::from_weights_tp( + config.clone(), + weights, + rank, + world, + device, + tp, + )) }; let local_kv = config.num_kv_heads() / world; let max_blocks_per_seq = (max_seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; let total_blocks = max_blocks_per_seq + 8; let cache = PagedKVCache::new_tp( - config, local_kv, total_blocks, 0, 4, max_blocks_per_seq, DType::BF16, device, + config, + local_kv, + total_blocks, + 0, + 4, + max_blocks_per_seq, + DType::BF16, + device, ); - RankCtx { model, cache, decoder: GraphedGptOssDecoder::new() } + RankCtx { + model, + cache, + decoder: GraphedGptOssDecoder::new(), + } } fn worker_loop( @@ -105,7 +155,15 @@ fn worker_loop( ack_tx: mpsc::Sender<()>, ) { let tp = Arc::new(TpContext::init(rank, world, id, rank as u32)); - let mut rc = build_rank(&model_dir, &config, rank, world, rank as u32, max_seq_len, Some(tp)); + let mut rc = build_rank( + &model_dir, + &config, + rank, + world, + rank as u32, + max_seq_len, + Some(tp), + ); while let Ok(cmd) = cmd_rx.recv() { match cmd { TpCommand::Register(slot) => { @@ -115,7 +173,11 @@ fn worker_loop( TpCommand::Prefill { tokens, slot } => { let _ = rc.model.forward_prefill_paged(&tokens, slot, &mut rc.cache); } - TpCommand::Decode { tokens, positions, slots } => { + TpCommand::Decode { + tokens, + positions, + slots, + } => { let _ = rank_decode(&mut rc, &tokens, &positions, &slots); } TpCommand::Shutdown => { @@ -129,7 +191,12 @@ fn worker_loop( /// Run the TP coordinator (rank 0) on the calling thread. Spawns worker ranks /// internally and consumes generation requests from `rx`. -pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Receiver) { +pub fn run_tp( + model_dir: &Path, + world: usize, + max_seq_len: usize, + rx: mpsc::Receiver, +) { // world=1 is a valid single-rank configuration (gpt-oss has no // single-GPU engine path; NCCL init and all_reduce no-op at world=1). assert!(world >= 1, "run_tp requires world >= 1"); @@ -152,7 +219,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece let model_dir = model_dir.to_path_buf(); let config = config.clone(); thread::spawn(move || { - worker_loop(rank, world, id, model_dir, config, max_seq_len, ctx_rx, ack_tx); + worker_loop( + rank, + world, + id, + model_dir, + config, + max_seq_len, + ctx_rx, + ack_tx, + ); }); } @@ -165,10 +241,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece // models loop under pure greedy when numerics diverge from the reference). // Off by default; XSERV_REP_PENALTY>1 enables it over the last // XSERV_REP_WINDOW generated tokens. Applied only on the greedy path. - let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY").ok() - .and_then(|s| s.parse().ok()).unwrap_or(1.0); - let rep_window: usize = std::env::var("XSERV_REP_WINDOW").ok() - .and_then(|s| s.parse().ok()).unwrap_or(128); + let rep_penalty: f32 = std::env::var("XSERV_REP_PENALTY") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(1.0); + let rep_window: usize = std::env::var("XSERV_REP_WINDOW") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(128); let pick = |logits: &Tensor, sp: &xserv_model::SamplingParams, history: &[u32]| -> u32 { if rep_penalty > 1.0 && sp.temperature == 0.0 { let start = history.len().saturating_sub(rep_window); @@ -197,8 +277,16 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece wait_acks(&ack_rx); // Prefill. - broadcast(&cmd_txs, TpCommand::Prefill { tokens: req.prompt_tokens.clone(), slot }); - let logits = rc.model.forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache); + broadcast( + &cmd_txs, + TpCommand::Prefill { + tokens: req.prompt_tokens.clone(), + slot, + }, + ); + let logits = rc + .model + .forward_prefill_paged(&req.prompt_tokens, slot, &mut rc.cache); wait_acks(&ack_rx); let mut gen_ids: Vec = Vec::new(); let mut next = pick(&logits, &req.sampling, &gen_ids); @@ -216,7 +304,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece break "length"; } let pos = rc.cache.seq_len(slot); - broadcast(&cmd_txs, TpCommand::Decode { tokens: vec![next], positions: vec![pos], slots: vec![slot] }); + broadcast( + &cmd_txs, + TpCommand::Decode { + tokens: vec![next], + positions: vec![pos], + slots: vec![slot], + }, + ); let logits = rank_decode(&mut rc, &[next], &[pos], &[slot]); wait_acks(&ack_rx); next = pick(&logits, &req.sampling, &gen_ids); @@ -227,9 +322,14 @@ pub fn run_tp(model_dir: &Path, world: usize, max_seq_len: usize, rx: mpsc::Rece let tail = tokenizer.flush_decode_stream(&mut decode_buf); if !tail.is_empty() { - let _ = req.sender.blocking_send(GenerateEvent::Token { id: next, text: tail }); + let _ = req.sender.blocking_send(GenerateEvent::Token { + id: next, + text: tail, + }); } - let _ = req.sender.blocking_send(GenerateEvent::Done { finish_reason: finish.to_string() }); + let _ = req.sender.blocking_send(GenerateEvent::Done { + finish_reason: finish.to_string(), + }); broadcast(&cmd_txs, TpCommand::Free(slot)); rc.cache.free_sequence(slot); @@ -246,6 +346,8 @@ fn emit_text(tokenizer: &Tokenizer, req: &GenerateRequest, token_id: u32, buf: & } let text = tokenizer.decode_token_stream(token_id, buf); if !text.is_empty() { - let _ = req.sender.blocking_send(GenerateEvent::Token { id: token_id, text }); + let _ = req + .sender + .blocking_send(GenerateEvent::Token { id: token_id, text }); } } diff --git a/crates/xserv-tensor/src/dtype.rs b/crates/xserv-tensor/src/dtype.rs index 95d90ec..06b8e30 100644 --- a/crates/xserv-tensor/src/dtype.rs +++ b/crates/xserv-tensor/src/dtype.rs @@ -43,18 +43,30 @@ pub trait TensorDType: Copy + Send + Sync + 'static { impl TensorDType for f32 { const DTYPE: DType = DType::F32; - fn to_f64(self) -> f64 { self as f64 } - fn from_f64(v: f64) -> Self { v as f32 } + fn to_f64(self) -> f64 { + self as f64 + } + fn from_f64(v: f64) -> Self { + v as f32 + } } impl TensorDType for f16 { const DTYPE: DType = DType::F16; - fn to_f64(self) -> f64 { self.to_f32() as f64 } - fn from_f64(v: f64) -> Self { f16::from_f32(v as f32) } + fn to_f64(self) -> f64 { + self.to_f32() as f64 + } + fn from_f64(v: f64) -> Self { + f16::from_f32(v as f32) + } } impl TensorDType for bf16 { const DTYPE: DType = DType::BF16; - fn to_f64(self) -> f64 { self.to_f32() as f64 } - fn from_f64(v: f64) -> Self { bf16::from_f32(v as f32) } + fn to_f64(self) -> f64 { + self.to_f32() as f64 + } + fn from_f64(v: f64) -> Self { + bf16::from_f32(v as f32) + } } diff --git a/crates/xserv-tensor/src/lib.rs b/crates/xserv-tensor/src/lib.rs index 78dc467..a43eeb5 100644 --- a/crates/xserv-tensor/src/lib.rs +++ b/crates/xserv-tensor/src/lib.rs @@ -6,4 +6,4 @@ pub mod tensor; pub use dtype::{DType, TensorDType}; pub use shape::Dims; pub use storage::{Device, Storage}; -pub use tensor::{register_gpu_contiguous, Tensor}; +pub use tensor::{Tensor, register_gpu_contiguous}; diff --git a/crates/xserv-tensor/src/shape.rs b/crates/xserv-tensor/src/shape.rs index bf89593..0a881d5 100644 --- a/crates/xserv-tensor/src/shape.rs +++ b/crates/xserv-tensor/src/shape.rs @@ -46,8 +46,16 @@ pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option { let ndim = a.len().max(b.len()); let mut result = SmallVec::with_capacity(ndim); for i in 0..ndim { - let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] }; - let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] }; + let da = if i < ndim - a.len() { + 1 + } else { + a[i - (ndim - a.len())] + }; + let db = if i < ndim - b.len() { + 1 + } else { + b[i - (ndim - b.len())] + }; if da == db { result.push(da); } else if da == 1 { @@ -100,8 +108,14 @@ mod tests { #[test] fn test_broadcast_shape() { - assert_eq!(broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), &[3, 4]); - assert_eq!(broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), &[2, 3, 4]); + assert_eq!( + broadcast_shape(&[3, 1], &[1, 4]).unwrap().as_slice(), + &[3, 4] + ); + assert_eq!( + broadcast_shape(&[2, 3, 4], &[4]).unwrap().as_slice(), + &[2, 3, 4] + ); assert_eq!(broadcast_shape(&[1], &[5, 3]).unwrap().as_slice(), &[5, 3]); assert!(broadcast_shape(&[3], &[4]).is_none()); } @@ -109,6 +123,9 @@ mod tests { #[test] fn test_broadcast_strides() { // [3,1] with strides [1,1] broadcast to [3,4] - assert_eq!(broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), &[1, 0]); + assert_eq!( + broadcast_strides(&[3, 1], &[1, 1], &[3, 4]).as_slice(), + &[1, 0] + ); } } diff --git a/crates/xserv-tensor/src/tensor.rs b/crates/xserv-tensor/src/tensor.rs index 8c88cfb..7eec2d9 100644 --- a/crates/xserv-tensor/src/tensor.rs +++ b/crates/xserv-tensor/src/tensor.rs @@ -33,8 +33,20 @@ impl Tensor { // --- Creation --- /// Create a tensor from raw components (for advanced use like GPU KV cache). - pub fn from_storage(storage: Storage, shape: Dims, strides: Dims, offset: usize, dtype: DType) -> Self { - Self { storage, shape, strides, offset, dtype } + pub fn from_storage( + storage: Storage, + shape: Dims, + strides: Dims, + offset: usize, + dtype: DType, + ) -> Self { + Self { + storage, + shape, + strides, + offset, + dtype, + } } pub fn from_slice(data: &[T], shape: &[usize]) -> Self { @@ -60,7 +72,10 @@ impl Tensor { data.len(), numel * dtype.size_bytes(), "raw bytes length {} != expected {} (numel={} * elem_size={})", - data.len(), numel * dtype.size_bytes(), numel, dtype.size_bytes() + data.len(), + numel * dtype.size_bytes(), + numel, + dtype.size_bytes() ); Self { storage: Storage::cpu(data.to_vec()), @@ -112,14 +127,28 @@ impl Tensor { // --- Properties --- - pub fn shape(&self) -> &[usize] { &self.shape } - pub fn strides(&self) -> &[usize] { &self.strides } - pub fn dtype(&self) -> DType { self.dtype } - pub fn ndim(&self) -> usize { self.shape.len() } - pub fn numel(&self) -> usize { shape::num_elements(&self.shape) } - pub fn offset(&self) -> usize { self.offset } + pub fn shape(&self) -> &[usize] { + &self.shape + } + pub fn strides(&self) -> &[usize] { + &self.strides + } + pub fn dtype(&self) -> DType { + self.dtype + } + pub fn ndim(&self) -> usize { + self.shape.len() + } + pub fn numel(&self) -> usize { + shape::num_elements(&self.shape) + } + pub fn offset(&self) -> usize { + self.offset + } - pub fn device(&self) -> Device { self.storage.device() } + pub fn device(&self) -> Device { + self.storage.device() + } pub fn is_contiguous(&self) -> bool { shape::is_contiguous(&self.shape, &self.strides) @@ -193,7 +222,11 @@ impl Tensor { shape::contiguous_strides(&new_shape) } else { let mut s = self.strides.clone(); - let stride_val = if dim < self.strides.len() { self.strides[dim] } else { 1 }; + let stride_val = if dim < self.strides.len() { + self.strides[dim] + } else { + 1 + }; s.insert(dim, stride_val); s }; @@ -230,7 +263,12 @@ impl Tensor { let ndim = self.ndim(); let mut idx = vec![0usize; ndim]; for flat in 0..numel { - let src_offset = self.offset + idx.iter().zip(self.strides.iter()).map(|(i, s)| i * s).sum::(); + let src_offset = self.offset + + idx + .iter() + .zip(self.strides.iter()) + .map(|(i, s)| i * s) + .sum::(); let src_byte_offset = src_offset * elem_size; let dst_byte_offset = flat * elem_size; dst[dst_byte_offset..dst_byte_offset + elem_size] @@ -261,7 +299,10 @@ impl Tensor { } // Transfer the raw storage (preserving strides/offset). // Non-contiguous layout is preserved — the user can call contiguous() after. - let new_storage = self.storage.to_device(device).expect("device transfer failed"); + let new_storage = self + .storage + .to_device(device) + .expect("device transfer failed"); Self { storage: new_storage, shape: self.shape.clone(), @@ -310,14 +351,20 @@ impl Tensor { } } - pub fn storage(&self) -> &Storage { &self.storage } + pub fn storage(&self) -> &Storage { + &self.storage + } } impl std::fmt::Debug for Tensor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( - f, "Tensor(shape={:?}, dtype={}, device={}, contiguous={})", - self.shape.as_slice(), self.dtype, self.device(), self.is_contiguous() + f, + "Tensor(shape={:?}, dtype={}, device={}, contiguous={})", + self.shape.as_slice(), + self.dtype, + self.device(), + self.is_contiguous() ) } } diff --git a/crates/xserv-tensor/tests/integration.rs b/crates/xserv-tensor/tests/integration.rs index 99b7bdd..9299e5f 100644 --- a/crates/xserv-tensor/tests/integration.rs +++ b/crates/xserv-tensor/tests/integration.rs @@ -32,7 +32,11 @@ fn test_zeros_and_ones() { #[test] fn test_bf16_tensor() { - let data: Vec = vec![bf16::from_f32(1.0), bf16::from_f32(2.5), bf16::from_f32(-3.0)]; + let data: Vec = vec![ + bf16::from_f32(1.0), + bf16::from_f32(2.5), + bf16::from_f32(-3.0), + ]; let t = Tensor::from_slice(&data, &[3]); assert_eq!(t.dtype(), DType::BF16); let out = t.as_slice::(); diff --git a/crates/xserv-tokenizer/src/bpe.rs b/crates/xserv-tokenizer/src/bpe.rs index b696b5c..1853d4f 100644 --- a/crates/xserv-tokenizer/src/bpe.rs +++ b/crates/xserv-tokenizer/src/bpe.rs @@ -95,11 +95,15 @@ impl Tokenizer { let (a_str, b_str) = match entry { MergeEntry::Str(s) => { let parts: Vec<&str> = s.splitn(2, ' ').collect(); - if parts.len() != 2 { continue; } + if parts.len() != 2 { + continue; + } (parts[0].to_string(), parts[1].to_string()) } MergeEntry::Pair(v) => { - if v.len() != 2 { continue; } + if v.len() != 2 { + continue; + } (v[0].clone(), v[1].clone()) } }; @@ -174,7 +178,10 @@ impl Tokenizer { if byte_fallback { Regex::new(r"[\p{L}\p{N}]+|[^\s\p{L}\p{N}]|\s+").unwrap() } else { - Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap() + Regex::new( + r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+", + ) + .unwrap() } } } @@ -262,7 +269,9 @@ impl Tokenizer { // BPE merges loop { - if token_ids.len() < 2 { break; } + if token_ids.len() < 2 { + break; + } let mut best_rank = usize::MAX; let mut best_idx = 0; for i in 0..token_ids.len() - 1 { @@ -273,12 +282,15 @@ impl Tokenizer { } } } - if best_rank == usize::MAX { break; } + if best_rank == usize::MAX { + break; + } let merged_bytes = [ self.decoder[token_ids[best_idx] as usize].as_slice(), self.decoder[token_ids[best_idx + 1] as usize].as_slice(), - ].concat(); + ] + .concat(); let merged_id = *self.encoder.get(&merged_bytes).unwrap_or_else(|| { panic!("merged token not in vocab"); }); @@ -389,14 +401,13 @@ fn unicode_to_byte(c: char) -> u8 { m }); - *map.get(&(c as u32)).unwrap_or_else(|| { - panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32) - }) + *map.get(&(c as u32)) + .unwrap_or_else(|| panic!("unmapped unicode char U+{:04X} in tokenizer", c as u32)) } #[cfg(test)] mod tests { - use super::{take_valid_utf8, Tokenizer}; + use super::{Tokenizer, take_valid_utf8}; #[test] fn qwen_added_tokens_are_indivisible_and_im_end_is_eos() {