cuda: infrastructure for whole-step CUDA graph capture

- Thread-local launch stream (xserv_cuda::stream): every kernel
  wrapper, cublasSetStream, and NCCL collective now launches on
  current_stream_raw() — the legacy null stream by default (behavior
  unchanged), or the capture stream installed via push_stream during
  graph capture. Capture is impossible on the legacy stream.
- Allocator retain mode: blocks freed inside a retain window are
  quarantined (RetainedBlocks) instead of pooled, so an instantiated
  graph keeps exclusive ownership of every intermediate buffer it
  references across replays.
- Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads
  must not poison each other's captures with their own cudaMallocs.
- embedding_device_ids / rope_inplace_device_pos: variants reading
  token ids / positions from persistent device buffers, replacing the
  per-call host upload that a captured region cannot contain.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
2026-06-12 20:12:37 +08:00
parent 2a92f268a9
commit 4088f49b7d
20 changed files with 191 additions and 69 deletions

View File

@@ -111,6 +111,22 @@ pub fn cached_trim() {
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer /// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
/// and size to avoid re-triggering Drop. /// and size to avoid re-triggering Drop.
pub fn return_to_pool(ptr: *mut u8, len: usize) { pub fn return_to_pool(ptr: *mut u8, len: usize) {
// During CUDA graph capture, buffers freed by the captured code are
// quarantined instead of pooled: the instantiated graph references their
// addresses on every replay, so they must never be handed to another
// consumer for as long as the graph lives.
let quarantined = RETAINED.with(|cell| {
let mut r = cell.borrow_mut();
if let Some(list) = r.as_mut() {
list.push((ptr, len));
true
} else {
false
}
});
if quarantined {
return;
}
ALLOCATOR.with(|cell| { ALLOCATOR.with(|cell| {
let mut alloc = cell.borrow_mut(); let mut alloc = cell.borrow_mut();
let bucket = bucket_size(len); let bucket = bucket_size(len);
@@ -119,6 +135,41 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) {
}); });
} }
thread_local! {
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
}
/// Buffers freed while a retain window was active. Holding this keeps their
/// memory out of the pool; dropping it returns the blocks (on the owning
/// thread) for reuse.
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
impl Drop for RetainedBlocks {
fn drop(&mut self) {
for (ptr, len) in self.0.drain(..) {
return_to_pool(ptr, len);
}
}
}
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
pub fn begin_retain() {
RETAINED.with(|cell| {
let mut r = cell.borrow_mut();
assert!(r.is_none(), "begin_retain: retain window already active");
*r = Some(Vec::new());
});
}
/// Stop quarantining and hand the quarantined blocks to the caller.
pub fn end_retain() -> RetainedBlocks {
RETAINED.with(|cell| {
let list = cell.borrow_mut().take().expect("end_retain without begin_retain");
RetainedBlocks(list)
})
}
/// Round up to next power-of-2, minimum 512 bytes. /// Round up to next power-of-2, minimum 512 bytes.
fn bucket_size(size: usize) -> usize { fn bucket_size(size: usize) -> usize {
let min = 512; let min = 512;

View File

@@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal /// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0; pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
unsafe extern "C" { unsafe extern "C" {
// --- Device --- // --- Device ---

View File

@@ -50,10 +50,13 @@ impl CudaGraph {
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> { pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
// If we have an old graph, destroy it first // If we have an old graph, destroy it first
self.destroy_inner(); self.destroy_inner();
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
// threads capturing concurrently would poison each other's captures.
error::check(unsafe { error::check(unsafe {
ffi::cudaStreamBeginCapture( ffi::cudaStreamBeginCapture(
stream.as_raw(), stream.as_raw(),
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL, ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL,
) )
}) })
} }

View File

@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
pub use error::{CudaError, Result}; pub use error::{CudaError, Result};
pub use graph::CudaGraph; pub use graph::CudaGraph;
pub use memory::{GpuBuffer, PinnedBuffer}; pub use memory::{GpuBuffer, PinnedBuffer};
pub use stream::CudaStream; pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard};

View File

@@ -31,3 +31,39 @@ impl Drop for CudaStream {
// Can move across threads, but not shared without synchronization // Can move across threads, but not shared without synchronization
unsafe impl Send for CudaStream {} unsafe impl Send for CudaStream {}
// --- Thread-local launch stream -------------------------------------------
//
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
// which defaults to the legacy null stream (the historical behavior). CUDA
// graph capture requires work to be issued on an explicit stream, so capture
// code installs its stream here for the duration of the captured region via
// `push_stream` / `StreamGuard`.
use std::cell::Cell;
thread_local! {
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
}
/// The stream kernel launches on this thread should use (null = legacy default).
pub fn current_stream_raw() -> ffi::CudaStream {
CURRENT_STREAM.with(|c| c.get())
}
/// RAII guard that installs a launch stream for the current thread and
/// restores the previous one on drop.
pub struct StreamGuard {
prev: ffi::CudaStream,
}
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
StreamGuard { prev }
}
impl Drop for StreamGuard {
fn drop(&mut self) {
CURRENT_STREAM.with(|c| c.set(self.prev));
}
}

View File

@@ -14,10 +14,13 @@ use xserv_cuda::GpuBuffer;
pub use ffi::NcclUniqueId as UniqueId; pub use ffi::NcclUniqueId as UniqueId;
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run /// NCCL is issued on the thread's current launch stream (legacy null stream
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered /// by default, the capture stream during CUDA graph capture). The model's
/// after the producing matmul and before the consuming kernel — no extra sync. /// kernels run on the same stream, so AllReduce stays correctly ordered after
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut(); /// the producing matmul and before the consuming kernel — no extra sync.
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
xserv_cuda::stream::current_stream_raw()
}
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes /// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
/// to all ranks out-of-band (e.g. via a shared variable across threads). /// to all ranks out-of-band (e.g. via a shared variable across threads).
@@ -80,7 +83,7 @@ impl TpContext {
ffi::NCCL_BF16, ffi::NCCL_BF16,
ffi::NCCL_SUM, ffi::NCCL_SUM,
self.comm, self.comm,
NULL_STREAM, launch_stream(),
) )
}, },
"ncclAllReduce", "ncclAllReduce",
@@ -135,7 +138,7 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory. /// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) { pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
ffi::check( ffi::check(
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) }, unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
"ncclSend", "ncclSend",
); );
} }
@@ -146,7 +149,7 @@ impl PpContext {
/// `ptr` must point to at least `count` BF16 elements of valid device memory. /// `ptr` must point to at least `count` BF16 elements of valid device memory.
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) { pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
ffi::check( ffi::check(
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) }, unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
"ncclRecv", "ncclRecv",
); );
} }

View File

@@ -28,8 +28,8 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c
let n = n as i32; let n = n as i32;
unsafe { unsafe {
match x.dtype() { match x.dtype() {
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype"), _ => panic!("unsupported dtype"),
} }
} }
@@ -49,8 +49,8 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
let n = n as i32; let n = n as i32;
unsafe { unsafe {
match a.dtype() { match a.dtype() {
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()), DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype"), _ => panic!("unsupported dtype"),
} }
} }
@@ -68,8 +68,8 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
let n = n as i32; let n = n as i32;
unsafe { unsafe {
match x.dtype() { match x.dtype() {
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()), DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
_ => panic!("unsupported dtype for scale"), _ => panic!("unsupported dtype for scale"),
} }
} }
@@ -95,7 +95,7 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
unsafe { unsafe {
launch_bias_add_2d_bf16( launch_bias_add_2d_bf16(
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(), rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -118,7 +118,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
up.data_ptr() as *const c_void, up.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
n, n,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -146,7 +146,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
n_elements, n_elements,
alpha, alpha,
limit, limit,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
out out

View File

@@ -36,7 +36,7 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
logits.data_ptr() as *const c_void, logits.data_ptr() as *const c_void,
out.as_mut_ptr() as *mut c_void, out.as_mut_ptr() as *mut c_void,
rows as i32, cols as i32, rows as i32, cols as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }

View File

@@ -144,12 +144,12 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
DType::F32 => launch_causal_mask_f32( DType::F32 => launch_causal_mask_f32(
scores.data_ptr() as *mut c_void, scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32, batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_causal_mask_bf16( DType::BF16 => launch_causal_mask_bf16(
scores.data_ptr() as *mut c_void, scores.data_ptr() as *mut c_void,
batch as i32, rows as i32, cols as i32, offset as i32, batch as i32, rows as i32, cols as i32, offset as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for causal mask"), _ => panic!("unsupported dtype for causal mask"),
} }
@@ -233,7 +233,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
head_dim as i32, head_dim as i32,
scale, scale,
1, // causal (always 1 for decode) 1, // causal (always 1 for decode)
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -295,7 +295,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
head_dim as i32, head_dim as i32,
scale, scale,
if causal { 1 } else { 0 }, if causal { 1 } else { 0 },
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -354,7 +354,7 @@ pub fn flash_attention_sinks(
scale, scale,
1, // always causal 1, // always causal
window_size as i32, window_size as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -409,7 +409,7 @@ pub fn paged_decode_attention(
head_dim as i32, head_dim as i32,
max_blocks_per_seq as i32, max_blocks_per_seq as i32,
scale, scale,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -464,7 +464,7 @@ pub fn paged_decode_attention_sinks(
max_blocks_per_seq as i32, max_blocks_per_seq as i32,
scale, scale,
window_size as i32, window_size as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }

View File

@@ -35,19 +35,32 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})"); assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
} }
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
}
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
/// Used by the CUDA-graph decode path, where ids live in a persistent device
/// buffer updated outside the captured region (no bounds check possible here).
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
assert_eq!(table.ndim(), 2);
assert!(table.is_contiguous());
assert!(matches!(table.device(), Device::Cuda(_)));
let hidden_size = table.shape()[1];
let vocab_size = table.shape()[0];
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device()); let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
unsafe { unsafe {
match table.dtype() { match table.dtype() {
DType::F32 => launch_embedding_f32( DType::F32 => launch_embedding_f32(
table.data_ptr() as _, ids_gpu.as_ptr() as _, table.data_ptr() as _, ids_gpu,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(), num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_embedding_bf16( DType::BF16 => launch_embedding_bf16(
table.data_ptr() as _, ids_gpu.as_ptr() as _, table.data_ptr() as _, ids_gpu,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(), num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for embedding"), _ => panic!("unsupported dtype for embedding"),
} }

View File

@@ -151,7 +151,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
let a_ptr = a.data_ptr() as *const c_void; let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void; let b_ptr = b.data_ptr() as *const c_void;
let c_ptr = c.data_ptr() as *mut c_void; let c_ptr = c.data_ptr() as *mut c_void;
let null_stream = std::ptr::null_mut(); let null_stream = xserv_cuda::current_stream_raw();
match backend { match backend {
GemmBackend::Naive => { GemmBackend::Naive => {
@@ -260,7 +260,7 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
let stride_c = (m * n) as i64; let stride_c = (m * n) as i64;
with_cublas(|handle| unsafe { with_cublas(|handle| unsafe {
cublasSetStream_v2(handle, std::ptr::null_mut()); cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major) // Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
error::check(cublasGemmStridedBatchedEx( error::check(cublasGemmStridedBatchedEx(
handle, handle,

View File

@@ -26,12 +26,12 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
DType::F32 => launch_layernorm_f32( DType::F32 => launch_layernorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_layernorm_bf16( DType::BF16 => launch_layernorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _, x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for layernorm"), _ => panic!("unsupported dtype for layernorm"),
} }

View File

@@ -16,11 +16,11 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host}; pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu}; pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16}; pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
pub use embedding::embedding; pub use embedding::{embedding, embedding_device_ids};
pub use gemm::{batched_matmul, matmul, GemmBackend}; pub use gemm::{batched_matmul, matmul, GemmBackend};
pub use layernorm::layernorm; pub use layernorm::layernorm;
pub use rmsnorm::{add_rmsnorm, rmsnorm}; pub use rmsnorm::{add_rmsnorm, rmsnorm};
pub use rope::{rope_inplace, RopeCache}; pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache};
pub use softmax::softmax; pub use softmax::softmax;
/// Register GPU kernels with the tensor crate. Call once at startup. /// Register GPU kernels with the tensor crate. Call once at startup.

View File

@@ -100,7 +100,7 @@ pub fn moe_topk_softmax(
topk_ids.data_ptr() as *mut c_void, topk_ids.data_ptr() as *mut c_void,
topk_weights.data_ptr() as *mut c_void, topk_weights.data_ptr() as *mut c_void,
num_tokens as i32, num_experts as i32, top_k as i32, num_tokens as i32, num_experts as i32, top_k as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -121,7 +121,7 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
x.data_ptr() as *const c_void, x.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, local_experts as i32, num_tokens as i32, hidden as i32, local_experts as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -144,7 +144,7 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
x.data_ptr() as *mut c_void, x.data_ptr() as *mut c_void,
bias.data_ptr() as *const c_void, bias.data_ptr() as *const c_void,
batch as i32, num_tokens as i32, dim as i32, batch as i32, num_tokens as i32, dim as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
} }
@@ -177,7 +177,7 @@ pub fn moe_weighted_sum(
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32, num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32, expert_start as i32, local_experts as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -224,7 +224,7 @@ pub fn moe_sparse_gemv_fp8(
y.data_ptr() as *mut c_void, y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32, num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32, expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
y y
@@ -256,7 +256,7 @@ pub fn moe_sparse_gemv_mxfp4(
y.data_ptr() as *mut c_void, y.data_ptr() as *mut c_void,
num_tokens as i32, n as i32, k as i32, top_k as i32, num_tokens as i32, n as i32, k as i32, top_k as i32,
expert_start as i32, local_experts as i32, x_per_slot as i32, expert_start as i32, local_experts as i32, x_per_slot as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
y y
@@ -288,7 +288,7 @@ pub fn moe_weighted_sum_sparse(
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_tokens as i32, hidden as i32, top_k as i32, num_tokens as i32, hidden as i32, top_k as i32,
expert_start as i32, local_experts as i32, expert_start as i32, local_experts as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -338,7 +338,7 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
let handle = cublas_handle(); let handle = cublas_handle();
unsafe { unsafe {
cublasSetStream_v2(handle, std::ptr::null_mut()); cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
let status = cublasGemmStridedBatchedEx( let status = cublasGemmStridedBatchedEx(
handle, handle,
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N 0, 0, // CUBLAS_OP_N, CUBLAS_OP_N

View File

@@ -300,7 +300,7 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
scales.data_ptr() as *const c_void, scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
num_experts as i32, rows as i32, cols as i32, num_experts as i32, rows as i32, cols as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -330,7 +330,7 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
fp8_out.data_ptr() as *mut c_void, fp8_out.data_ptr() as *mut c_void,
scales.data_ptr() as *mut c_void, scales.data_ptr() as *mut c_void,
num_rows as i32, cols as i32, num_rows as i32, cols as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -406,7 +406,7 @@ pub fn batched_gemm_fp8(
&plan.algo, &plan.algo,
ws_ptr, ws_ptr,
plan.workspace_size, plan.workspace_size,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}"); assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
} }
@@ -424,7 +424,7 @@ pub fn batched_gemm_fp8(
a_scales.data_ptr() as *const c_void, a_scales.data_ptr() as *const c_void,
b_scales.data_ptr() as *const c_void, b_scales.data_ptr() as *const c_void,
total_rows, n as i32, m as i32, total_rows, n as i32, m as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
@@ -456,7 +456,7 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
w_scales.data_ptr() as *const c_void, w_scales.data_ptr() as *const c_void,
y.data_ptr() as *mut c_void, y.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32, e as i32, n as i32, k as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
y y
@@ -472,7 +472,7 @@ pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n
w_scales.data_ptr() as *const c_void, w_scales.data_ptr() as *const c_void,
out.data_ptr() as *mut c_void, out.data_ptr() as *mut c_void,
e as i32, n as i32, k as i32, e as i32, n as i32, k as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
out out

View File

@@ -28,11 +28,11 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
match x.dtype() { match x.dtype() {
DType::F32 => launch_rmsnorm_f32( DType::F32 => launch_rmsnorm_f32(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_rmsnorm_bf16( DType::BF16 => launch_rmsnorm_bf16(
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(), rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for rmsnorm"), _ => panic!("unsupported dtype for rmsnorm"),
} }
@@ -71,7 +71,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
rows as i32, rows as i32,
hidden_size as i32, hidden_size as i32,
eps, eps,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }

View File

@@ -31,7 +31,7 @@ impl RopeCache {
unsafe { unsafe {
launch_compute_rope_cache( launch_compute_rope_cache(
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _, cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(), max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(),
); );
} }
@@ -136,21 +136,36 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions"); let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
pos_gpu.copy_from_host(pos_bytes).unwrap(); pos_gpu.copy_from_host(pos_bytes).unwrap();
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
}
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
/// Used by the CUDA-graph decode path, where the position lives in a
/// persistent device buffer updated outside the captured region.
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
assert_eq!(x.ndim(), 3);
assert!(x.is_contiguous());
assert!(matches!(x.device(), Device::Cuda(_)));
let num_tokens = x.shape()[0];
let num_heads = x.shape()[1];
let head_dim = x.shape()[2];
assert_eq!(head_dim / 2, cache.half_dim);
unsafe { unsafe {
match x.dtype() { match x.dtype() {
DType::F32 => launch_rope_f32( DType::F32 => launch_rope_f32(
x.data_ptr() as *mut c_void, x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _, pos_gpu,
num_tokens as i32, num_heads as i32, head_dim as i32, num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_rope_bf16( DType::BF16 => launch_rope_bf16(
x.data_ptr() as *mut c_void, x.data_ptr() as *mut c_void,
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _, cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
pos_gpu.as_ptr() as _, pos_gpu,
num_tokens as i32, num_heads as i32, head_dim as i32, num_tokens as i32, num_heads as i32, head_dim as i32,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for rope"), _ => panic!("unsupported dtype for rope"),
} }

View File

@@ -22,11 +22,11 @@ pub fn softmax(x: &Tensor) -> Tensor {
match x.dtype() { match x.dtype() {
DType::F32 => launch_softmax_f32( DType::F32 => launch_softmax_f32(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(), rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
), ),
DType::BF16 => launch_softmax_bf16( DType::BF16 => launch_softmax_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
rows as i32, cols as i32, std::ptr::null_mut(), rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
), ),
_ => panic!("unsupported dtype for softmax"), _ => panic!("unsupported dtype for softmax"),
} }

View File

@@ -25,7 +25,7 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
unsafe { unsafe {
launch_reshape_heads_bf16( launch_reshape_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -40,7 +40,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
unsafe { unsafe {
launch_merge_heads_bf16( launch_merge_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -54,7 +54,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head
unsafe { unsafe {
launch_transpose_hsd_to_shd_bf16( launch_transpose_hsd_to_shd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -68,7 +68,7 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
unsafe { unsafe {
launch_transpose_shd_to_hsd_bf16( launch_transpose_shd_to_hsd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(), seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -87,7 +87,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
unsafe { unsafe {
launch_repeat_kv_bf16( launch_repeat_kv_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void, x.data_ptr() as _, out.data_ptr() as *mut c_void,
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(), kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
); );
} }
out out
@@ -126,14 +126,14 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
numel as i32, ndim as i32, numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3], shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3], strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, std::ptr::null_mut(), in_offset, xserv_cuda::current_stream_raw(),
), ),
DType::F32 => launch_strided_copy_f32( DType::F32 => launch_strided_copy_f32(
storage_ptr as _, out.data_ptr() as *mut c_void, storage_ptr as _, out.data_ptr() as *mut c_void,
numel as i32, ndim as i32, numel as i32, ndim as i32,
shape4[0], shape4[1], shape4[2], shape4[3], shape4[0], shape4[1], shape4[2], shape4[3],
strides4[0], strides4[1], strides4[2], strides4[3], strides4[0], strides4[1], strides4[2], strides4[3],
in_offset, std::ptr::null_mut(), in_offset, xserv_cuda::current_stream_raw(),
), ),
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()), _ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
} }

View File

@@ -349,7 +349,7 @@ impl PagedKVCache {
k_pool_ptr, v_pool_ptr, k_pool_ptr, v_pool_ptr,
block_ids_gpu.as_ptr() as *const i32, block_ids_gpu.as_ptr() as *const i32,
num_tokens, nkv, hd, start_pos, bs, num_tokens, nkv, hd, start_pos, bs,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
// block_ids_gpu drops here; the launch on the null stream will have // block_ids_gpu drops here; the launch on the null stream will have
@@ -397,7 +397,7 @@ impl PagedKVCache {
k_pool_ptr, v_pool_ptr, k_pool_ptr, v_pool_ptr,
bt_ptr, cl_ptr, bt_ptr, cl_ptr,
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq, batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
std::ptr::null_mut(), xserv_cuda::current_stream_raw(),
); );
} }
} }