cuda: infrastructure for whole-step CUDA graph capture
- Thread-local launch stream (xserv_cuda::stream): every kernel wrapper, cublasSetStream, and NCCL collective now launches on current_stream_raw() — the legacy null stream by default (behavior unchanged), or the capture stream installed via push_stream during graph capture. Capture is impossible on the legacy stream. - Allocator retain mode: blocks freed inside a retain window are quarantined (RetainedBlocks) instead of pooled, so an instantiated graph keeps exclusive ownership of every intermediate buffer it references across replays. - Capture mode GLOBAL -> THREAD_LOCAL: concurrent TP rank threads must not poison each other's captures with their own cudaMallocs. - embedding_device_ids / rope_inplace_device_pos: variants reading token ids / positions from persistent device buffers, replacing the per-call host upload that a captured region cannot contain. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -111,6 +111,22 @@ pub fn cached_trim() {
|
|||||||
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
|
/// Called from `GpuBuffer::Drop` for pooled buffers. Takes raw pointer
|
||||||
/// and size to avoid re-triggering Drop.
|
/// and size to avoid re-triggering Drop.
|
||||||
pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
||||||
|
// During CUDA graph capture, buffers freed by the captured code are
|
||||||
|
// quarantined instead of pooled: the instantiated graph references their
|
||||||
|
// addresses on every replay, so they must never be handed to another
|
||||||
|
// consumer for as long as the graph lives.
|
||||||
|
let quarantined = RETAINED.with(|cell| {
|
||||||
|
let mut r = cell.borrow_mut();
|
||||||
|
if let Some(list) = r.as_mut() {
|
||||||
|
list.push((ptr, len));
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if quarantined {
|
||||||
|
return;
|
||||||
|
}
|
||||||
ALLOCATOR.with(|cell| {
|
ALLOCATOR.with(|cell| {
|
||||||
let mut alloc = cell.borrow_mut();
|
let mut alloc = cell.borrow_mut();
|
||||||
let bucket = bucket_size(len);
|
let bucket = bucket_size(len);
|
||||||
@@ -119,6 +135,41 @@ pub fn return_to_pool(ptr: *mut u8, len: usize) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static RETAINED: RefCell<Option<Vec<(*mut u8, usize)>>> = const { RefCell::new(None) };
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffers freed while a retain window was active. Holding this keeps their
|
||||||
|
/// memory out of the pool; dropping it returns the blocks (on the owning
|
||||||
|
/// thread) for reuse.
|
||||||
|
pub struct RetainedBlocks(Vec<(*mut u8, usize)>);
|
||||||
|
|
||||||
|
impl Drop for RetainedBlocks {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
for (ptr, len) in self.0.drain(..) {
|
||||||
|
return_to_pool(ptr, len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start quarantining buffers freed on this thread (see `return_to_pool`).
|
||||||
|
/// Must be paired with `end_retain` on the same thread; nesting unsupported.
|
||||||
|
pub fn begin_retain() {
|
||||||
|
RETAINED.with(|cell| {
|
||||||
|
let mut r = cell.borrow_mut();
|
||||||
|
assert!(r.is_none(), "begin_retain: retain window already active");
|
||||||
|
*r = Some(Vec::new());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop quarantining and hand the quarantined blocks to the caller.
|
||||||
|
pub fn end_retain() -> RetainedBlocks {
|
||||||
|
RETAINED.with(|cell| {
|
||||||
|
let list = cell.borrow_mut().take().expect("end_retain without begin_retain");
|
||||||
|
RetainedBlocks(list)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Round up to next power-of-2, minimum 512 bytes.
|
/// Round up to next power-of-2, minimum 512 bytes.
|
||||||
fn bucket_size(size: usize) -> usize {
|
fn bucket_size(size: usize) -> usize {
|
||||||
let min = 512;
|
let min = 512;
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
|||||||
|
|
||||||
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
||||||
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
||||||
|
pub const CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL: i32 = 1;
|
||||||
|
|
||||||
unsafe extern "C" {
|
unsafe extern "C" {
|
||||||
// --- Device ---
|
// --- Device ---
|
||||||
|
|||||||
@@ -50,10 +50,13 @@ impl CudaGraph {
|
|||||||
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||||
// If we have an old graph, destroy it first
|
// If we have an old graph, destroy it first
|
||||||
self.destroy_inner();
|
self.destroy_inner();
|
||||||
|
// THREAD_LOCAL: only "potentially unsafe" CUDA calls (cudaMalloc etc.)
|
||||||
|
// made by THIS thread invalidate the capture. With GLOBAL mode, TP rank
|
||||||
|
// threads capturing concurrently would poison each other's captures.
|
||||||
error::check(unsafe {
|
error::check(unsafe {
|
||||||
ffi::cudaStreamBeginCapture(
|
ffi::cudaStreamBeginCapture(
|
||||||
stream.as_raw(),
|
stream.as_raw(),
|
||||||
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
|
ffi::CUDA_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,4 +11,4 @@ pub use device::DeviceInfo;
|
|||||||
pub use error::{CudaError, Result};
|
pub use error::{CudaError, Result};
|
||||||
pub use graph::CudaGraph;
|
pub use graph::CudaGraph;
|
||||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||||
pub use stream::CudaStream;
|
pub use stream::{current_stream_raw, push_stream, CudaStream, StreamGuard};
|
||||||
|
|||||||
@@ -31,3 +31,39 @@ impl Drop for CudaStream {
|
|||||||
|
|
||||||
// Can move across threads, but not shared without synchronization
|
// Can move across threads, but not shared without synchronization
|
||||||
unsafe impl Send for CudaStream {}
|
unsafe impl Send for CudaStream {}
|
||||||
|
|
||||||
|
// --- Thread-local launch stream -------------------------------------------
|
||||||
|
//
|
||||||
|
// Every kernel wrapper in xserv-kernels launches on `current_stream_raw()`,
|
||||||
|
// which defaults to the legacy null stream (the historical behavior). CUDA
|
||||||
|
// graph capture requires work to be issued on an explicit stream, so capture
|
||||||
|
// code installs its stream here for the duration of the captured region via
|
||||||
|
// `push_stream` / `StreamGuard`.
|
||||||
|
|
||||||
|
use std::cell::Cell;
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CURRENT_STREAM: Cell<ffi::CudaStream> = const { Cell::new(std::ptr::null_mut()) };
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The stream kernel launches on this thread should use (null = legacy default).
|
||||||
|
pub fn current_stream_raw() -> ffi::CudaStream {
|
||||||
|
CURRENT_STREAM.with(|c| c.get())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RAII guard that installs a launch stream for the current thread and
|
||||||
|
/// restores the previous one on drop.
|
||||||
|
pub struct StreamGuard {
|
||||||
|
prev: ffi::CudaStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_stream(stream: &CudaStream) -> StreamGuard {
|
||||||
|
let prev = CURRENT_STREAM.with(|c| c.replace(stream.as_raw()));
|
||||||
|
StreamGuard { prev }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for StreamGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
CURRENT_STREAM.with(|c| c.set(self.prev));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,10 +14,13 @@ use xserv_cuda::GpuBuffer;
|
|||||||
|
|
||||||
pub use ffi::NcclUniqueId as UniqueId;
|
pub use ffi::NcclUniqueId as UniqueId;
|
||||||
|
|
||||||
/// The CUDA "null" (default) stream. The model's kernels and cuBLAS calls run
|
/// NCCL is issued on the thread's current launch stream (legacy null stream
|
||||||
/// on it, so issuing NCCL on the same stream keeps AllReduce correctly ordered
|
/// by default, the capture stream during CUDA graph capture). The model's
|
||||||
/// after the producing matmul and before the consuming kernel — no extra sync.
|
/// kernels run on the same stream, so AllReduce stays correctly ordered after
|
||||||
const NULL_STREAM: xserv_cuda::ffi::CudaStream = std::ptr::null_mut();
|
/// the producing matmul and before the consuming kernel — no extra sync.
|
||||||
|
fn launch_stream() -> xserv_cuda::ffi::CudaStream {
|
||||||
|
xserv_cuda::stream::current_stream_raw()
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
/// Generate a unique id on one rank (typically rank 0) and broadcast the bytes
|
||||||
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
/// to all ranks out-of-band (e.g. via a shared variable across threads).
|
||||||
@@ -80,7 +83,7 @@ impl TpContext {
|
|||||||
ffi::NCCL_BF16,
|
ffi::NCCL_BF16,
|
||||||
ffi::NCCL_SUM,
|
ffi::NCCL_SUM,
|
||||||
self.comm,
|
self.comm,
|
||||||
NULL_STREAM,
|
launch_stream(),
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
"ncclAllReduce",
|
"ncclAllReduce",
|
||||||
@@ -135,7 +138,7 @@ impl PpContext {
|
|||||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||||
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
pub fn send_bf16_ptr(&self, ptr: *const c_void, count: usize, peer: usize) {
|
||||||
ffi::check(
|
ffi::check(
|
||||||
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
unsafe { ffi::ncclSend(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||||
"ncclSend",
|
"ncclSend",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -146,7 +149,7 @@ impl PpContext {
|
|||||||
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
/// `ptr` must point to at least `count` BF16 elements of valid device memory.
|
||||||
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
pub fn recv_bf16_ptr(&self, ptr: *mut c_void, count: usize, peer: usize) {
|
||||||
ffi::check(
|
ffi::check(
|
||||||
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, NULL_STREAM) },
|
unsafe { ffi::ncclRecv(ptr, count, ffi::NCCL_BF16, peer as i32, self.comm, launch_stream()) },
|
||||||
"ncclRecv",
|
"ncclRecv",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c
|
|||||||
let n = n as i32;
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::F32 => f32_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||||
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::BF16 => bf16_fn(x.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||||
_ => panic!("unsupported dtype"),
|
_ => panic!("unsupported dtype"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,8 +49,8 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
|
|||||||
let n = n as i32;
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match a.dtype() {
|
match a.dtype() {
|
||||||
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::F32 => f32_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||||
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, std::ptr::null_mut()),
|
DType::BF16 => bf16_fn(a.data_ptr() as _, b.data_ptr() as _, out.data_ptr() as *mut c_void, n, xserv_cuda::current_stream_raw()),
|
||||||
_ => panic!("unsupported dtype"),
|
_ => panic!("unsupported dtype"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,8 +68,8 @@ pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
|||||||
let n = n as i32;
|
let n = n as i32;
|
||||||
unsafe {
|
unsafe {
|
||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
DType::F32 => launch_scale_f32(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||||
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, std::ptr::null_mut()),
|
DType::BF16 => launch_scale_bf16(x.data_ptr() as _, out.data_ptr() as *mut c_void, scale_val, n, xserv_cuda::current_stream_raw()),
|
||||||
_ => panic!("unsupported dtype for scale"),
|
_ => panic!("unsupported dtype for scale"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ pub fn bias_add_2d(x: &Tensor, bias: &Tensor) -> Tensor {
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_bias_add_2d_bf16(
|
launch_bias_add_2d_bf16(
|
||||||
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, bias.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -118,7 +118,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
|||||||
up.data_ptr() as *const c_void,
|
up.data_ptr() as *const c_void,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
n,
|
n,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -146,7 +146,7 @@ pub fn gpt_oss_glu(gate_up: &Tensor, alpha: f32, limit: f32) -> Tensor {
|
|||||||
n_elements,
|
n_elements,
|
||||||
alpha,
|
alpha,
|
||||||
limit,
|
limit,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ pub fn argmax_bf16_to_host(logits: &Tensor) -> Vec<u32> {
|
|||||||
logits.data_ptr() as *const c_void,
|
logits.data_ptr() as *const c_void,
|
||||||
out.as_mut_ptr() as *mut c_void,
|
out.as_mut_ptr() as *mut c_void,
|
||||||
rows as i32, cols as i32,
|
rows as i32, cols as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -144,12 +144,12 @@ fn apply_causal_mask(scores: &Tensor, offset: usize) {
|
|||||||
DType::F32 => launch_causal_mask_f32(
|
DType::F32 => launch_causal_mask_f32(
|
||||||
scores.data_ptr() as *mut c_void,
|
scores.data_ptr() as *mut c_void,
|
||||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_causal_mask_bf16(
|
DType::BF16 => launch_causal_mask_bf16(
|
||||||
scores.data_ptr() as *mut c_void,
|
scores.data_ptr() as *mut c_void,
|
||||||
batch as i32, rows as i32, cols as i32, offset as i32,
|
batch as i32, rows as i32, cols as i32, offset as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for causal mask"),
|
_ => panic!("unsupported dtype for causal mask"),
|
||||||
}
|
}
|
||||||
@@ -233,7 +233,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
|||||||
head_dim as i32,
|
head_dim as i32,
|
||||||
scale,
|
scale,
|
||||||
1, // causal (always 1 for decode)
|
1, // causal (always 1 for decode)
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,7 +295,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
|||||||
head_dim as i32,
|
head_dim as i32,
|
||||||
scale,
|
scale,
|
||||||
if causal { 1 } else { 0 },
|
if causal { 1 } else { 0 },
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,7 +354,7 @@ pub fn flash_attention_sinks(
|
|||||||
scale,
|
scale,
|
||||||
1, // always causal
|
1, // always causal
|
||||||
window_size as i32,
|
window_size as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,7 +409,7 @@ pub fn paged_decode_attention(
|
|||||||
head_dim as i32,
|
head_dim as i32,
|
||||||
max_blocks_per_seq as i32,
|
max_blocks_per_seq as i32,
|
||||||
scale,
|
scale,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,7 +464,7 @@ pub fn paged_decode_attention_sinks(
|
|||||||
max_blocks_per_seq as i32,
|
max_blocks_per_seq as i32,
|
||||||
scale,
|
scale,
|
||||||
window_size as i32,
|
window_size as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,19 +35,32 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
|||||||
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
assert!((tid as usize) < vocab_size, "token_id {tid} out of bounds (vocab_size={vocab_size})");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
embedding_device_ids(table, ids_gpu.as_ptr() as *const c_void, num_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embedding lookup with token ids already on the GPU (u32, [num_tokens]).
|
||||||
|
/// Used by the CUDA-graph decode path, where ids live in a persistent device
|
||||||
|
/// buffer updated outside the captured region (no bounds check possible here).
|
||||||
|
pub fn embedding_device_ids(table: &Tensor, ids_gpu: *const c_void, num_tokens: usize) -> Tensor {
|
||||||
|
assert_eq!(table.ndim(), 2);
|
||||||
|
assert!(table.is_contiguous());
|
||||||
|
assert!(matches!(table.device(), Device::Cuda(_)));
|
||||||
|
let hidden_size = table.shape()[1];
|
||||||
|
let vocab_size = table.shape()[0];
|
||||||
|
|
||||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
match table.dtype() {
|
match table.dtype() {
|
||||||
DType::F32 => launch_embedding_f32(
|
DType::F32 => launch_embedding_f32(
|
||||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
table.data_ptr() as _, ids_gpu,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_embedding_bf16(
|
DType::BF16 => launch_embedding_bf16(
|
||||||
table.data_ptr() as _, ids_gpu.as_ptr() as _,
|
table.data_ptr() as _, ids_gpu,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden_size as i32, vocab_size as i32, std::ptr::null_mut(),
|
num_tokens as i32, hidden_size as i32, vocab_size as i32, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for embedding"),
|
_ => panic!("unsupported dtype for embedding"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
|||||||
let a_ptr = a.data_ptr() as *const c_void;
|
let a_ptr = a.data_ptr() as *const c_void;
|
||||||
let b_ptr = b.data_ptr() as *const c_void;
|
let b_ptr = b.data_ptr() as *const c_void;
|
||||||
let c_ptr = c.data_ptr() as *mut c_void;
|
let c_ptr = c.data_ptr() as *mut c_void;
|
||||||
let null_stream = std::ptr::null_mut();
|
let null_stream = xserv_cuda::current_stream_raw();
|
||||||
|
|
||||||
match backend {
|
match backend {
|
||||||
GemmBackend::Naive => {
|
GemmBackend::Naive => {
|
||||||
@@ -260,7 +260,7 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
|||||||
let stride_c = (m * n) as i64;
|
let stride_c = (m * n) as i64;
|
||||||
|
|
||||||
with_cublas(|handle| unsafe {
|
with_cublas(|handle| unsafe {
|
||||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||||
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
// Row-major trick: C = A @ B ⟺ C^T = B^T @ A^T (col-major)
|
||||||
error::check(cublasGemmStridedBatchedEx(
|
error::check(cublasGemmStridedBatchedEx(
|
||||||
handle,
|
handle,
|
||||||
|
|||||||
@@ -26,12 +26,12 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
|||||||
DType::F32 => launch_layernorm_f32(
|
DType::F32 => launch_layernorm_f32(
|
||||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_layernorm_bf16(
|
DType::BF16 => launch_layernorm_bf16(
|
||||||
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
x.data_ptr() as _, gamma.data_ptr() as _, beta.data_ptr() as _,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for layernorm"),
|
_ => panic!("unsupported dtype for layernorm"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,11 +16,11 @@ pub use activation::{add, bias_add_2d, gelu, gpt_oss_glu, mul, scale, silu, silu
|
|||||||
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
pub use argmax::{argmax_bf16_single, argmax_bf16_to_host};
|
||||||
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
pub use transpose::{merge_heads_gpu, repeat_kv_gpu, reshape_heads_gpu, strided_to_contiguous_gpu, transpose_for_rope_gpu, transpose_from_rope_gpu};
|
||||||
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
pub use attention::{attention, decode_attention, flash_attention, flash_attention_sinks, paged_decode_attention, paged_decode_attention_sinks, reshape_and_cache_bf16, reshape_and_cache_batched_bf16};
|
||||||
pub use embedding::embedding;
|
pub use embedding::{embedding, embedding_device_ids};
|
||||||
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
pub use gemm::{batched_matmul, matmul, GemmBackend};
|
||||||
pub use layernorm::layernorm;
|
pub use layernorm::layernorm;
|
||||||
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
pub use rmsnorm::{add_rmsnorm, rmsnorm};
|
||||||
pub use rope::{rope_inplace, RopeCache};
|
pub use rope::{rope_inplace, rope_inplace_device_pos, RopeCache};
|
||||||
pub use softmax::softmax;
|
pub use softmax::softmax;
|
||||||
|
|
||||||
/// Register GPU kernels with the tensor crate. Call once at startup.
|
/// Register GPU kernels with the tensor crate. Call once at startup.
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ pub fn moe_topk_softmax(
|
|||||||
topk_ids.data_ptr() as *mut c_void,
|
topk_ids.data_ptr() as *mut c_void,
|
||||||
topk_weights.data_ptr() as *mut c_void,
|
topk_weights.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, num_experts as i32, top_k as i32,
|
num_tokens as i32, num_experts as i32, top_k as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +121,7 @@ pub fn moe_replicate(x: &Tensor, local_experts: usize) -> Tensor {
|
|||||||
x.data_ptr() as *const c_void,
|
x.data_ptr() as *const c_void,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden as i32, local_experts as i32,
|
num_tokens as i32, hidden as i32, local_experts as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ pub fn moe_bias_add_3d(x: &Tensor, bias: &Tensor) {
|
|||||||
x.data_ptr() as *mut c_void,
|
x.data_ptr() as *mut c_void,
|
||||||
bias.data_ptr() as *const c_void,
|
bias.data_ptr() as *const c_void,
|
||||||
batch as i32, num_tokens as i32, dim as i32,
|
batch as i32, num_tokens as i32, dim as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -177,7 +177,7 @@ pub fn moe_weighted_sum(
|
|||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden as i32, top_k as i32,
|
num_tokens as i32, hidden as i32, top_k as i32,
|
||||||
expert_start as i32, local_experts as i32,
|
expert_start as i32, local_experts as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ pub fn moe_sparse_gemv_fp8(
|
|||||||
y.data_ptr() as *mut c_void,
|
y.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
y
|
y
|
||||||
@@ -256,7 +256,7 @@ pub fn moe_sparse_gemv_mxfp4(
|
|||||||
y.data_ptr() as *mut c_void,
|
y.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
num_tokens as i32, n as i32, k as i32, top_k as i32,
|
||||||
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
expert_start as i32, local_experts as i32, x_per_slot as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
y
|
y
|
||||||
@@ -288,7 +288,7 @@ pub fn moe_weighted_sum_sparse(
|
|||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_tokens as i32, hidden as i32, top_k as i32,
|
num_tokens as i32, hidden as i32, top_k as i32,
|
||||||
expert_start as i32, local_experts as i32,
|
expert_start as i32, local_experts as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -338,7 +338,7 @@ pub fn batched_gemm_strided(a: &Tensor, b: &Tensor) -> Tensor {
|
|||||||
|
|
||||||
let handle = cublas_handle();
|
let handle = cublas_handle();
|
||||||
unsafe {
|
unsafe {
|
||||||
cublasSetStream_v2(handle, std::ptr::null_mut());
|
cublasSetStream_v2(handle, xserv_cuda::current_stream_raw());
|
||||||
let status = cublasGemmStridedBatchedEx(
|
let status = cublasGemmStridedBatchedEx(
|
||||||
handle,
|
handle,
|
||||||
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
|
0, 0, // CUBLAS_OP_N, CUBLAS_OP_N
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ pub fn dequant_fp8_to_bf16(src: &Tensor, scales: &Tensor) -> Tensor {
|
|||||||
scales.data_ptr() as *const c_void,
|
scales.data_ptr() as *const c_void,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
num_experts as i32, rows as i32, cols as i32,
|
num_experts as i32, rows as i32, cols as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ pub fn quantize_bf16_to_fp8_rowwise(src: &Tensor) -> (Tensor, Tensor) {
|
|||||||
fp8_out.data_ptr() as *mut c_void,
|
fp8_out.data_ptr() as *mut c_void,
|
||||||
scales.data_ptr() as *mut c_void,
|
scales.data_ptr() as *mut c_void,
|
||||||
num_rows as i32, cols as i32,
|
num_rows as i32, cols as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,7 +406,7 @@ pub fn batched_gemm_fp8(
|
|||||||
&plan.algo,
|
&plan.algo,
|
||||||
ws_ptr,
|
ws_ptr,
|
||||||
plan.workspace_size,
|
plan.workspace_size,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
assert_eq!(status, 0, "batched cublasLtMatmul FP8 failed: status={status}");
|
||||||
}
|
}
|
||||||
@@ -424,7 +424,7 @@ pub fn batched_gemm_fp8(
|
|||||||
a_scales.data_ptr() as *const c_void,
|
a_scales.data_ptr() as *const c_void,
|
||||||
b_scales.data_ptr() as *const c_void,
|
b_scales.data_ptr() as *const c_void,
|
||||||
total_rows, n as i32, m as i32,
|
total_rows, n as i32, m as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,7 +456,7 @@ pub fn batched_gemv_mxfp4(x: &Tensor, w_packed: &Tensor, w_scales: &Tensor, n: u
|
|||||||
w_scales.data_ptr() as *const c_void,
|
w_scales.data_ptr() as *const c_void,
|
||||||
y.data_ptr() as *mut c_void,
|
y.data_ptr() as *mut c_void,
|
||||||
e as i32, n as i32, k as i32,
|
e as i32, n as i32, k as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
y
|
y
|
||||||
@@ -472,7 +472,7 @@ pub fn dequant_mxfp4_to_bf16_t(w_packed: &Tensor, w_scales: &Tensor, e: usize, n
|
|||||||
w_scales.data_ptr() as *const c_void,
|
w_scales.data_ptr() as *const c_void,
|
||||||
out.data_ptr() as *mut c_void,
|
out.data_ptr() as *mut c_void,
|
||||||
e as i32, n as i32, k as i32,
|
e as i32, n as i32, k as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
|||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => launch_rmsnorm_f32(
|
DType::F32 => launch_rmsnorm_f32(
|
||||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_rmsnorm_bf16(
|
DType::BF16 => launch_rmsnorm_bf16(
|
||||||
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, gamma.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
rows as i32, hidden_size as i32, eps, std::ptr::null_mut(),
|
rows as i32, hidden_size as i32, eps, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for rmsnorm"),
|
_ => panic!("unsupported dtype for rmsnorm"),
|
||||||
}
|
}
|
||||||
@@ -71,7 +71,7 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
|||||||
rows as i32,
|
rows as i32,
|
||||||
hidden_size as i32,
|
hidden_size as i32,
|
||||||
eps,
|
eps,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ impl RopeCache {
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_compute_rope_cache(
|
launch_compute_rope_cache(
|
||||||
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
cos.as_mut_ptr() as _, sin.as_mut_ptr() as _,
|
||||||
max_seq_len as i32, half_dim as i32, theta, std::ptr::null_mut(),
|
max_seq_len as i32, half_dim as i32, theta, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,21 +136,36 @@ pub fn rope_inplace(x: &Tensor, cache: &RopeCache, positions: &[u32]) {
|
|||||||
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
let mut pos_gpu = xserv_cuda::allocator::cached_alloc(pos_bytes.len()).expect("alloc positions");
|
||||||
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
pos_gpu.copy_from_host(pos_bytes).unwrap();
|
||||||
|
|
||||||
|
rope_inplace_device_pos(x, cache, pos_gpu.as_ptr() as *const c_void);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RoPE in-place with positions already on the GPU (u32, [num_tokens]).
|
||||||
|
/// Used by the CUDA-graph decode path, where the position lives in a
|
||||||
|
/// persistent device buffer updated outside the captured region.
|
||||||
|
pub fn rope_inplace_device_pos(x: &Tensor, cache: &RopeCache, pos_gpu: *const c_void) {
|
||||||
|
assert_eq!(x.ndim(), 3);
|
||||||
|
assert!(x.is_contiguous());
|
||||||
|
assert!(matches!(x.device(), Device::Cuda(_)));
|
||||||
|
let num_tokens = x.shape()[0];
|
||||||
|
let num_heads = x.shape()[1];
|
||||||
|
let head_dim = x.shape()[2];
|
||||||
|
assert_eq!(head_dim / 2, cache.half_dim);
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => launch_rope_f32(
|
DType::F32 => launch_rope_f32(
|
||||||
x.data_ptr() as *mut c_void,
|
x.data_ptr() as *mut c_void,
|
||||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||||
pos_gpu.as_ptr() as _,
|
pos_gpu,
|
||||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_rope_bf16(
|
DType::BF16 => launch_rope_bf16(
|
||||||
x.data_ptr() as *mut c_void,
|
x.data_ptr() as *mut c_void,
|
||||||
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
cache.cos.as_ptr() as _, cache.sin.as_ptr() as _,
|
||||||
pos_gpu.as_ptr() as _,
|
pos_gpu,
|
||||||
num_tokens as i32, num_heads as i32, head_dim as i32,
|
num_tokens as i32, num_heads as i32, head_dim as i32,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for rope"),
|
_ => panic!("unsupported dtype for rope"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,11 +22,11 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
|||||||
match x.dtype() {
|
match x.dtype() {
|
||||||
DType::F32 => launch_softmax_f32(
|
DType::F32 => launch_softmax_f32(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::BF16 => launch_softmax_bf16(
|
DType::BF16 => launch_softmax_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
rows as i32, cols as i32, std::ptr::null_mut(),
|
rows as i32, cols as i32, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("unsupported dtype for softmax"),
|
_ => panic!("unsupported dtype for softmax"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim:
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_reshape_heads_bf16(
|
launch_reshape_heads_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -40,7 +40,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_merge_heads_bf16(
|
launch_merge_heads_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -54,7 +54,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_transpose_hsd_to_shd_bf16(
|
launch_transpose_hsd_to_shd_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -68,7 +68,7 @@ pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, hea
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_transpose_shd_to_hsd_bf16(
|
launch_transpose_shd_to_hsd_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
seq_len as i32, num_heads as i32, head_dim as i32, std::ptr::null_mut(),
|
seq_len as i32, num_heads as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -87,7 +87,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
|||||||
unsafe {
|
unsafe {
|
||||||
launch_repeat_kv_bf16(
|
launch_repeat_kv_bf16(
|
||||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||||
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, std::ptr::null_mut(),
|
kv_heads as i32, n_rep as i32, seq_len as i32, head_dim as i32, xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
@@ -126,14 +126,14 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
|||||||
numel as i32, ndim as i32,
|
numel as i32, ndim as i32,
|
||||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||||
in_offset, std::ptr::null_mut(),
|
in_offset, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
DType::F32 => launch_strided_copy_f32(
|
DType::F32 => launch_strided_copy_f32(
|
||||||
storage_ptr as _, out.data_ptr() as *mut c_void,
|
storage_ptr as _, out.data_ptr() as *mut c_void,
|
||||||
numel as i32, ndim as i32,
|
numel as i32, ndim as i32,
|
||||||
shape4[0], shape4[1], shape4[2], shape4[3],
|
shape4[0], shape4[1], shape4[2], shape4[3],
|
||||||
strides4[0], strides4[1], strides4[2], strides4[3],
|
strides4[0], strides4[1], strides4[2], strides4[3],
|
||||||
in_offset, std::ptr::null_mut(),
|
in_offset, xserv_cuda::current_stream_raw(),
|
||||||
),
|
),
|
||||||
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
|
_ => panic!("strided_to_contiguous_gpu: unsupported dtype {:?}", x.dtype()),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -349,7 +349,7 @@ impl PagedKVCache {
|
|||||||
k_pool_ptr, v_pool_ptr,
|
k_pool_ptr, v_pool_ptr,
|
||||||
block_ids_gpu.as_ptr() as *const i32,
|
block_ids_gpu.as_ptr() as *const i32,
|
||||||
num_tokens, nkv, hd, start_pos, bs,
|
num_tokens, nkv, hd, start_pos, bs,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
// block_ids_gpu drops here; the launch on the null stream will have
|
// block_ids_gpu drops here; the launch on the null stream will have
|
||||||
@@ -397,7 +397,7 @@ impl PagedKVCache {
|
|||||||
k_pool_ptr, v_pool_ptr,
|
k_pool_ptr, v_pool_ptr,
|
||||||
bt_ptr, cl_ptr,
|
bt_ptr, cl_ptr,
|
||||||
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
|
batch, nkv, hd, BLOCK_SIZE, self.max_blocks_per_seq,
|
||||||
std::ptr::null_mut(),
|
xserv_cuda::current_stream_raw(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user