cuda: infrastructure for whole-step CUDA graph capture
- Thread-local launch stream (xserv_cuda::stream): every kernel wrapper, cublasSetStream, and NCCL collective now launches on current_stream_raw() — the legacy null stream by default (behavior unchanged), or the capture stream installed via push_stream during graph capture. Capture is impossible on the legacy stream. - Allocator retain mode: blocks freed inside a retain window are quarantined (RetainedBlocks) instead of pooled, so an instantiated graph keeps exclusive ownership of every intermediate buffer it references across replays. - Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads must not poison each other's captures with their own cudaMallocs. - embedding_device_ids / rope_inplace_device_pos: variants reading token ids / positions from persistent device buffers, replacing the per-call host upload that a captured region cannot contain. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -111,6 +111,22 @@ pub fn cached_trim() {
|
||||
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
|
||||
/// and size to avoid re-triggering Drop.
|
||||
pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||
// During CUDA graph capture, buffers freed by the captured code are
|
||||
// quarantined instead of pooled: the instantiated graph references their
|
||||
// addresses on every replay, so they must never be handed to another
|
||||
// consumer for as long as the graph lives.
|
||||
let quarantined = RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
if let Some(list) = r.as_mut() {
|
||||
list.push((ptr, len));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if quarantined {
|
||||
return;
|
||||
}
|
||||
ALLOCATOR.with(|cell| {
|
||||
let mut alloc = cell.borrow_mut();
|
||||
let bucket = bucket_size(len);
|
||||
@@ -119,6 +135,41 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||
});
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
/// Buffers freed while a retain window was active. Holding this keeps their
|
||||
/// memory out of the pool; dropping it returns the blocks (on the owning
|
||||
/// thread) for reuse.
|
||||
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
|
||||
|
||||
impl Drop for RetainedBlocks {
|
||||
fn drop(&mut self) {
|
||||
for (ptr, len) in self.0.drain(..) {
|
||||
return_to_pool(ptr, len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
|
||||
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
|
||||
pub fn begin_retain() {
|
||||
RETAINED.with(|cell| {
|
||||
let mut r = cell.borrow_mut();
|
||||
assert!(r.is_none(), "begin_retain: retain window already active");
|
||||
*r = Some(Vec::new());
|
||||
});
|
||||
}
|
||||
|
||||
/// Stop quarantining and hand the quarantined blocks to the caller.
|
||||
pub fn end_retain() -> RetainedBlocks {
|
||||
RETAINED.with(|cell| {
|
||||
let list = cell.borrow_mut().take().expect("end_retain without begin_retain");
|
||||
RetainedBlocks(list)
|
||||
})
|
||||
}
|
||||
|
||||
/// Round up to next power-of-2, minimum 512 bytes.
|
||||
fn bucket_size(size: usize) -> usize {
|
||||
let min = 512;
|
||||
|
||||
@@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
||||
|
||||
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
|
||||
|
||||
unsafe extern "C" {
|
||||
// --- Device ---
|
||||
|
||||
@@ -50,10 +50,13 @@ impl CudaGraph {
|
||||
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
// If we have an old graph, destroy it first
|
||||
self.destroy_inner();
|
||||
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
|
||||
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
|
||||
// threads capturing concurrently would poison each other's captures.
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamBeginCapture(
|
||||
stream.as_raw(),
|
||||
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
|
||||
ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
|
||||
pub use error::{CudaError, Result};
|
||||
pub use graph::CudaGraph;
|
||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||
pub use stream::CudaStream;
|
||||
pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard};
|
||||
|
||||
@@ -31,3 +31,39 @@ impl Drop for CudaStream {
|
||||
|
||||
// Can move across threads, but not shared without synchronization
|
||||
unsafe impl Send for CudaStream {}
|
||||
|
||||
// --- Thread-local launch stream -------------------------------------------
|
||||
//
|
||||
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
|
||||
// which defaults to the legacy null stream (the historical behavior). CUDA
|
||||
// graph capture requires work to be issued on an explicit stream, so capture
|
||||
// code installs its stream here for the duration of the captured region via
|
||||
// `push_stream` / `StreamGuard`.
|
||||
|
||||
use std::cell::Cell;
|
||||
|
||||
thread_local! {
|
||||
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
|
||||
}
|
||||
|
||||
/// The stream kernel launches on this thread should use (null = legacy default).
|
||||
pub fn current_stream_raw() -> ffi::CudaStream {
|
||||
CURRENT_STREAM.with(|c| c.get())
|
||||
}
|
||||
|
||||
/// RAII guard that installs a launch stream for the current thread and
|
||||
/// restores the previous one on drop.
|
||||
pub struct StreamGuard {
|
||||
prev: ffi::CudaStream,
|
||||
}
|
||||
|
||||
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
|
||||
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
|
||||
StreamGuard { prev }
|
||||
}
|
||||
|
||||
impl Drop for StreamGuard {
|
||||
fn drop(&mut self) {
|
||||
CURRENT_STREAM.with(|c| c.set(self.prev));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,13 @@ use xserv_cuda::GpuBuffer;
|
||||
|
||||
pub use ffi::NcclUniqueId as UniqueId;
|
||||
|
||||
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
|
||||
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
|
||||
/// after the producing matmul and before the consuming kernel — no extra sync.
|
||||
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
|
||||
/// NCCL is issued on the thread's current launch stream (legacy null stream
|
||||
/// by default, the capture stream during CUDA graph capture). The model's
|
||||
/// kernels run on the same stream, so AllReduce stays correctly ordered after
|
||||
/// the producing matmul and before the consuming kernel — no extra sync.
|
||||
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
|
||||
xserv_cuda::stream::current_stream_raw()
|
||||
}
|
||||
|
||||
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
||||
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
||||
@@ -80,7 +83,7 @@ impl TpContext {
|
||||
ffi::NCCL_BF16,
|
||||
ffi::NCCL_SUM,
|
||||
self.comm,
|
||||
NULL_STREAM,
|
||||
launch_stream(),
|
||||
)
|
||||
},
|
||||
"ncclAllReduce",
|
||||
@@ -135,7 +138,7 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||
"ncclSend",
|
||||
);
|
||||
}
|
||||
@@ -146,7 +149,7 @@ impl PpContext {
|
||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||
ffi::check(
|
||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||
"ncclRecv",
|
||||
);
|
||||
}
|
||||
|
||||
@@ -28,8 +28,8 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
@@ -49,8 +49,8 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||
_ => panic!("unsupported dtype"),
|
||||
}
|
||||
}
|
||||
@@ -68,8 +68,8 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
let n = n as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||
_ => panic!("unsupported dtype for scale"),
|
||||
}
|
||||
}
|
||||
@@ -95,7 +95,7 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
||||
unsafe {
|
||||
launch_bias_add_2d_bf16(
|
||||
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -118,7 +118,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
up.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
n,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -146,7 +146,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
|
||||
n_elements,
|
||||
alpha,
|
||||
limit,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
|
||||
@@ -36,7 +36,7 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
||||
logits.data_ptr() as *const c_void,
|
||||
out.as_mut_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -144,12 +144,12 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
||||
DType::F32 => launch_causal_mask_f32(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_causal_mask_bf16(
|
||||
scores.data_ptr() as *mut c_void,
|
||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for causal mask"),
|
||||
}
|
||||
@@ -233,7 +233,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
head_dim as i32,
|
||||
scale,
|
||||
1, // causal (always 1 for decode)
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -295,7 +295,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
head_dim as i32,
|
||||
scale,
|
||||
if causal { 1 } else { 0 },
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -354,7 +354,7 @@ pub fn flash_attention_sinks(
|
||||
scale,
|
||||
1, // always causal
|
||||
window_size as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -409,7 +409,7 @@ pub fn paged_decode_attention(
|
||||
head_dim as i32,
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -464,7 +464,7 @@ pub fn paged_decode_attention_sinks(
|
||||
max_blocks_per_seq as i32,
|
||||
scale,
|
||||
window_size as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -35,19 +35,32 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||
}
|
||||
|
||||
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
|
||||
}
|
||||
|
||||
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where ids live in a persistent device
|
||||
/// buffer updated outside the captured region (no bounds check possible here).
|
||||
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
|
||||
assert_eq!(table.ndim(), 2);
|
||||
assert!(table.is_contiguous());
|
||||
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||
let hidden_size = table.shape()[1];
|
||||
let vocab_size = table.shape()[0];
|
||||
|
||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
DType::F32 => launch_embedding_f32(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _, ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_embedding_bf16(
|
||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
||||
table.data_ptr() as _, ids_gpu,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for embedding"),
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
let c_ptr = c.data_ptr() as *mut c_void;
|
||||
let null_stream = std::ptr::null_mut();
|
||||
let null_stream = xserv_cuda::current_stream_raw();
|
||||
|
||||
match backend {
|
||||
GemmBackend::Naive => {
|
||||
@@ -260,7 +260,7 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let stride_c = (m * n) as i64;
|
||||
|
||||
with_cublas(|handle| unsafe {
|
||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||
error::check(cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
|
||||
@@ -26,12 +26,12 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
DType::F32 => launch_layernorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_layernorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||
out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for layernorm"),
|
||||
}
|
||||
|
||||
@@ -16,11 +16,11 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
|
||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||
pub use embedding::embedding;
|
||||
pub use embedding::{embedding, embedding_device_ids};
|
||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||
pub use layernorm::layernorm;
|
||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||
pub use rope::{rope_inplace, RopeCache};
|
||||
pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache};
|
||||
pub use softmax::softmax;
|
||||
|
||||
/// Register GPU kernels with the tensor crate. Call once at startup.
|
||||
|
||||
@@ -100,7 +100,7 @@ pub fn moe_topk_softmax(
|
||||
topk_ids.data_ptr() as *mut c_void,
|
||||
topk_weights.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, num_experts as i32, top_k as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
|
||||
x.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
|
||||
x.data_ptr() as *mut c_void,
|
||||
bias.data_ptr() as *const c_void,
|
||||
batch as i32, num_tokens as i32, dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -177,7 +177,7 @@ pub fn moe_weighted_sum(
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ pub fn moe_sparse_gemv_fp8(
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
@@ -256,7 +256,7 @@ pub fn moe_sparse_gemv_mxfp4(
|
||||
y.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
@@ -288,7 +288,7 @@ pub fn moe_weighted_sum_sparse(
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_tokens as i32, hidden as i32, top_k as i32,
|
||||
expert_start as i32, local_experts as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -338,7 +338,7 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
|
||||
let handle = cublas_handle();
|
||||
unsafe {
|
||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
||||
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||
let status = cublasGemmStridedBatchedEx(
|
||||
handle,
|
||||
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||
|
||||
@@ -300,7 +300,7 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
||||
scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
num_experts as i32, rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -330,7 +330,7 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
||||
fp8_out.data_ptr() as *mut c_void,
|
||||
scales.data_ptr() as *mut c_void,
|
||||
num_rows as i32, cols as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -406,7 +406,7 @@ pub fn batched_gemm_fp8(
|
||||
&plan.algo,
|
||||
ws_ptr,
|
||||
plan.workspace_size,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
||||
}
|
||||
@@ -424,7 +424,7 @@ pub fn batched_gemm_fp8(
|
||||
a_scales.data_ptr() as *const c_void,
|
||||
b_scales.data_ptr() as *const c_void,
|
||||
total_rows, n as i32, m as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -456,7 +456,7 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
y.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
y
|
||||
@@ -472,7 +472,7 @@ pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n
|
||||
w_scales.data_ptr() as *const c_void,
|
||||
out.data_ptr() as *mut c_void,
|
||||
e as i32, n as i32, k as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
|
||||
@@ -28,11 +28,11 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rmsnorm_f32(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rmsnorm_bf16(
|
||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
||||
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rmsnorm"),
|
||||
}
|
||||
@@ -71,7 +71,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
rows as i32,
|
||||
hidden_size as i32,
|
||||
eps,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ impl RopeCache {
|
||||
unsafe {
|
||||
launch_compute_rope_cache(
|
||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
||||
max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -136,21 +136,36 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
||||
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||
|
||||
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
|
||||
}
|
||||
|
||||
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
|
||||
/// Used by the CUDA-graph decode path, where the position lives in a
|
||||
/// persistent device buffer updated outside the captured region.
|
||||
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
|
||||
assert_eq!(x.ndim(), 3);
|
||||
assert!(x.is_contiguous());
|
||||
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||
let num_tokens = x.shape()[0];
|
||||
let num_heads = x.shape()[1];
|
||||
let head_dim = x.shape()[2];
|
||||
assert_eq!(head_dim / 2, cache.half_dim);
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_rope_f32(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_rope_bf16(
|
||||
x.data_ptr() as *mut c_void,
|
||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||
pos_gpu.as_ptr() as _,
|
||||
pos_gpu,
|
||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for rope"),
|
||||
}
|
||||
|
||||
@@ -22,11 +22,11 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
match x.dtype() {
|
||||
DType::F32 => launch_softmax_f32(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::BF16 => launch_softmax_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
||||
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("unsupported dtype for softmax"),
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -40,7 +40,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -54,7 +54,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -68,7 +68,7 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -87,7 +87,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(),
|
||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
out
|
||||
@@ -126,14 +126,14 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, std::ptr::null_mut(),
|
||||
in_offset, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
DType::F32 => launch_strided_copy_f32(
|
||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||
numel as i32, ndim as i32,
|
||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||
in_offset, std::ptr::null_mut(),
|
||||
in_offset, xserv_cuda::current_stream_raw(),
|
||||
),
|
||||
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
|
||||
}
|
||||
|
||||
@@ -349,7 +349,7 @@ impl PagedKVCache {
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
block_ids_gpu.as_ptr() as *const i32,
|
||||
num_tokens, nkv, hd, start_pos, bs,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
// block_ids_gpu drops here; the launch on the null stream will have
|
||||
@@ -397,7 +397,7 @@ impl PagedKVCache {
|
||||
k_pool_ptr, v_pool_ptr,
|
||||
bt_ptr, cl_ptr,
|
||||
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
|
||||
std::ptr::null_mut(),
|
||||
xserv_cuda::current_stream_raw(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user