diff --git a/crates/xserv-tensor/src/shape.rs b/crates/xserv-tensor/src/shape.rs index 5f70dc6..bf89593 100644 --- a/crates/xserv-tensor/src/shape.rs +++ b/crates/xserv-tensor/src/shape.rs @@ -18,12 +18,21 @@ pub fn contiguous_strides(shape: &[usize]) -> Dims { } /// Check if the given strides represent contiguous (row-major) layout for the shape. +/// A stride mismatch on a dimension of size 1 is allowed because that +/// dimension is never stepped. pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { if shape.is_empty() { return true; } - let expected = contiguous_strides(shape); - strides == expected.as_slice() + let ndim = shape.len(); + let mut expected_stride = 1usize; + for d in (0..ndim).rev() { + if shape[d] != 1 && strides[d] != expected_stride { + return false; + } + expected_stride *= shape[d]; + } + true } /// Total number of elements given a shape. diff --git a/crates/xserv-tensor/src/tensor.rs b/crates/xserv-tensor/src/tensor.rs index a4aeeb1..33965f6 100644 --- a/crates/xserv-tensor/src/tensor.rs +++ b/crates/xserv-tensor/src/tensor.rs @@ -120,6 +120,21 @@ impl Tensor { } } + /// Zero-copy slice along `dim`: keeps elements `[start, start+len)`. + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self { + assert!(dim < self.ndim()); + assert!(start + len <= self.shape[dim], "narrow out of bounds"); + let mut new_shape = self.shape.clone(); + new_shape[dim] = len; + Self { + storage: self.storage.clone(), + shape: new_shape, + strides: self.strides.clone(), + offset: self.offset + start * self.strides[dim], + dtype: self.dtype, + } + } + pub fn transpose(&self, dim0: usize, dim1: usize) -> Self { assert!(dim0 < self.ndim() && dim1 < self.ndim()); let mut new_shape = self.shape.clone();