diff --git a/crates/xtrain-cuda/src/ffi.rs b/crates/xtrain-cuda/src/ffi.rs index 5975471..e18e027 100644 --- a/crates/xtrain-cuda/src/ffi.rs +++ b/crates/xtrain-cuda/src/ffi.rs @@ -19,6 +19,7 @@ unsafe extern "C" { pub fn cudaMalloc(devptr: *mut *mut u8, size: usize) -> i32; pub fn cudaFree(devptr: *mut u8) -> i32; pub fn cudaMemcpy(dst: *mut u8, src: *const u8, count: usize, kind: i32) -> i32; + pub fn cudaMemset(devptr: *mut u8, value: i32, count: usize) -> i32; // --- Error --- pub fn cudaGetErrorString(error: i32) -> *const c_char; diff --git a/crates/xtrain-cuda/src/memory.rs b/crates/xtrain-cuda/src/memory.rs index 4575f98..672b34e 100644 --- a/crates/xtrain-cuda/src/memory.rs +++ b/crates/xtrain-cuda/src/memory.rs @@ -46,6 +46,12 @@ impl GpuBuffer { ffi::cudaMemcpy(dst.as_mut_ptr(), self.ptr, dst.len(), ffi::CUDA_MEMCPY_D2H) }) } + + /// Set every byte of the buffer to `value` on the device (no host copy). + /// Used to zero op-output buffers without a blocking H2D memcpy of zeros. + pub fn memset(&mut self, value: u8) -> Result<()> { + error::check(unsafe { ffi::cudaMemset(self.ptr, value as i32, self.len) }) + } } impl Drop for GpuBuffer { diff --git a/crates/xtrain-tensor/src/storage.rs b/crates/xtrain-tensor/src/storage.rs index ef162df..330b507 100644 --- a/crates/xtrain-tensor/src/storage.rs +++ b/crates/xtrain-tensor/src/storage.rs @@ -99,9 +99,11 @@ impl Storage { match device { Device::Cpu => Ok(Storage::cpu(vec![0u8; len_bytes])), Device::Cuda(dev) => { - // No device memset in T2: stage zeros from the host. + // Device-side memset (Phase T7): avoids a blocking H2D memcpy of a + // host zero buffer on every op-output allocation. cudaMemset is + // async on the default stream, so it doesn't serialize the stream. let mut buf = GpuBuffer::alloc(len_bytes)?; - buf.copy_from_host(&vec![0u8; len_bytes])?; + buf.memset(0)?; Ok(Storage::cuda(buf, dev)) } } diff --git a/crates/xtrain-tensor/src/tensor.rs b/crates/xtrain-tensor/src/tensor.rs index 2d1411a..5ad55f5 100644 --- a/crates/xtrain-tensor/src/tensor.rs +++ b/crates/xtrain-tensor/src/tensor.rs @@ -153,7 +153,6 @@ impl Tensor { std::ptr::null_mut(), // default stream ); } - xtrain_cuda::device::synchronize().expect("scale kernel sync failed"); out } @@ -201,7 +200,6 @@ impl Tensor { 0.0, out.data_ptr() as *mut f32, ); - xtrain_cuda::device::synchronize().expect("matmul kernel sync failed"); out } @@ -229,7 +227,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("transpose kernel sync failed"); out } @@ -277,7 +274,6 @@ impl Tensor { 0.0, db.data_ptr() as *mut f32, ); - xtrain_cuda::device::synchronize().expect("matmul_backward sync failed"); (da, db) } @@ -301,7 +297,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("add sync failed"); out } @@ -319,7 +314,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("mul sync failed"); out } @@ -342,7 +336,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("add_bias sync failed"); out } @@ -363,7 +356,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("sum_rows sync failed"); out } @@ -390,7 +382,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("rms_norm sync failed"); (y, inv_rms) } @@ -427,7 +418,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("rms_norm_backward sync failed"); (dx, dgamma) } @@ -444,7 +434,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("silu sync failed"); out } @@ -462,7 +451,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("silu_backward sync failed"); dx } @@ -485,7 +473,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("rope sync failed"); out } @@ -506,7 +493,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("rope_backward sync failed"); dx } @@ -525,7 +511,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("softmax sync failed"); out } @@ -545,7 +530,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("softmax_backward sync failed"); dx } @@ -571,7 +555,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("cross_entropy sync failed"); (probs, loss) } @@ -592,7 +575,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("cross_entropy_backward sync failed"); dx } @@ -640,7 +622,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("embedding sync failed"); out } @@ -660,7 +641,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("embedding_backward sync failed"); dtable } @@ -684,7 +664,6 @@ impl Tensor { std::ptr::null_mut(), ); } - xtrain_cuda::device::synchronize().expect("transpose_3d01 sync failed"); out }