phase 15: Tensor::empty + CUDA Graph infra — 50.3 tok/s (140% of HF, 45% roofline)
Two optimizations: 1. Tensor::empty() — skip cudaMemset for output tensors All kernel wrappers that fully overwrite their output now use Tensor::empty() instead of Tensor::zeros(). Eliminates ~756 cudaMemset calls per decode step (21 per layer × 36 layers). Improvement: 46.6 → 50.3 tok/s (+8%). 2. CUDA Graph infrastructure (for future use) Added FFI bindings (cudaStreamBeginCapture, cudaGraphInstantiate, cudaGraphLaunch) and RAII CudaGraph wrapper. Not yet used in the forward pass due to variable kv_len, but provides foundation for future graph-based decode optimization. Ablation (dash5, RTX 5090, Qwen3-8B BF16, serial decode): | Optimization | tok/s | vs HF | Roofline | |-------------|-------|-------|----------| | Phase 14 baseline | 12.9 | 36% | 12% | | + Fused kernels | 13.2 | 37% | 12% | | + Batched decode | 13.2 (serial) | 37% | 12% | | + Custom GEMV | 46.6 | 130% | 42% | | + Tensor::empty | 50.3 | 140% | 45% | Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,8 @@ use std::os::raw::c_char;
|
||||
|
||||
pub type CudaStream = *mut c_void;
|
||||
pub type CudaEvent = *mut c_void;
|
||||
pub type CudaGraph = *mut c_void;
|
||||
pub type CudaGraphExec = *mut c_void;
|
||||
|
||||
pub const CUDA_MEMCPY_H2D: i32 = 1;
|
||||
pub const CUDA_MEMCPY_D2H: i32 = 2;
|
||||
@@ -11,6 +13,9 @@ pub const CUDA_MEMCPY_D2D: i32 = 3;
|
||||
pub const CUDA_SUCCESS: i32 = 0;
|
||||
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
|
||||
|
||||
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
|
||||
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
|
||||
|
||||
unsafe extern "C" {
|
||||
// --- Device ---
|
||||
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
|
||||
@@ -44,6 +49,18 @@ unsafe extern "C" {
|
||||
pub fn cudaGetLastError() -> i32;
|
||||
pub fn cudaGetErrorString(error: i32) -> *const c_char;
|
||||
|
||||
// --- CUDA Graphs ---
|
||||
pub fn cudaStreamBeginCapture(stream: CudaStream, mode: i32) -> i32;
|
||||
pub fn cudaStreamEndCapture(stream: CudaStream, graph: *mut CudaGraph) -> i32;
|
||||
pub fn cudaGraphInstantiate(
|
||||
graph_exec: *mut CudaGraphExec,
|
||||
graph: CudaGraph,
|
||||
flags: u64,
|
||||
) -> i32;
|
||||
pub fn cudaGraphLaunch(graph_exec: CudaGraphExec, stream: CudaStream) -> i32;
|
||||
pub fn cudaGraphDestroy(graph: CudaGraph) -> i32;
|
||||
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
|
||||
|
||||
// --- Our test kernel ---
|
||||
pub fn launch_vecadd_f32(
|
||||
a: *const f32,
|
||||
|
||||
98
crates/xserv-cuda/src/graph.rs
Normal file
98
crates/xserv-cuda/src/graph.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
//! CUDA Graphs: capture a sequence of kernel launches and replay them with
|
||||
//! near-zero host-side overhead (~3-5 us per launch eliminated).
|
||||
//!
|
||||
//! Usage:
|
||||
//! ```ignore
|
||||
//! let stream = CudaStream::new()?;
|
||||
//! let mut graph = CudaGraph::new();
|
||||
//!
|
||||
//! // First call: capture
|
||||
//! graph.begin_capture(&stream)?;
|
||||
//! // ... launch kernels on `stream` ...
|
||||
//! graph.end_capture(&stream)?;
|
||||
//!
|
||||
//! // Subsequent calls: replay
|
||||
//! graph.launch(&stream)?;
|
||||
//! ```
|
||||
//!
|
||||
//! Requirements for captured kernels:
|
||||
//! - All tensor shapes must be identical between capture and replay.
|
||||
//! - No host-side branching during the captured section.
|
||||
//! - Memory addresses used during capture must remain valid during replay.
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::ffi;
|
||||
use crate::stream::CudaStream;
|
||||
|
||||
/// RAII wrapper around a captured CUDA graph and its executable instance.
|
||||
pub struct CudaGraph {
|
||||
graph: ffi::CudaGraph,
|
||||
exec: ffi::CudaGraphExec,
|
||||
}
|
||||
|
||||
impl CudaGraph {
|
||||
/// Create an empty graph handle (not yet captured).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
graph: std::ptr::null_mut(),
|
||||
exec: std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if a graph has been captured and instantiated.
|
||||
pub fn is_ready(&self) -> bool {
|
||||
!self.exec.is_null()
|
||||
}
|
||||
|
||||
/// Begin capturing kernel launches on `stream`.
|
||||
/// All subsequent kernel launches on this stream are recorded into the
|
||||
/// graph instead of being executed.
|
||||
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
// If we have an old graph, destroy it first
|
||||
self.destroy_inner();
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamBeginCapture(
|
||||
stream.as_raw(),
|
||||
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// End capture and instantiate the executable graph.
|
||||
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
|
||||
error::check(unsafe {
|
||||
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
|
||||
})?;
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
|
||||
})
|
||||
}
|
||||
|
||||
/// Replay the captured graph on `stream`.
|
||||
/// Panics if no graph has been captured yet.
|
||||
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
|
||||
assert!(self.is_ready(), "CudaGraph::launch called before capture");
|
||||
error::check(unsafe {
|
||||
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
|
||||
})
|
||||
}
|
||||
|
||||
fn destroy_inner(&mut self) {
|
||||
if !self.exec.is_null() {
|
||||
unsafe { ffi::cudaGraphExecDestroy(self.exec) };
|
||||
self.exec = std::ptr::null_mut();
|
||||
}
|
||||
if !self.graph.is_null() {
|
||||
unsafe { ffi::cudaGraphDestroy(self.graph) };
|
||||
self.graph = std::ptr::null_mut();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CudaGraph {
|
||||
fn drop(&mut self) {
|
||||
self.destroy_inner();
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for CudaGraph {}
|
||||
@@ -2,11 +2,13 @@ pub mod allocator;
|
||||
pub mod device;
|
||||
pub mod error;
|
||||
pub mod ffi;
|
||||
pub mod graph;
|
||||
pub mod memory;
|
||||
pub mod stream;
|
||||
|
||||
pub use allocator::CachingAllocator;
|
||||
pub use device::DeviceInfo;
|
||||
pub use error::{CudaError, Result};
|
||||
pub use graph::CudaGraph;
|
||||
pub use memory::{GpuBuffer, PinnedBuffer};
|
||||
pub use stream::CudaStream;
|
||||
|
||||
@@ -18,7 +18,7 @@ unsafe extern "C" {
|
||||
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
|
||||
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
@@ -37,7 +37,7 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
|
||||
assert!(a.is_contiguous() && b.is_contiguous());
|
||||
assert!(matches!(a.device(), Device::Cuda(_)));
|
||||
assert_eq!(a.dtype(), b.dtype());
|
||||
let out = Tensor::zeros(a.shape(), a.dtype(), a.device());
|
||||
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
|
||||
let n = a.numel() as i32;
|
||||
unsafe {
|
||||
match a.dtype() {
|
||||
@@ -54,7 +54,7 @@ pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_si
|
||||
|
||||
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
let n = x.numel() as i32;
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
@@ -76,7 +76,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
|
||||
assert!(gate.is_contiguous() && up.is_contiguous());
|
||||
assert!(matches!(gate.device(), Device::Cuda(_)));
|
||||
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
|
||||
let out = Tensor::zeros(gate.shape(), gate.dtype(), gate.device());
|
||||
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
|
||||
let n = gate.numel() as i32;
|
||||
unsafe {
|
||||
launch_silu_mul_bf16(
|
||||
|
||||
@@ -105,7 +105,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
|
||||
let kv_len = k.shape()[2];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::zeros(
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, 1, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
@@ -166,7 +166,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
|
||||
}
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = Tensor::zeros(
|
||||
let output = Tensor::empty(
|
||||
&[batch, num_q_heads, q_len, head_dim],
|
||||
DType::BF16,
|
||||
q.device(),
|
||||
|
||||
@@ -29,7 +29,7 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
|
||||
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
|
||||
ids_gpu.copy_from_host(ids_bytes).unwrap();
|
||||
|
||||
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
|
||||
|
||||
unsafe {
|
||||
match table.dtype() {
|
||||
|
||||
@@ -98,7 +98,9 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
|
||||
let n = b.shape()[1];
|
||||
let dtype = a.dtype();
|
||||
|
||||
let c = Tensor::zeros(&[m, n], dtype, a.device());
|
||||
// All backends (naive, tiled, cuBLAS with beta=0, custom GEMV) fully
|
||||
// overwrite every element of C, so we skip the cudaMemset.
|
||||
let c = Tensor::empty(&[m, n], dtype, a.device());
|
||||
|
||||
let a_ptr = a.data_ptr() as *const c_void;
|
||||
let b_ptr = b.data_ptr() as *const c_void;
|
||||
@@ -202,7 +204,8 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
|
||||
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
|
||||
out_shape.push(m);
|
||||
out_shape.push(n);
|
||||
let c = Tensor::zeros(&out_shape, a.dtype(), a.device());
|
||||
// cuBLAS with beta=0 fully overwrites every element of C.
|
||||
let c = Tensor::empty(&out_shape, a.dtype(), a.device());
|
||||
|
||||
let dtype = a.dtype();
|
||||
let (a_type, b_type, c_type) = match dtype {
|
||||
|
||||
@@ -17,7 +17,7 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
|
||||
assert_eq!(beta.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
|
||||
@@ -20,7 +20,7 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
|
||||
assert_eq!(x.dtype(), gamma.dtype());
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
@@ -54,8 +54,8 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
|
||||
assert_eq!(gamma.shape(), &[hidden_size]);
|
||||
|
||||
let rows = x.numel() / hidden_size;
|
||||
let normed_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
|
||||
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
|
||||
|
||||
unsafe {
|
||||
launch_add_rmsnorm_bf16(
|
||||
|
||||
@@ -14,7 +14,7 @@ pub fn softmax(x: &Tensor) -> Tensor {
|
||||
|
||||
let cols = *x.shape().last().unwrap();
|
||||
let rows = x.numel() / cols;
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
unsafe {
|
||||
match x.dtype() {
|
||||
|
||||
@@ -21,7 +21,7 @@ unsafe extern "C" {
|
||||
pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_reshape_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
@@ -36,7 +36,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let hidden = num_heads * head_dim;
|
||||
let out = Tensor::zeros(&[seq_len, hidden], DType::BF16, x.device());
|
||||
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_merge_heads_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
@@ -50,7 +50,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
|
||||
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_hsd_to_shd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
@@ -64,7 +64,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head
|
||||
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
|
||||
assert_eq!(x.dtype(), DType::BF16);
|
||||
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
|
||||
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_transpose_shd_to_hsd_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
@@ -83,7 +83,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
|
||||
let seq_len = x.shape()[2];
|
||||
let head_dim = x.shape()[3];
|
||||
let new_heads = kv_heads * n_rep;
|
||||
let out = Tensor::zeros(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
|
||||
unsafe {
|
||||
launch_repeat_kv_bf16(
|
||||
x.data_ptr() as _, out.data_ptr() as *mut c_void,
|
||||
@@ -112,7 +112,7 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
|
||||
strides4[pad + i] = x.strides()[i] as i32;
|
||||
}
|
||||
|
||||
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
|
||||
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
|
||||
|
||||
// Use storage base pointer + element offset, because strides are relative to
|
||||
// element 0 of the storage, not the data_ptr() (which already adds byte offset).
|
||||
|
||||
@@ -116,4 +116,18 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate storage **without zeroing** on the given device.
|
||||
/// The buffer may contain stale data from the caching allocator's pool.
|
||||
/// Only use when the caller guarantees the kernel will fully overwrite
|
||||
/// every element before any read.
|
||||
pub fn empty(len_bytes: usize, device: Device) -> CudaResult<Self> {
|
||||
match device {
|
||||
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])), // CPU still zeros (cheap)
|
||||
Device::Cuda(dev) => {
|
||||
let buf = xserv_cuda::allocator::cached_alloc(len_bytes)?;
|
||||
Ok(Storage::cuda(buf, dev))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,22 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a tensor **without zeroing** the backing memory.
|
||||
/// The buffer may contain stale data. Only use when the calling kernel
|
||||
/// will fully overwrite every element before any read.
|
||||
pub fn empty(shape: &[usize], dtype: DType, device: Device) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
let len_bytes = numel * dtype.size_bytes();
|
||||
let storage = Storage::empty(len_bytes, device).expect("alloc failed");
|
||||
Self {
|
||||
storage,
|
||||
shape: Dims::from_slice(shape),
|
||||
strides: shape::contiguous_strides(shape),
|
||||
offset: 0,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ones(shape: &[usize], dtype: DType) -> Self {
|
||||
let numel = shape::num_elements(shape);
|
||||
match dtype {
|
||||
|
||||
Reference in New Issue
Block a user