phase 15: Tensor::empty + CUDA Graph infra — 50.3 tok/s (140% of HF, 45% roofline)

Two optimizations:

1. Tensor::empty() — skip cudaMemset for output tensors
   All kernel wrappers that fully overwrite their output now use
   Tensor::empty() instead of Tensor::zeros(). Eliminates ~756
   cudaMemset calls per decode step (21 per layer × 36 layers).
   Improvement: 46.6 → 50.3 tok/s (+8%).

2. CUDA Graph infrastructure (for future use)
   Added FFI bindings (cudaStreamBeginCapture, cudaGraphInstantiate,
   cudaGraphLaunch) and RAII CudaGraph wrapper. Not yet used in the
   forward pass due to variable kv_len, but provides foundation for
   future graph-based decode optimization.

Ablation (dash5, RTX 5090, Qwen3-8B BF16, serial decode):

| Optimization | tok/s | vs HF | Roofline |
|-------------|-------|-------|----------|
| Phase 14 baseline | 12.9 | 36% | 12% |
| + Fused kernels | 13.2 | 37% | 12% |
| + Batched decode | 13.2 (serial) | 37% | 12% |
| + Custom GEMV | 46.6 | 130% | 42% |
| + Tensor::empty | 50.3 | 140% | 45% |

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-22 23:57:34 +08:00
parent e207523e21
commit d5532ef209
13 changed files with 170 additions and 20 deletions

View File

@@ -3,6 +3,8 @@ use std::os::raw::c_char;
pub type CudaStream = *mut c_void;
pub type CudaEvent = *mut c_void;
pub type CudaGraph = *mut c_void;
pub type CudaGraphExec = *mut c_void;
pub const CUDA_MEMCPY_H2D: i32 = 1;
pub const CUDA_MEMCPY_D2H: i32 = 2;
@@ -11,6 +13,9 @@ pub const CUDA_MEMCPY_D2D: i32 = 3;
pub const CUDA_SUCCESS: i32 = 0;
pub const CUDA_ERROR_OUT_OF_MEMORY: i32 = 2;
/// cudaStreamCaptureMode::cudaStreamCaptureModeGlobal
pub const CUDA_STREAM_CAPTURE_MODE_GLOBAL: i32 = 0;
unsafe extern "C" {
// --- Device ---
pub fn cudaGetDeviceCount(count: *mut i32) -> i32;
@@ -44,6 +49,18 @@ unsafe extern "C" {
pub fn cudaGetLastError() -> i32;
pub fn cudaGetErrorString(error: i32) -> *const c_char;
// --- CUDA Graphs ---
pub fn cudaStreamBeginCapture(stream: CudaStream, mode: i32) -> i32;
pub fn cudaStreamEndCapture(stream: CudaStream, graph: *mut CudaGraph) -> i32;
pub fn cudaGraphInstantiate(
graph_exec: *mut CudaGraphExec,
graph: CudaGraph,
flags: u64,
) -> i32;
pub fn cudaGraphLaunch(graph_exec: CudaGraphExec, stream: CudaStream) -> i32;
pub fn cudaGraphDestroy(graph: CudaGraph) -> i32;
pub fn cudaGraphExecDestroy(graph_exec: CudaGraphExec) -> i32;
// --- Our test kernel ---
pub fn launch_vecadd_f32(
a: *const f32,

View File

@@ -0,0 +1,98 @@
//! CUDA Graphs: capture a sequence of kernel launches and replay them with
//! near-zero host-side overhead (~3-5 us per launch eliminated).
//!
//! Usage:
//! ```ignore
//! let stream = CudaStream::new()?;
//! let mut graph = CudaGraph::new();
//!
//! // First call: capture
//! graph.begin_capture(&stream)?;
//! // ... launch kernels on `stream` ...
//! graph.end_capture(&stream)?;
//!
//! // Subsequent calls: replay
//! graph.launch(&stream)?;
//! ```
//!
//! Requirements for captured kernels:
//! - All tensor shapes must be identical between capture and replay.
//! - No host-side branching during the captured section.
//! - Memory addresses used during capture must remain valid during replay.
use crate::error::{self, Result};
use crate::ffi;
use crate::stream::CudaStream;
/// RAII wrapper around a captured CUDA graph and its executable instance.
pub struct CudaGraph {
graph: ffi::CudaGraph,
exec: ffi::CudaGraphExec,
}
impl CudaGraph {
/// Create an empty graph handle (not yet captured).
pub fn new() -> Self {
Self {
graph: std::ptr::null_mut(),
exec: std::ptr::null_mut(),
}
}
/// Returns true if a graph has been captured and instantiated.
pub fn is_ready(&self) -> bool {
!self.exec.is_null()
}
/// Begin capturing kernel launches on `stream`.
/// All subsequent kernel launches on this stream are recorded into the
/// graph instead of being executed.
pub fn begin_capture(&mut self, stream: &CudaStream) -> Result<()> {
// If we have an old graph, destroy it first
self.destroy_inner();
error::check(unsafe {
ffi::cudaStreamBeginCapture(
stream.as_raw(),
ffi::CUDA_STREAM_CAPTURE_MODE_GLOBAL,
)
})
}
/// End capture and instantiate the executable graph.
pub fn end_capture(&mut self, stream: &CudaStream) -> Result<()> {
error::check(unsafe {
ffi::cudaStreamEndCapture(stream.as_raw(), &mut self.graph)
})?;
error::check(unsafe {
ffi::cudaGraphInstantiate(&mut self.exec, self.graph, 0)
})
}
/// Replay the captured graph on `stream`.
/// Panics if no graph has been captured yet.
pub fn launch(&self, stream: &CudaStream) -> Result<()> {
assert!(self.is_ready(), "CudaGraph::launch called before capture");
error::check(unsafe {
ffi::cudaGraphLaunch(self.exec, stream.as_raw())
})
}
fn destroy_inner(&mut self) {
if !self.exec.is_null() {
unsafe { ffi::cudaGraphExecDestroy(self.exec) };
self.exec = std::ptr::null_mut();
}
if !self.graph.is_null() {
unsafe { ffi::cudaGraphDestroy(self.graph) };
self.graph = std::ptr::null_mut();
}
}
}
impl Drop for CudaGraph {
fn drop(&mut self) {
self.destroy_inner();
}
}
unsafe impl Send for CudaGraph {}

View File

@@ -2,11 +2,13 @@ pub mod allocator;
pub mod device;
pub mod error;
pub mod ffi;
pub mod graph;
pub mod memory;
pub mod stream;
pub use allocator::CachingAllocator;
pub use device::DeviceInfo;
pub use error::{CudaError, Result};
pub use graph::CudaGraph;
pub use memory::{GpuBuffer, PinnedBuffer};
pub use stream::CudaStream;

View File

@@ -18,7 +18,7 @@ unsafe extern "C" {
fn dispatch_unary(x: &Tensor, f32_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void),
bf16_fn: unsafe extern "C" fn(*const c_void, *mut c_void, i32, *mut c_void)) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel() as i32;
unsafe {
match x.dtype() {
@@ -37,7 +37,7 @@ fn dispatch_binary(a: &Tensor, b: &Tensor,
assert!(a.is_contiguous() && b.is_contiguous());
assert!(matches!(a.device(), Device::Cuda(_)));
assert_eq!(a.dtype(), b.dtype());
let out = Tensor::zeros(a.shape(), a.dtype(), a.device());
let out = Tensor::empty(a.shape(), a.dtype(), a.device());
let n = a.numel() as i32;
unsafe {
match a.dtype() {
@@ -54,7 +54,7 @@ pub fn silu(x: &Tensor) -> Tensor { dispatch_unary(x, launch_silu_f32, launch_si
pub fn scale(x: &Tensor, scale_val: f32) -> Tensor {
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
let n = x.numel() as i32;
unsafe {
match x.dtype() {
@@ -76,7 +76,7 @@ pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Tensor {
assert!(gate.is_contiguous() && up.is_contiguous());
assert!(matches!(gate.device(), Device::Cuda(_)));
assert_eq!(gate.dtype(), DType::BF16, "silu_mul requires BF16");
let out = Tensor::zeros(gate.shape(), gate.dtype(), gate.device());
let out = Tensor::empty(gate.shape(), gate.dtype(), gate.device());
let n = gate.numel() as i32;
unsafe {
launch_silu_mul_bf16(

View File

@@ -105,7 +105,7 @@ pub fn decode_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Tensor {
let kv_len = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::zeros(
let output = Tensor::empty(
&[batch, num_q_heads, 1, head_dim],
DType::BF16,
q.device(),
@@ -166,7 +166,7 @@ pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, causal: bool) -> Tens
}
let scale = 1.0 / (head_dim as f32).sqrt();
let output = Tensor::zeros(
let output = Tensor::empty(
&[batch, num_q_heads, q_len, head_dim],
DType::BF16,
q.device(),

View File

@@ -29,7 +29,7 @@ pub fn embedding(table: &Tensor, token_ids: &[u32]) -> Tensor {
let mut ids_gpu = GpuBuffer::alloc(ids_bytes.len()).expect("alloc token_ids");
ids_gpu.copy_from_host(ids_bytes).unwrap();
let out = Tensor::zeros(&[num_tokens, hidden_size], table.dtype(), table.device());
let out = Tensor::empty(&[num_tokens, hidden_size], table.dtype(), table.device());
unsafe {
match table.dtype() {

View File

@@ -98,7 +98,9 @@ pub fn matmul(a: &Tensor, b: &Tensor, backend: GemmBackend) -> Tensor {
let n = b.shape()[1];
let dtype = a.dtype();
let c = Tensor::zeros(&[m, n], dtype, a.device());
// All backends (naive, tiled, cuBLAS with beta=0, custom GEMV) fully
// overwrite every element of C, so we skip the cudaMemset.
let c = Tensor::empty(&[m, n], dtype, a.device());
let a_ptr = a.data_ptr() as *const c_void;
let b_ptr = b.data_ptr() as *const c_void;
@@ -202,7 +204,8 @@ pub fn batched_matmul(a: &Tensor, b: &Tensor) -> Tensor {
let mut out_shape: Vec<usize> = a.shape()[..ndim - 2].to_vec();
out_shape.push(m);
out_shape.push(n);
let c = Tensor::zeros(&out_shape, a.dtype(), a.device());
// cuBLAS with beta=0 fully overwrites every element of C.
let c = Tensor::empty(&out_shape, a.dtype(), a.device());
let dtype = a.dtype();
let (a_type, b_type, c_type) = match dtype {

View File

@@ -17,7 +17,7 @@ pub fn layernorm(x: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor
assert_eq!(beta.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {

View File

@@ -20,7 +20,7 @@ pub fn rmsnorm(x: &Tensor, gamma: &Tensor, eps: f32) -> Tensor {
assert_eq!(x.dtype(), gamma.dtype());
let rows = x.numel() / hidden_size;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {
@@ -54,8 +54,8 @@ pub fn add_rmsnorm(x: &Tensor, residual: &Tensor, gamma: &Tensor, eps: f32) -> (
assert_eq!(gamma.shape(), &[hidden_size]);
let rows = x.numel() / hidden_size;
let normed_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
let sum_out = Tensor::zeros(x.shape(), DType::BF16, x.device());
let normed_out = Tensor::empty(x.shape(), DType::BF16, x.device());
let sum_out = Tensor::empty(x.shape(), DType::BF16, x.device());
unsafe {
launch_add_rmsnorm_bf16(

View File

@@ -14,7 +14,7 @@ pub fn softmax(x: &Tensor) -> Tensor {
let cols = *x.shape().last().unwrap();
let rows = x.numel() / cols;
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
unsafe {
match x.dtype() {

View File

@@ -21,7 +21,7 @@ unsafe extern "C" {
pub fn reshape_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_reshape_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
@@ -36,7 +36,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let hidden = num_heads * head_dim;
let out = Tensor::zeros(&[seq_len, hidden], DType::BF16, x.device());
let out = Tensor::empty(&[seq_len, hidden], DType::BF16, x.device());
unsafe {
launch_merge_heads_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
@@ -50,7 +50,7 @@ pub fn merge_heads_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: u
pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
let out = Tensor::empty(&[seq_len, num_heads, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_hsd_to_shd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
@@ -64,7 +64,7 @@ pub fn transpose_for_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head
pub fn transpose_from_rope_gpu(x: &Tensor, seq_len: usize, num_heads: usize, head_dim: usize) -> Tensor {
assert_eq!(x.dtype(), DType::BF16);
assert!(x.is_contiguous() && matches!(x.device(), Device::Cuda(_)));
let out = Tensor::zeros(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
let out = Tensor::empty(&[1, num_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_transpose_shd_to_hsd_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
@@ -83,7 +83,7 @@ pub fn repeat_kv_gpu(x: &Tensor, n_rep: usize) -> Tensor {
let seq_len = x.shape()[2];
let head_dim = x.shape()[3];
let new_heads = kv_heads * n_rep;
let out = Tensor::zeros(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
let out = Tensor::empty(&[1, new_heads, seq_len, head_dim], DType::BF16, x.device());
unsafe {
launch_repeat_kv_bf16(
x.data_ptr() as _, out.data_ptr() as *mut c_void,
@@ -112,7 +112,7 @@ pub fn strided_to_contiguous_gpu(x: &Tensor) -> Tensor {
strides4[pad + i] = x.strides()[i] as i32;
}
let out = Tensor::zeros(x.shape(), x.dtype(), x.device());
let out = Tensor::empty(x.shape(), x.dtype(), x.device());
// Use storage base pointer + element offset, because strides are relative to
// element 0 of the storage, not the data_ptr() (which already adds byte offset).

View File

@@ -116,4 +116,18 @@ impl Storage {
}
}
}
/// Allocate storage **without zeroing** on the given device.
/// The buffer may contain stale data from the caching allocator's pool.
/// Only use when the caller guarantees the kernel will fully overwrite
/// every element before any read.
pub fn empty(len_bytes: usize, device: Device) -> CudaResult<Self> {
match device {
Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])), // CPU still zeros (cheap)
Device::Cuda(dev) => {
let buf = xserv_cuda::allocator::cached_alloc(len_bytes)?;
Ok(Storage::cuda(buf, dev))
}
}
}
}

View File

@@ -65,6 +65,22 @@ impl Tensor {
}
}
/// Allocate a tensor **without zeroing** the backing memory.
/// The buffer may contain stale data. Only use when the calling kernel
/// will fully overwrite every element before any read.
pub fn empty(shape: &[usize], dtype: DType, device: Device) -> Self {
let numel = shape::num_elements(shape);
let len_bytes = numel * dtype.size_bytes();
let storage = Storage::empty(len_bytes, device).expect("alloc failed");
Self {
storage,
shape: Dims::from_slice(shape),
strides: shape::contiguous_strides(shape),
offset: 0,
dtype,
}
}
pub fn ones(shape: &[usize], dtype: DType) -> Self {
let numel = shape::num_elements(shape);
match dtype {