From 4088f49b7d9ac297e46731cc8065f6a40a12c62c Mon Sep 17 00:00:00 2001 From: Gahow Wang Date: Fri, 12 Jun 2026 20:12:37 +0800 Subject: [PATCH] cuda: infrastructure for whole-step CUDA graph capture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Thread-local launch stream (xserv_cuda::stream): every kernel wrapper, cublasSetStream, and NCCL collective now launches on current_stream_raw() — the legacy null stream by default (behavior unchanged), or the capture stream installed via push_stream during graph capture. Capture is impossible on the legacy stream. - Allocator retain mode: blocks freed inside a retain window are quarantined (RetainedBlocks) instead of pooled, so an instantiated graph keeps exclusive ownership of every intermediate buffer it references across replays. - Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads must not poison each other's captures with their own cudaMallocs. - embedding_device_ids / rope_inplace_device_pos: variants reading token ids / positions from persistent device buffers, replacing the per-call host upload that a captured region cannot contain. Co-Authored-By: Claude Fable 5 --- crates/xserv-cuda/src/allocator.rs | 51 ++++++++++++++++++++++++ crates/xserv-cuda/src/ffi.rs | 1 + crates/xserv-cuda/src/graph.rs | 5 ++- crates/xserv-cuda/src/lib.rs | 2 +- crates/xserv-cuda/src/stream.rs | 36 +++++++++++++++++ crates/xserv-distributed/src/lib.rs | 17 ++++---- crates/xserv-kernels/src/activation.rs | 18 ++++----- crates/xserv-kernels/src/argmax.rs | 2 +- crates/xserv-kernels/src/attention.rs | 14 +++---- crates/xserv-kernels/src/embedding.rs | 21 ++++++++-- crates/xserv-kernels/src/gemm.rs | 4 +- crates/xserv-kernels/src/layernorm.rs | 4 +- crates/xserv-kernels/src/lib.rs | 4 +- crates/xserv-kernels/src/moe.rs | 16 ++++---- crates/xserv-kernels/src/quantization.rs | 12 +++--- crates/xserv-kernels/src/rmsnorm.rs | 6 +-- crates/xserv-kernels/src/rope.rs | 25 +++++++++--- crates/xserv-kernels/src/softmax.rs | 4 +- crates/xserv-kernels/src/transpose.rs | 14 +++---- crates/xserv-model/src/paged_kv_cache.rs | 4 +- 20 files changed, 191 insertions(+), 69 deletions(-) diff --git a/crates/xserv-cuda/src/allocator.rs b/crates/xserv-cuda/src/allocator.rs index 3bd9e31..e309dd1 100644 --- a/crates/xserv-cuda/src/allocator.rs +++ b/crates/xserv-cuda/src/allocator.rs @@ -111,6 +111,22 @@ pub fn cached_trim() { /// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer /// and size to avoid re-triggering Drop. pub fn return_to_pool(ptr: *mut u8, len: usize) { + // During CUDA graph capture, buffers freed by the captured code are + // quarantined instead of pooled: the instantiated graph references their + // addresses on every replay, so they must never be handed to another + // consumer for as long as the graph lives. + let quarantined = RETAINED.with(|cell| { + let mut r = cell.borrow_mut(); + if let Some(list) = r.as_mut() { + list.push((ptr, len)); + true + } else { + false + } + }); + if quarantined { + return; + } ALLOCATOR.with(|cell| { let mut alloc = cell.borrow_mut(); let bucket = bucket_size(len); @@ -119,6 +135,41 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) { }); } +thread_local! { + static RETAINED: RefCell>> = const { RefCell::new(None) }; +} + +/// Buffers freed while a retain window was active. Holding this keeps their +/// memory out of the pool; dropping it returns the blocks (on the owning +/// thread) for reuse. +pub struct RetainedBlocks(Vec<(*mut u8, usize)>); + +impl Drop for RetainedBlocks { + fn drop(&mut self) { + for (ptr, len) in self.0.drain(..) { + return_to_pool(ptr, len); + } + } +} + +/// Start quarantining buffers freed on this thread (see `return_to_pool`). +/// Must be paired with `end_retain` on the same thread; nesting unsupported. +pub fn begin_retain() { + RETAINED.with(|cell| { + let mut r = cell.borrow_mut(); + assert!(r.is_none(), "begin_retain: retain window already active"); + *r = Some(Vec::new()); + }); +} + +/// Stop quarantining and hand the quarantined blocks to the caller. +pub fn end_retain() -> RetainedBlocks { + RETAINED.with(|cell| { + let list = cell.borrow_mut().take().expect("end_retain without begin_retain"); + RetainedBlocks(list) + }) +} + /// Round up to next power-of-2, minimum 512 bytes. fn bucket_size(size: usize) -> usize { let min = 512; diff --git a/crates/xserv-cuda/src/ffi.rs b/crates/xserv-cuda/src/ffi.rs index 191d4fa..8b18571 100644 --- a/crates/xserv-cuda/src/ffi.rs +++ b/crates/xserv-cuda/src/ffi.rs @@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2; /// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0; +pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1; unsafe extern "C" { // --- Device --- diff --git a/crates/xserv-cuda/src/graph.rs b/crates/xserv-cuda/src/graph.rs index c002b45..748bcbb 100644 --- a/crates/xserv-cuda/src/graph.rs +++ b/crates/xserv-cuda/src/graph.rs @@ -50,10 +50,13 @@ impl CudaGraph { pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> { // If we have an old graph, destroy it first self.destroy_inner(); + // THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.) + // made by THIS thread invalidate the capture. With GLOBAL mode, TP rank + // threads capturing concurrently would poison each other's captures. error::check(unsafe { ffi::cudaStreamBeginCapture( stream.as_raw(), - ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL, + ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL, ) }) } diff --git a/crates/xserv-cuda/src/lib.rs b/crates/xserv-cuda/src/lib.rs index 2dc2fbd..adb3811 100644 --- a/crates/xserv-cuda/src/lib.rs +++ b/crates/xserv-cuda/src/lib.rs @@ -11,4 +11,4 @@ pub use device::DeviceInfo; pub use error::{CudaError, Result}; pub use graph::CudaGraph; pub use memory::{GpuBuffer, PinnedBuffer}; -pub use stream::CudaStream; +pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard}; diff --git a/crates/xserv-cuda/src/stream.rs b/crates/xserv-cuda/src/stream.rs index aa85c22..d452a75 100644 --- a/crates/xserv-cuda/src/stream.rs +++ b/crates/xserv-cuda/src/stream.rs @@ -31,3 +31,39 @@ impl Drop for CudaStream { // Can move across threads, but not shared without synchronization unsafe impl Send for CudaStream {} + +// --- Thread-local launch stream ------------------------------------------- +// +// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`, +// which defaults to the legacy null stream (the historical behavior). CUDA +// graph capture requires work to be issued on an explicit stream, so capture +// code installs its stream here for the duration of the captured region via +// `push_stream` / `StreamGuard`. + +use std::cell::Cell; + +thread_local! { + static CURRENT_STREAM: Cell = const { Cell::new(std::ptr::null_mut()) }; +} + +/// The stream kernel launches on this thread should use (null = legacy default). +pub fn current_stream_raw() -> ffi::CudaStream { + CURRENT_STREAM.with(|c| c.get()) +} + +/// RAII guard that installs a launch stream for the current thread and +/// restores the previous one on drop. +pub struct StreamGuard { + prev: ffi::CudaStream, +} + +pub fn push_stream(stream: &CudaStream) -> StreamGuard { + let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw())); + StreamGuard { prev } +} + +impl Drop for StreamGuard { + fn drop(&mut self) { + CURRENT_STREAM.with(|c| c.set(self.prev)); + } +} diff --git a/crates/xserv-distributed/src/lib.rs b/crates/xserv-distributed/src/lib.rs index 11f6f5c..1608241 100644 --- a/crates/xserv-distributed/src/lib.rs +++ b/crates/xserv-distributed/src/lib.rs @@ -14,10 +14,13 @@ use xserv_cuda::GpuBuffer; pub use ffi::NcclUniqueId as UniqueId; -/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run -/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered -/// after the producing matmul and before the consuming kernel — no extra sync. -const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut(); +/// NCCL is issued on the thread's current launch stream (legacy null stream +/// by default, the capture stream during CUDA graph capture). The model's +/// kernels run on the same stream, so AllReduce stays correctly ordered after +/// the producing matmul and before the consuming kernel — no extra sync. +fn launch_stream() -> xserv_cuda::ffi::CudaStream { + xserv_cuda::stream::current_stream_raw() +} /// Generate a unique id on one rank (typically rank 0) and broadcast the bytes /// to all ranks out-of-band (e.g. via a shared variable across threads). @@ -80,7 +83,7 @@ impl TpContext { ffi::NCCL_BF16, ffi::NCCL_SUM, self.comm, - NULL_STREAM, + launch_stream(), ) }, "ncclAllReduce", @@ -135,7 +138,7 @@ impl PpContext { /// `ptr` must point to at least `count` BF16 elements of valid device memory. pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) { ffi::check( - unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) }, + unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) }, "ncclSend", ); } @@ -146,7 +149,7 @@ impl PpContext { /// `ptr` must point to at least `count` BF16 elements of valid device memory. pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) { ffi::check( - unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) }, + unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) }, "ncclRecv", ); } diff --git a/crates/xserv-kernels/src/activation.rs b/crates/xserv-kernels/src/activation.rs index 2477841..288804d 100644 --- a/crates/xserv-kernels/src/activation.rs +++ b/crates/xserv-kernels/src/activation.rs @@ -28,8 +28,8 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c let n = n as i32; unsafe { match x.dtype() { - DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), + DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), _ => panic!("unsupported dtype"), } } @@ -49,8 +49,8 @@ fn dispatch_binary(a: &Tensor, b: &Tensor, let n = n as i32; unsafe { match a.dtype() { - DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), - DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), + DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), + DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()), _ => panic!("unsupported dtype"), } } @@ -68,8 +68,8 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor { let n = n as i32; unsafe { match x.dtype() { - DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), - DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), + DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()), + DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()), _ => panic!("unsupported dtype for scale"), } } @@ -95,7 +95,7 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor { unsafe { launch_bias_add_2d_bf16( x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, std::ptr::null_mut(), + rows as i32, cols as i32, xserv_cuda::current_stream_raw(), ); } out @@ -118,7 +118,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor { up.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, n, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } out @@ -146,7 +146,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor { n_elements, alpha, limit, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } out diff --git a/crates/xserv-kernels/src/argmax.rs b/crates/xserv-kernels/src/argmax.rs index 81c42c2..3d2765b 100644 --- a/crates/xserv-kernels/src/argmax.rs +++ b/crates/xserv-kernels/src/argmax.rs @@ -36,7 +36,7 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec { logits.data_ptr() as *const c_void, out.as_mut_ptr() as *mut c_void, rows as i32, cols as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } diff --git a/crates/xserv-kernels/src/attention.rs b/crates/xserv-kernels/src/attention.rs index 3da3e03..61f41b4 100644 --- a/crates/xserv-kernels/src/attention.rs +++ b/crates/xserv-kernels/src/attention.rs @@ -144,12 +144,12 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) { DType::F32 => launch_causal_mask_f32( scores.data_ptr() as *mut c_void, batch as i32, rows as i32, cols as i32, offset as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_causal_mask_bf16( scores.data_ptr() as *mut c_void, batch as i32, rows as i32, cols as i32, offset as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for causal mask"), } @@ -233,7 +233,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor { head_dim as i32, scale, 1, // causal (always 1 for decode) - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -295,7 +295,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens head_dim as i32, scale, if causal { 1 } else { 0 }, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -354,7 +354,7 @@ pub fn flash_attention_sinks( scale, 1, // always causal window_size as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -409,7 +409,7 @@ pub fn paged_decode_attention( head_dim as i32, max_blocks_per_seq as i32, scale, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -464,7 +464,7 @@ pub fn paged_decode_attention_sinks( max_blocks_per_seq as i32, scale, window_size as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } diff --git a/crates/xserv-kernels/src/embedding.rs b/crates/xserv-kernels/src/embedding.rs index 6655e37..0cc2262 100644 --- a/crates/xserv-kernels/src/embedding.rs +++ b/crates/xserv-kernels/src/embedding.rs @@ -35,19 +35,32 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor { assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})"); } + embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens) +} + +/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]). +/// Used by the CUDA-graph decode path, where ids live in a persistent device +/// buffer updated outside the captured region (no bounds check possible here). +pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor { + assert_eq!(table.ndim(), 2); + assert!(table.is_contiguous()); + assert!(matches!(table.device(), Device::Cuda(_))); + let hidden_size = table.shape()[1]; + let vocab_size = table.shape()[0]; + let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device()); unsafe { match table.dtype() { DType::F32 => launch_embedding_f32( - table.data_ptr() as _, ids_gpu.as_ptr() as _, + table.data_ptr() as _, ids_gpu, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(), + num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_embedding_bf16( - table.data_ptr() as _, ids_gpu.as_ptr() as _, + table.data_ptr() as _, ids_gpu, out.data_ptr() as *mut c_void, - num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(), + num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for embedding"), } diff --git a/crates/xserv-kernels/src/gemm.rs b/crates/xserv-kernels/src/gemm.rs index 5a9ad1a..88e743f 100644 --- a/crates/xserv-kernels/src/gemm.rs +++ b/crates/xserv-kernels/src/gemm.rs @@ -151,7 +151,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor { let a_ptr = a.data_ptr() as *const c_void; let b_ptr = b.data_ptr() as *const c_void; let c_ptr = c.data_ptr() as *mut c_void; - let null_stream = std::ptr::null_mut(); + let null_stream = xserv_cuda::current_stream_raw(); match backend { GemmBackend::Naive => { @@ -260,7 +260,7 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor { let stride_c = (m * n) as i64; with_cublas(|handle| unsafe { - cublasSetStream_v2(handle, std::ptr::null_mut()); + cublasSetStream_v2(handle, xserv_cuda::current_stream_raw()); // Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major) error::check(cublasGemmStridedBatchedEx( handle, diff --git a/crates/xserv-kernels/src/layernorm.rs b/crates/xserv-kernels/src/layernorm.rs index 08e3c7f..e03ce51 100644 --- a/crates/xserv-kernels/src/layernorm.rs +++ b/crates/xserv-kernels/src/layernorm.rs @@ -26,12 +26,12 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor DType::F32 => launch_layernorm_f32( x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), + rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_layernorm_bf16( x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), + rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for layernorm"), } diff --git a/crates/xserv-kernels/src/lib.rs b/crates/xserv-kernels/src/lib.rs index eeddd54..584fa8d 100644 --- a/crates/xserv-kernels/src/lib.rs +++ b/crates/xserv-kernels/src/lib.rs @@ -16,11 +16,11 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu}; pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16}; -pub use embedding::embedding; +pub use embedding::{embedding, embedding_device_ids}; pub use gemm::{batched_matmul, matmul, GemmBackend}; pub use layernorm::layernorm; pub use rmsnorm::{add_rmsnorm, rmsnorm}; -pub use rope::{rope_inplace, RopeCache}; +pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache}; pub use softmax::softmax; /// Register GPU kernels with the tensor crate. Call once at startup. diff --git a/crates/xserv-kernels/src/moe.rs b/crates/xserv-kernels/src/moe.rs index 4dd37e1..38bb1e2 100644 --- a/crates/xserv-kernels/src/moe.rs +++ b/crates/xserv-kernels/src/moe.rs @@ -100,7 +100,7 @@ pub fn moe_topk_softmax( topk_ids.data_ptr() as *mut c_void, topk_weights.data_ptr() as *mut c_void, num_tokens as i32, num_experts as i32, top_k as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -121,7 +121,7 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor { x.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, num_tokens as i32, hidden as i32, local_experts as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -144,7 +144,7 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) { x.data_ptr() as *mut c_void, bias.data_ptr() as *const c_void, batch as i32, num_tokens as i32, dim as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } } @@ -177,7 +177,7 @@ pub fn moe_weighted_sum( out.data_ptr() as *mut c_void, num_tokens as i32, hidden as i32, top_k as i32, expert_start as i32, local_experts as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -224,7 +224,7 @@ pub fn moe_sparse_gemv_fp8( y.data_ptr() as *mut c_void, num_tokens as i32, n as i32, k as i32, top_k as i32, expert_start as i32, local_experts as i32, x_per_slot as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } y @@ -256,7 +256,7 @@ pub fn moe_sparse_gemv_mxfp4( y.data_ptr() as *mut c_void, num_tokens as i32, n as i32, k as i32, top_k as i32, expert_start as i32, local_experts as i32, x_per_slot as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } y @@ -288,7 +288,7 @@ pub fn moe_weighted_sum_sparse( out.data_ptr() as *mut c_void, num_tokens as i32, hidden as i32, top_k as i32, expert_start as i32, local_experts as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } out @@ -338,7 +338,7 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor { let handle = cublas_handle(); unsafe { - cublasSetStream_v2(handle, std::ptr::null_mut()); + cublasSetStream_v2(handle, xserv_cuda::current_stream_raw()); let status = cublasGemmStridedBatchedEx( handle, 0, 0, // CUBLAS_OP_N, CUBLAS_OP_N diff --git a/crates/xserv-kernels/src/quantization.rs b/crates/xserv-kernels/src/quantization.rs index a937926..6d40520 100644 --- a/crates/xserv-kernels/src/quantization.rs +++ b/crates/xserv-kernels/src/quantization.rs @@ -300,7 +300,7 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor { scales.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, num_experts as i32, rows as i32, cols as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -330,7 +330,7 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) { fp8_out.data_ptr() as *mut c_void, scales.data_ptr() as *mut c_void, num_rows as i32, cols as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -406,7 +406,7 @@ pub fn batched_gemm_fp8( &plan.algo, ws_ptr, plan.workspace_size, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}"); } @@ -424,7 +424,7 @@ pub fn batched_gemm_fp8( a_scales.data_ptr() as *const c_void, b_scales.data_ptr() as *const c_void, total_rows, n as i32, m as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } @@ -456,7 +456,7 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u w_scales.data_ptr() as *const c_void, y.data_ptr() as *mut c_void, e as i32, n as i32, k as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } y @@ -472,7 +472,7 @@ pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n w_scales.data_ptr() as *const c_void, out.data_ptr() as *mut c_void, e as i32, n as i32, k as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } out diff --git a/crates/xserv-kernels/src/rmsnorm.rs b/crates/xserv-kernels/src/rmsnorm.rs index 0d56f78..ad4e981 100644 --- a/crates/xserv-kernels/src/rmsnorm.rs +++ b/crates/xserv-kernels/src/rmsnorm.rs @@ -28,11 +28,11 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor { match x.dtype() { DType::F32 => launch_rmsnorm_f32( x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), + rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_rmsnorm_bf16( x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), + rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for rmsnorm"), } @@ -71,7 +71,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> ( rows as i32, hidden_size as i32, eps, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } diff --git a/crates/xserv-kernels/src/rope.rs b/crates/xserv-kernels/src/rope.rs index 3a47b42..829fd6c 100644 --- a/crates/xserv-kernels/src/rope.rs +++ b/crates/xserv-kernels/src/rope.rs @@ -31,7 +31,7 @@ impl RopeCache { unsafe { launch_compute_rope_cache( cos.as_mut_ptr() as _, sin.as_mut_ptr() as _, - max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(), + max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(), ); } @@ -136,21 +136,36 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) { let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions"); pos_gpu.copy_from_host(pos_bytes).unwrap(); + rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void); +} + +/// RoPE in-place with positions already on the GPU (u32, [num_tokens]). +/// Used by the CUDA-graph decode path, where the position lives in a +/// persistent device buffer updated outside the captured region. +pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) { + assert_eq!(x.ndim(), 3); + assert!(x.is_contiguous()); + assert!(matches!(x.device(), Device::Cuda(_))); + let num_tokens = x.shape()[0]; + let num_heads = x.shape()[1]; + let head_dim = x.shape()[2]; + assert_eq!(head_dim / 2, cache.half_dim); + unsafe { match x.dtype() { DType::F32 => launch_rope_f32( x.data_ptr() as *mut c_void, cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, - pos_gpu.as_ptr() as _, + pos_gpu, num_tokens as i32, num_heads as i32, head_dim as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_rope_bf16( x.data_ptr() as *mut c_void, cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, - pos_gpu.as_ptr() as _, + pos_gpu, num_tokens as i32, num_heads as i32, head_dim as i32, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for rope"), } diff --git a/crates/xserv-kernels/src/softmax.rs b/crates/xserv-kernels/src/softmax.rs index fe2f622..e83c9be 100644 --- a/crates/xserv-kernels/src/softmax.rs +++ b/crates/xserv-kernels/src/softmax.rs @@ -22,11 +22,11 @@ pub fn softmax(x: &Tensor) -> Tensor { match x.dtype() { DType::F32 => launch_softmax_f32( x.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, std::ptr::null_mut(), + rows as i32, cols as i32, xserv_cuda::current_stream_raw(), ), DType::BF16 => launch_softmax_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - rows as i32, cols as i32, std::ptr::null_mut(), + rows as i32, cols as i32, xserv_cuda::current_stream_raw(), ), _ => panic!("unsupported dtype for softmax"), } diff --git a/crates/xserv-kernels/src/transpose.rs b/crates/xserv-kernels/src/transpose.rs index 24d7392..0260801 100644 --- a/crates/xserv-kernels/src/transpose.rs +++ b/crates/xserv-kernels/src/transpose.rs @@ -25,7 +25,7 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: unsafe { launch_reshape_heads_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), + seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), ); } out @@ -40,7 +40,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u unsafe { launch_merge_heads_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), + seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), ); } out @@ -54,7 +54,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head unsafe { launch_transpose_hsd_to_shd_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), + seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), ); } out @@ -68,7 +68,7 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea unsafe { launch_transpose_shd_to_hsd_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), + seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(), ); } out @@ -87,7 +87,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor { unsafe { launch_repeat_kv_bf16( x.data_ptr() as _, out.data_ptr() as *mut c_void, - kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(), + kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(), ); } out @@ -126,14 +126,14 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor { numel as i32, ndim as i32, shape4[0], shape4[1], shape4[2], shape4[3], strides4[0], strides4[1], strides4[2], strides4[3], - in_offset, std::ptr::null_mut(), + in_offset, xserv_cuda::current_stream_raw(), ), DType::F32 => launch_strided_copy_f32( storage_ptr as _, out.data_ptr() as *mut c_void, numel as i32, ndim as i32, shape4[0], shape4[1], shape4[2], shape4[3], strides4[0], strides4[1], strides4[2], strides4[3], - in_offset, std::ptr::null_mut(), + in_offset, xserv_cuda::current_stream_raw(), ), _ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()), } diff --git a/crates/xserv-model/src/paged_kv_cache.rs b/crates/xserv-model/src/paged_kv_cache.rs index 5871abc..9b34dbb 100644 --- a/crates/xserv-model/src/paged_kv_cache.rs +++ b/crates/xserv-model/src/paged_kv_cache.rs @@ -349,7 +349,7 @@ impl PagedKVCache { k_pool_ptr, v_pool_ptr, block_ids_gpu.as_ptr() as *const i32, num_tokens, nkv, hd, start_pos, bs, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } // block_ids_gpu drops here; the launch on the null stream will have @@ -397,7 +397,7 @@ impl PagedKVCache { k_pool_ptr, v_pool_ptr, bt_ptr, cl_ptr, batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq, - std::ptr::null_mut(), + xserv_cuda::current_stream_raw(), ); } }