tensor: add narrow() view and relax is_contiguous for size-1 dims
narrow(dim, start, len) creates a zero-copy slice along any dimension. is_contiguous() now ignores stride mismatches on dimensions of size 1, since those dimensions are never stepped. This avoids unnecessary GPU strided copies when slicing fused projection outputs at batch=1. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -18,12 +18,21 @@ pub fn contiguous_strides(shape: &[usize]) -> Dims {
|
||||
}
|
||||
|
||||
/// Check if the given strides represent contiguous (row-major) layout for the shape.
|
||||
/// A stride mismatch on a dimension of size 1 is allowed because that
|
||||
/// dimension is never stepped.
|
||||
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
|
||||
if shape.is_empty() {
|
||||
return true;
|
||||
}
|
||||
let expected = contiguous_strides(shape);
|
||||
strides == expected.as_slice()
|
||||
let ndim = shape.len();
|
||||
let mut expected_stride = 1usize;
|
||||
for d in (0..ndim).rev() {
|
||||
if shape[d] != 1 && strides[d] != expected_stride {
|
||||
return false;
|
||||
}
|
||||
expected_stride *= shape[d];
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Total number of elements given a shape.
|
||||
|
||||
@@ -120,6 +120,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Zero-copy slice along `dim`: keeps elements `[start, start+len)`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
|
||||
assert!(dim < self.ndim());
|
||||
assert!(start + len <= self.shape[dim], "narrow out of bounds");
|
||||
let mut new_shape = self.shape.clone();
|
||||
new_shape[dim] = len;
|
||||
Self {
|
||||
storage: self.storage.clone(),
|
||||
shape: new_shape,
|
||||
strides: self.strides.clone(),
|
||||
offset: self.offset + start * self.strides[dim],
|
||||
dtype: self.dtype,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
|
||||
assert!(dim0 < self.ndim() && dim1 < self.ndim());
|
||||
let mut new_shape = self.shape.clone();
|
||||
|
||||
Reference in New Issue
Block a user