tensor: add narrow() view and relax is_contiguous for size-1 dims

narrow(dim, start, len) creates a zero-copy slice along any dimension.
is_contiguous() now ignores stride mismatches on dimensions of size 1,
since those dimensions are never stepped. This avoids unnecessary GPU
strided copies when slicing fused projection outputs at batch=1.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Gahow Wang
2026-05-30 12:49:57 +08:00
parent c2362df1f1
commit 1ab6ca9c09
2 changed files with 26 additions and 2 deletions

View File

@@ -18,12 +18,21 @@ pub fn contiguous_strides(shape: &[usize]) -> Dims {
}
/// Check if the given strides represent contiguous (row-major) layout for the shape.
/// A stride mismatch on a dimension of size 1 is allowed because that
/// dimension is never stepped.
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}
let expected = contiguous_strides(shape);
strides == expected.as_slice()
let ndim = shape.len();
let mut expected_stride = 1usize;
for d in (0..ndim).rev() {
if shape[d] != 1 && strides[d] != expected_stride {
return false;
}
expected_stride *= shape[d];
}
true
}
/// Total number of elements given a shape.

View File

@@ -120,6 +120,21 @@ impl Tensor {
}
}
/// Zero-copy slice along `dim`: keeps elements `[start, start+len)`.
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Self {
assert!(dim < self.ndim());
assert!(start + len <= self.shape[dim], "narrow out of bounds");
let mut new_shape = self.shape.clone();
new_shape[dim] = len;
Self {
storage: self.storage.clone(),
shape: new_shape,
strides: self.strides.clone(),
offset: self.offset + start * self.strides[dim],
dtype: self.dtype,
}
}
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
assert!(dim0 < self.ndim() && dim1 < self.ndim());
let mut new_shape = self.shape.clone();